49 #include "EST_Token.h"
50 #include "EST_FMatrix.h"
51 #include "EST_multistats.h"
52 #include "EST_Wagon.h"
64 int wgn_min_cluster_size = 50;
67 int wgn_quiet = FALSE;
68 int wgn_verbose = FALSE;
69 int wgn_count_field = -1;
71 int wgn_predictee = 0;
73 float wgn_float_range_split = 10;
74 float wgn_balance = 0;
79 static float do_summary(
WNode &tree,
WDataSet &ds,ostream *output);
80 static float test_tree_float(
WNode &tree,
WDataSet &ds,ostream *output);
81 static float test_tree_class(
WNode &tree,
WDataSet &ds,ostream *output);
82 static float test_tree_cluster(
WNode &tree,
WDataSet &dataset, ostream *output);
83 static float test_tree_vector(
WNode &tree,
WDataSet &dataset,ostream *output);
84 static float test_tree_trajectory(
WNode &tree,
WDataSet &dataset,ostream *output);
85 static float test_tree_ols(
WNode &tree,
WDataSet &dataset,ostream *output);
86 static int wagon_split(
int margin,
WNode &node);
88 static void construct_binary_ques(
int feat,
WQuestion &test_ques);
92 static WNode *wagon_stepwise_find_next_best(
float &bscore,
int &best_feat);
94 Declare_TList_T(
WVector *, WVectorP)
96 Declare_TVector_Base_T(
WVector *,NULL,NULL,WVectorP)
98 #if defined(INSTANTIATE_TEMPLATES)
100 #include "../base_class/EST_TList.cc"
101 #include "../base_class/EST_TVector.cc"
103 Instantiate_TList_T(WVector *, WVectorP)
105 Instantiate_TVector(WVector *)
109 void wgn_load_datadescription(
EST_String fname,LISP ignores)
112 wgn_dataset.load_description(fname,ignores);
113 wgn_test_dataset.load_description(fname,ignores);
124 if (ts.
open(fname) == -1)
125 wagon_error(
EST_String(
"unable to open data file \"")+
133 v =
new WVector(dataset.width());
137 int type = dataset.ftype(i);
138 if ((type == wndt_float) ||
139 (type == wndt_ols) ||
140 (wgn_count_field == i))
143 float f = atof(ts.
get().string());
148 cout << fname <<
": bad float " << f
150 dataset.feat_name(i) <<
" vector " <<
151 dataset.samples() << endl;
152 v->set_flt_val(i,0.0);
155 else if (type == wndt_binary)
156 v->set_int_val(i,atoi(ts.
get().string()));
157 else if (type == wndt_cluster)
158 v->set_int_val(i,atoi(ts.
get().string()));
159 else if (type == wndt_vector)
160 v->set_int_val(i,atoi(ts.
get().string()));
161 else if (type == wndt_trajectory)
167 v->set_int_val(i,atoi(ts.
get().string()));
169 else if (type == wndt_ignore)
177 int n = wgn_discretes.discrete(type).
name(s);
180 cout << fname <<
": bad value " << s <<
" in field " <<
181 dataset.feat_name(i) <<
" vector " <<
182 dataset.samples() << endl;
189 while (!ts.
eoln() && i<dataset.width());
191 if (i != dataset.width())
193 wagon_error(fname+
": data vector "+itoString(nvec)+
" contains "
194 +itoString(i)+
" parameters instead of "+
195 itoString(dataset.width()));
199 cerr << fname <<
": data vector " << nvec <<
200 " contains too many parameters instead of "
201 << dataset.width() << endl;
202 wagon_error(
EST_String(
"extra parameter(s) from ")+
208 cout <<
"Dataset of " << dataset.samples() <<
" vectors of " <<
209 dataset.width() <<
" parameters from: " << fname << endl;
213 float summary_results(
WNode &tree,ostream *output)
215 if (wgn_test_dataset.samples() != 0)
216 return do_summary(tree,wgn_test_dataset,output);
218 return do_summary(tree,wgn_dataset,output);
221 static float do_summary(
WNode &tree,
WDataSet &ds,ostream *output)
223 if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
224 return test_tree_cluster(tree,ds,output);
225 else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
226 return test_tree_vector(tree,ds,output);
227 else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
228 return test_tree_trajectory(tree,ds,output);
229 else if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
230 return test_tree_ols(tree,ds,output);
231 else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
232 return test_tree_class(tree,ds,output);
234 return test_tree_float(tree,ds,output);
237 WNode *wgn_build_tree(
float &score)
243 wgn_set_up_data(top->get_data(),wgn_dataset,wgn_held_out,TRUE);
246 wagon_split(margin,*top);
248 if (wgn_held_out > 0)
250 wgn_set_up_data(top->get_data(),wgn_dataset,wgn_held_out,FALSE);
251 top->held_out_prune();
257 score = summary_results(*top,0);
272 for (j=i=0,d=ds.head(); d != 0; d=d->next(),j++)
274 if ((in) && ((j%100) >= held_out))
289 static float test_tree_class(
WNode &tree,
WDataSet &dataset,ostream *output)
299 float correct=0,total=0, count=0;
301 float bcorrect=0, bpredicted=0, bactual=0;
302 float precision=0, recall=0;
304 for (p=dataset.head(); p != 0; p=p->next())
306 pnode = tree.predict_node((*dataset(p)));
307 predict = (
EST_String)pnode->get_impurity().value();
308 if (wgn_count_field == -1)
311 count = dataset(p)->get_flt_val(wgn_count_field);
312 prob = pnode->get_impurity().pd().probability(predict);
313 H += (log(prob))*count;
314 type = dataset.ftype(wgn_predictee);
315 real = wgn_discretes[type].name(dataset(p)->get_int_val(wgn_predictee));
317 if (wgn_opt_param ==
"B_NB_F1")
335 for (i=0; i<wgn_discretes[dataset.ftype(wgn_predictee)].length(); i++)
336 lex.
append(wgn_discretes[dataset.ftype(wgn_predictee)].name(i));
342 print_confusion(m,pairs,lex);
343 *output <<
";; entropy " << (-1*(H/total)) <<
" perplexity " <<
344 pow(2.0,(-1*(H/total))) << endl;
349 if (wgn_opt_param ==
"entropy")
350 return -pow(2.0,(-1*(H/total)));
351 else if(wgn_opt_param ==
"B_NB_F1")
356 precision = bcorrect/bpredicted;
360 recall = bcorrect/bactual;
362 if((precision+recall) !=0)
363 fmeasure = 2* (precision*recall)/(precision+recall);
364 cout<<
"F1 :" << fmeasure <<
" Prec:" << precision <<
" Rec:" << recall <<
" B-Pred:" << bpredicted <<
" B-Actual:" << bactual <<
" B-Correct:" << bcorrect << endl;
368 return (
float)correct/(float)total;
371 static float test_tree_vector(
WNode &tree,
WDataSet &dataset,ostream *output)
378 float predict, actual;
386 for (p=dataset.head(); p != 0; p=p->next())
388 leaf = tree.predict_node((*dataset(p)));
389 pos = dataset(p)->get_int_val(wgn_predictee);
391 if (wgn_VertexFeats.
a(0,j) > 0.0)
394 for (pp=leaf->get_impurity().members.head(); pp != 0; pp=pp->next())
396 i = leaf->get_impurity().members.
item(pp);
397 b += wgn_VertexTrack.
a(i,j);
400 actual = wgn_VertexTrack.
a(pos,j);
401 if (wgn_count_field == -1)
404 count = dataset(p)->get_flt_val(wgn_count_field);
405 x.cumulate(predict,count);
406 y.cumulate(actual,count);
409 error = predict-actual;
411 error = (predict-actual)/b.
stddev();
412 error = predict-actual;
413 se.cumulate((error*error),count);
414 e.cumulate(fabs(error),count);
415 xx.cumulate(predict*predict,count);
416 yy.cumulate(actual*actual,count);
417 xy.cumulate(predict*actual,count);
442 <<
";; RMSE " << ftoString(sqrt(se.
mean()),4,1)
443 <<
" Correlation is " << ftoString(cor,4,1)
444 <<
" Mean (abs) Error " << ftoString(e.
mean(),4,1)
445 <<
" (" << ftoString(e.
stddev(),4,1) <<
")" << endl;
447 cout <<
"RMSE " << ftoString(sqrt(se.
mean()),4,1)
448 <<
" Correlation is " << ftoString(cor,4,1)
449 <<
" Mean (abs) Error " << ftoString(e.
mean(),4,1)
450 <<
" (" << ftoString(e.
stddev(),4,1) <<
")" << endl;
453 if (wgn_opt_param ==
"rmse")
454 return -sqrt(se.
mean());
459 static float test_tree_trajectory(
WNode &tree,
WDataSet &dataset,ostream *output)
467 float predict, actual;
475 for (p=dataset.head(); p != 0; p=p->next())
477 leaf = tree.predict_node((*dataset(p)));
478 pos = dataset(p)->get_int_val(wgn_predictee);
480 if (wgn_VertexFeats.
a(0,j) > 0.0)
483 for (pp=leaf->get_impurity().members.head(); pp != 0; pp=pp->next())
485 i = leaf->get_impurity().members.
item(pp);
486 b += wgn_VertexTrack.
a(i,j);
489 actual = wgn_VertexTrack.
a(pos,j);
490 if (wgn_count_field == -1)
493 count = dataset(p)->get_flt_val(wgn_count_field);
494 x.cumulate(predict,count);
495 y.cumulate(actual,count);
498 error = predict-actual;
500 error = (predict-actual)/b.
stddev();
501 error = predict-actual;
502 se.cumulate((error*error),count);
503 e.cumulate(fabs(error),count);
504 xx.cumulate(predict*predict,count);
505 yy.cumulate(actual*actual,count);
506 xy.cumulate(predict*actual,count);
531 <<
";; RMSE " << ftoString(sqrt(se.
mean()),4,1)
532 <<
" Correlation is " << ftoString(cor,4,1)
533 <<
" Mean (abs) Error " << ftoString(e.
mean(),4,1)
534 <<
" (" << ftoString(e.
stddev(),4,1) <<
")" << endl;
536 cout <<
"RMSE " << ftoString(sqrt(se.
mean()),4,1)
537 <<
" Correlation is " << ftoString(cor,4,1)
538 <<
" Mean (abs) Error " << ftoString(e.
mean(),4,1)
539 <<
" (" << ftoString(e.
stddev(),4,1) <<
")" << endl;
542 if (wgn_opt_param ==
"rmse")
543 return -sqrt(se.
mean());
548 static float test_tree_cluster(
WNode &tree,
WDataSet &dataset,ostream *output)
557 for (p=dataset.head(); p != 0; p=p->next())
559 leaf = tree.predict_node((*dataset(p)));
560 real = dataset(p)->get_int_val(wgn_predictee);
561 meandist += leaf->get_impurity().cluster_distance(real);
562 right_cluster += leaf->get_impurity().in_cluster(real);
563 ranking += leaf->get_impurity().cluster_ranking(real);
570 *output <<
";; Right cluster " << right_cluster <<
" (" <<
571 (int)(100.0*(
float)right_cluster/(
float)dataset.length()) <<
572 "%) mean ranking " << ranking.mean() << " mean distance "
573 << meandist.mean() << endl;
574 cout << "Right cluster " << right_cluster << " (" <<
575 (
int)(100.0*(
float)right_cluster/(
float)dataset.length()) <<
576 "%) mean ranking " << ranking.mean() << " mean distance "
577 << meandist.mean() << endl;
580 return 10000-meandist.mean();
583 static
float test_tree_float(
WNode &tree,
WDataSet &dataset,ostream *output)
592 for (p=dataset.head(); p != 0; p=p->next())
594 predict = tree.predict((*dataset(p)));
595 real = dataset(p)->get_flt_val(wgn_predictee);
596 if (wgn_count_field == -1)
599 count = dataset(p)->get_flt_val(wgn_count_field);
600 x.cumulate(predict,count);
601 y.cumulate(real,count);
602 error = predict-real;
603 se.cumulate((error*error),count);
604 e.cumulate(fabs(error),count);
605 xx.cumulate(predict*predict,count);
606 yy.cumulate(real*real,count);
607 xy.cumulate(predict*real,count);
631 <<
";; RMSE " << ftoString(sqrt(se.
mean()),4,1)
632 <<
" Correlation is " << ftoString(cor,4,1)
633 <<
" Mean (abs) Error " << ftoString(e.
mean(),4,1)
634 <<
" (" << ftoString(e.
stddev(),4,1) <<
")" << endl;
636 cout <<
"RMSE " << ftoString(sqrt(se.
mean()),4,1)
637 <<
" Correlation is " << ftoString(cor,4,1)
638 <<
" Mean (abs) Error " << ftoString(e.
mean(),4,1)
639 <<
" (" << ftoString(e.
stddev(),4,1) <<
")" << endl;
642 if (wgn_opt_param ==
"rmse")
643 return -sqrt(se.
mean());
648 static float test_tree_ols(
WNode &tree,
WDataSet &dataset,ostream *output)
658 for (p=dataset.head(); p != 0; p=p->next())
660 leaf = tree.predict_node((*dataset(p)));
663 real = dataset(p)->get_flt_val(wgn_predictee);
664 if (wgn_count_field == -1)
667 count = dataset(p)->get_flt_val(wgn_count_field);
668 x.cumulate(predict,count);
669 y.cumulate(real,count);
670 error = predict-real;
671 se.cumulate((error*error),count);
672 e.cumulate(fabs(error),count);
673 xx.cumulate(predict*predict,count);
674 yy.cumulate(real*real,count);
675 xy.cumulate(predict*real,count);
699 <<
";; RMSE " << ftoString(sqrt(se.
mean()),4,1)
700 <<
" Correlation is " << ftoString(cor,4,1)
701 <<
" Mean (abs) Error " << ftoString(e.
mean(),4,1)
702 <<
" (" << ftoString(e.
stddev(),4,1) <<
")" << endl;
704 cout <<
"RMSE " << ftoString(sqrt(se.
mean()),4,1)
705 <<
" Correlation is " << ftoString(cor,4,1)
706 <<
" Mean (abs) Error " << ftoString(e.
mean(),4,1)
707 <<
" (" << ftoString(e.
stddev(),4,1) <<
")" << endl;
710 if (wgn_opt_param ==
"rmse")
711 return -sqrt(se.
mean());
716 static int wagon_split(
int margin,
WNode &node)
722 node.set_impurity(
WImpurity(node.get_data()));
723 q = find_best_question(node.get_data());
729 double impurity_measure = node.get_impurity().measure();
730 double question_score = q.get_score();
732 if ((question_score < WGN_HUGE_VAL) &&
733 (question_score < impurity_measure))
739 wgn_find_split(q,node.get_data(),l->get_data(),r->get_data());
740 node.set_subnodes(l,r);
741 node.set_question(q);
745 for (i=0; i < margin; i++)
750 wagon_split(margin,*l);
752 wagon_split(margin,*r);
761 for (i=0; i < margin; i++)
763 cout <<
"stopped samples: " << node.samples() <<
" impurity: "
764 << node.get_impurity() << endl;
779 for (iy=in=i=0; i < ds.
n(); i++)
780 if (q.ask(*ds(i)) == TRUE)
794 bscore = tscore = WGN_HUGE_VAL;
795 best_ques.set_score(bscore);
797 for (i=0;i < wgn_dataset.width(); i++)
799 if ((wgn_dataset.ignore(i) == TRUE) ||
800 (i == wgn_predictee))
801 tscore = WGN_HUGE_VAL;
802 else if (wgn_dataset.ftype(i) == wndt_binary)
804 construct_binary_ques(i,test_ques);
805 tscore = wgn_score_question(test_ques,dset);
807 else if (wgn_dataset.ftype(i) == wndt_float)
809 tscore = construct_float_ques(i,test_ques,dset);
811 else if (wgn_dataset.ftype(i) == wndt_ignore)
812 tscore = WGN_HUGE_VAL;
815 else if (wgn_csubset && (wgn_dataset.ftype(i) >= wndt_class))
817 wagon_error(
"subset selection temporarily deleted");
818 tscore = construct_class_ques_subset(i,test_ques,dset);
821 else if (wgn_dataset.ftype(i) >= wndt_class)
822 tscore = construct_class_ques(i,test_ques,dset);
825 best_ques = test_ques;
826 best_ques.set_score(tscore);
837 float tscore,bscore = WGN_HUGE_VAL;
842 test_q.set_oper(wnop_is);
845 for (cl=0; cl < wgn_discretes[wgn_dataset.ftype(feat)].length(); cl++)
847 test_q.set_operand1(
EST_Val(cl));
848 tscore = wgn_score_question(test_q,ds);
860 static float construct_class_ques_subset(
int feat,
WQuestion &ques,
868 float tscore,bscore = WGN_HUGE_VAL;
873 ques.set_oper(wnop_is);
874 float *scores =
new float[wgn_discretes[wgn_dataset.ftype(feat)].length()];
877 for (cl=0; cl < wgn_discretes[wgn_dataset.ftype(feat)].length(); cl++)
879 ques.set_operand(flocons(cl));
880 scores[cl] = wgn_score_question(ques,ds);
883 LISP order = sort_class_scores(feat,scores);
886 if (siod_llength(order) == 1)
888 ques.set_oper(wnop_is);
889 ques.set_operand(car(order));
890 return scores[get_c_int(car(order))];
893 ques.set_oper(wnop_in);
895 for (l=cdr(order); CDR(l) != NIL; l = cdr(l))
898 tscore = wgn_score_question(ques,ds);
909 if (siod_llength(best_l) == 1)
911 ques.set_oper(wnop_is);
912 ques.set_operand(car(best_l));
914 else if (equal(cdr(order),best_l) != NIL)
916 ques.set_oper(wnop_is);
917 ques.set_operand(car(order));
921 cout <<
"Found a good subset" << endl;
922 ques.set_operand(best_l);
928 static LISP sort_class_scores(
int feat,
float *scores)
935 for (i=0; i < wgn_discretes[wgn_dataset.ftype(feat)].length(); i++)
937 if (scores[i] != WGN_HUGE_VAL)
940 items = cons(flocons(i),NIL);
943 for (l=items; l != NIL; l=cdr(l))
945 if (scores[i] < scores[get_c_int(car(l))])
947 CDR(l) = cons(car(l),cdr(l));
953 items = l_append(items,cons(flocons(i),NIL));
965 float tscore,bscore = WGN_HUGE_VAL;
969 float max,min,val,incr;
972 test_q.set_oper(wnop_lessthan);
975 min = max = ds(0)->get_flt_val(feat);
976 for (d=0; d < ds.
n(); d++)
978 val = ds(d)->get_flt_val(feat);
986 incr = (max-min)/wgn_float_range_split;
990 for (i=0,p=min+incr; i < wgn_float_range_split; i++,p += incr )
992 test_q.set_operand1(
EST_Val(p));
993 tscore = wgn_score_question(test_q,ds);
1004 static void construct_binary_ques(
int feat,
WQuestion &test_ques)
1010 test_ques.set_fp(feat);
1011 test_ques.set_oper(wnop_binary);
1012 test_ques.set_operand1(
EST_Val(
""));
1020 int d, num_yes, num_no;
1024 num_yes = num_no = 0;
1027 for (d=0; d < ds.
n(); d++)
1029 if ((ignorenth < 2) ||
1030 (d%ignorenth != ignorenth-1))
1033 if (wgn_count_field == -1)
1036 count = (*wv)[wgn_count_field];
1038 if (q.ask(*wv) == TRUE)
1041 if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
1042 y.cumulate(d,count);
1044 y.cumulate((*wv)[wgn_predictee],count);
1049 if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
1050 n.cumulate(d,count);
1052 n.cumulate((*wv)[wgn_predictee],count);
1062 if ((wgn_balance == 0.0) ||
1063 (ds.
n()/wgn_balance < wgn_min_cluster_size))
1064 min_cluster = wgn_min_cluster_size;
1066 min_cluster = (int)(ds.
n()/wgn_balance);
1068 if ((y.samples() < min_cluster) ||
1069 (n.samples() < min_cluster))
1070 return WGN_HUGE_VAL;
1090 return score_question_set(q,ds,1);
1093 WNode *wagon_stepwise(
float limit)
1101 WNode *best = 0,*new_best = 0;
1102 float bscore,best_score = -WGN_HUGE_VAL;
1107 for (i=0; i < wgn_dataset.width(); i++)
1108 wgn_dataset.set_ignore(i,TRUE);
1110 for (i=0; i < wgn_dataset.width(); i++)
1112 if ((wgn_dataset.ftype(i) == wndt_ignore) || (i == wgn_predictee))
1120 new_best = wagon_stepwise_find_next_best(bscore,best_feat);
1122 if ((bscore - fabs(bscore * (limit/100))) <= best_score)
1130 best_score = bscore;
1133 wgn_dataset.set_ignore(best_feat,FALSE);
1136 fprintf(stdout,
"FEATURE %d %s: %2.4f\n",
1138 (
const char *)wgn_dataset.feat_name(best_feat),
1149 static WNode *wagon_stepwise_find_next_best(
float &bscore,
int &best_feat)
1154 float best_score = -WGN_HUGE_VAL;
1155 int best_new_feat = -1;
1158 for (i=0; i < wgn_dataset.width(); i++)
1160 if (wgn_dataset.ftype(i) == wndt_ignore)
1162 else if (i == wgn_predictee)
1164 else if (wgn_dataset.ignore(i) == TRUE)
1170 wgn_dataset.set_ignore(i,FALSE);
1172 current = wgn_build_tree(score);
1174 if (score > best_score)
1190 wgn_dataset.set_ignore(i,TRUE);
1194 bscore = best_score;
1195 best_feat = best_new_feat;