44 #include "EST_cutils.h"
45 #include "EST_Token.h"
46 #include "EST_Wagon.h"
53 return impurity.value();
54 else if (question.ask(d))
55 return left->predict(d);
57 return right->predict(d);
64 else if (question.ask(d))
65 return left->predict_node(d);
67 return right->predict_node(d);
74 if ((left == 0) && (right == 0))
76 else if (get_impurity().type() != wnim_class)
82 void WNode::prune(
void)
90 if (left != 0) left->prune();
91 if (right != 0) right->prune();
95 if ((left->pure() == TRUE) && ((right->pure() == TRUE)) &&
96 (left->get_impurity().value() == right->get_impurity().value()))
98 delete left; left = 0;
99 delete right; right = 0;
105 void WNode::held_out_prune()
116 wgn_score_question(question,get_data());
117 if (question.get_score() < get_impurity().measure())
119 wgn_find_split(question,get_data(),
122 left->held_out_prune();
123 right->held_out_prune();
127 delete left; left = 0;
128 delete right; right = 0;
133 void WNode::print_out(ostream &s,
int margin)
138 for (i=0;i<margin;i++) s <<
" ";
145 left->print_out(s,margin+1);
146 right->print_out(s,margin+1);
151 ostream & operator <<(ostream &s,
WNode &n)
160 void WDataSet::ignore_non_numbers()
165 for (i=0; i<dlength; i++)
167 if ((p_type[i] == wndt_binary) ||
168 (p_type[i] == wndt_float))
179 void WDataSet::load_description(
const EST_String &fname, LISP ignores)
186 description = car(vload(fname,1));
187 dlength = siod_llength(description);
193 if (wgn_predictee_name ==
"")
198 for (i=0,d=description; d != NIL; d=cdr(d),i++)
200 p_name[i] = get_c_string(car(car(d)));
201 tname = get_c_string(car(cdr(car(d))));
203 if ((wgn_predictee_name !=
"") && (wgn_predictee_name == p_name[i]))
205 if ((wgn_count_field_name !=
"") &&
206 (wgn_count_field_name == p_name[i]))
208 if ((tname ==
"count") || (i == wgn_count_field))
211 p_type[i] = wndt_ignore;
215 else if ((tname ==
"ignore") || (siod_member_str(p_name[i],ignores)))
217 p_type[i] = wndt_ignore;
219 if (i == wgn_predictee)
220 wagon_error(
EST_String(
"predictee \"")+p_name[i]+
221 "\" can't be ignored \n");
223 else if (siod_llength(car(d)) > 2)
225 LISP rest = cdr(car(d));
227 siod_list_to_strlist(rest,sl);
228 p_type[i] = wgn_discretes.def(sl);
229 if (streq(get_c_string(car(rest)),
"_other_"))
230 wgn_discretes[p_type[i]].def_val(
"_other_");
232 else if (tname ==
"binary")
233 p_type[i] = wndt_binary;
234 else if (tname ==
"cluster")
235 p_type[i] = wndt_cluster;
236 else if (tname ==
"vector")
237 p_type[i] = wndt_vector;
238 else if (tname ==
"trajectory")
239 p_type[i] = wndt_trajectory;
240 else if (tname ==
"ols")
241 p_type[i] = wndt_ols;
242 else if (tname ==
"matrix")
243 p_type[i] = wndt_matrix;
244 else if (tname ==
"float")
245 p_type[i] = wndt_float;
248 wagon_error(
EST_String(
"Unknown type \"")+tname+
249 "\" for field number "+itoString(i)+
250 "/"+p_name[i]+
" in description file \""+fname+
"\"");
254 if (wgn_predictee == -1)
256 wagon_error(
EST_String(
"predictee field \"")+wgn_predictee_name+
257 "\" not found in description ");
261 const int WQuestion::ask(
const WVector &w)
const
267 if (w.get_flt_val(feature_pos) == operand1.
Float())
272 if (w.get_int_val(feature_pos) == 1)
276 case wnop_greaterthan:
277 if (w.get_flt_val(feature_pos) > operand1.
Float())
282 if (w.get_flt_val(feature_pos) < operand1.
Float())
287 if (w.get_int_val(feature_pos) == operand1.
Int())
292 if (ilist_member(operandl,w.get_int_val(feature_pos)))
297 wagon_error(
"Unknown test operator");
303 ostream& operator<<(ostream& s,
const WQuestion &q)
306 static EST_Regex needquotes(
".*[()'\";., \t\n\r].*");
308 s <<
"(" << wgn_dataset.feat_name(q.get_fp());
312 s <<
" = " << q.get_operand1().
string();
316 case wnop_greaterthan:
317 s <<
" > " << q.get_operand1().
Float();
320 s <<
" < " << q.get_operand1().
Float();
323 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
324 name(q.get_operand1().
Int());
327 s << quote_string(name,
"\"",
"\\",1);
332 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
333 name(q.get_operand1().
Int());
334 s <<
" matches " << quote_string(name,
"\"",
"\\",1);
338 for (
int l=0; l < q.get_operandl().length(); l++)
340 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
341 name(q.get_operandl().
nth(l));
342 if (name.matches(needquotes))
343 s << quote_string(name,
"\"",
"\\",1);
367 cerr <<
"WImpurity: no value currently set\n";
370 else if (t==wnim_class)
372 else if (t==wnim_cluster)
374 else if (t==wnim_ols)
376 else if (t==wnim_vector)
378 else if (t==wnim_trajectory)
384 double WImpurity::samples(
void)
388 else if (t==wnim_class)
390 else if (t==wnim_cluster)
391 return members.length();
392 else if (t==wnim_ols)
393 return members.length();
394 else if (t==wnim_vector)
395 return members.length();
396 else if (t==wnim_trajectory)
397 return members.length();
407 a.
reset(); trajectory=0; l=0; width=0;
409 for (i=0; i < ds.
n(); i++)
413 else if (wgn_count_field == -1)
414 cumulate((*(ds(i)))[wgn_predictee],1);
416 cumulate((*(ds(i)))[wgn_predictee],
417 (*(ds(i)))[wgn_count_field]);
421 float WImpurity::measure(
void)
425 else if (t == wnim_vector)
426 return vector_impurity();
427 else if (t == wnim_trajectory)
428 return trajectory_impurity();
429 else if (t == wnim_matrix)
431 else if (t == wnim_class)
433 else if (t == wnim_cluster)
434 return cluster_impurity();
435 else if (t == wnim_ols)
436 return ols_impurity();
439 cerr <<
"WImpurity: can't measure unset object" << endl;
444 float WImpurity::vector_impurity()
459 if (wgn_VertexFeats.
a(0,j) > 0.0)
462 for (pp=members.head(), countpp=member_counts.head(); pp != 0; pp=pp->next(), countpp=countpp->next())
464 i = members.
item(pp);
467 b.cumulate(wgn_VertexTrack.
a(i,j), member_counts.
item(countpp)) ;
477 float x, lshift, rshift, ushift;
482 if (wgn_VertexFeats.
a(0,j) > 0.0)
485 for (pp=members.head(), countpp=member_counts.head(); pp != 0;
486 pp=pp->next(), countpp=countpp->next())
488 i = members.
item(pp);
490 c[j].cumulate(wgn_VertexTrack.
a(i,j),member_counts.
item(countpp));
497 for (pp=members.head(), countpp=member_counts.head(); pp != 0;
498 pp=pp->next(), countpp=countpp->next())
501 float bshift, qshift;
503 i = members.
item(pp);
505 lshift = 0; ushift = 0; rshift = 0;
507 for (q=-20; q<=20; q++)
510 for (j=67+q; j<147+q; j++)
512 x = c[j].
mean() - wgn_VertexTrack(i,j);
514 if ((bshift > 0) && (qshift > bshift))
517 if ((bshift == 0) || (qshift < bshift))
536 if (wgn_VertexFeats.
a(0,j) > 0.0)
538 for (pp=members.head(); pp != 0; pp=pp->next())
539 cs[j][j] += wgn_VertexTrack.
a(members.
item(pp),j);
545 if (wgn_VertexFeats.
a(0,j) > 0.0)
547 for (pp=members.head(); pp != 0; pp=pp->next())
549 mmm = members.
item(pp);
550 cs[i][j] += (wgn_VertexTrack.
a(mmm,i)-cs[j][j].
mean())*
551 (wgn_VertexTrack.
a(mmm,j)-cs[j][j].
mean());
558 if (wgn_VertexFeats.
a(0,j) > 0.0)
570 for (pp=members.head(); pp != 0; pp=pp->next())
572 x = members.
item(pp);
574 for (qq=pp->next(); qq != 0; qq=qq->next())
576 y = members.
item(qq);
578 if (wgn_VertexFeats.
a(0,j) > 0.0)
580 d = wgn_VertexTrack(x,j)-wgn_VertexTrack(y,j);
590 return a.
mean() * count;
593 WImpurity::~WImpurity()
600 delete [] trajectory[j];
601 delete [] trajectory;
608 float WImpurity::trajectory_impurity()
620 double n, m, m1, m2, w;
633 for (pp=members.head(); pp != 0; pp=pp->next())
635 i = members.
item(pp);
636 for (q=0; q<wgn_UnitTrack.
a(i,1); q++)
638 ni = (int)wgn_UnitTrack.
a(i,0)+q;
639 if (wgn_VertexTrack.
a(ni,0) == -1.0)
646 if (q==wgn_UnitTrack.
a(i,1))
652 l2ss += wgn_UnitTrack.
a(i,1) - (q+1) - 1;
653 lss += wgn_UnitTrack.
a(i,1);
654 if (wgn_UnitTrack.
a(i,1) > l)
655 l = (
int)wgn_UnitTrack.
a(i,1);
660 l = ((int)lss.
mean() < 7) ? 7 : (
int)lss.
mean();
668 for (pp=members.head(); pp != 0; pp=pp->next())
670 i = members.
item(pp);
671 m = (float)wgn_UnitTrack.
a(i,1)/(float)l;
672 s = (int)wgn_UnitTrack.
a(i,0);
673 for (ti=0,n=0.0; ti<l; ti++,n+=m)
678 if (wgn_VertexFeats.
a(0,j) > 0.0)
679 trajectory[ti][j] += wgn_VertexTrack.
a(s+ni,j);
686 for (ti=0; ti<l; ti++)
689 if (wgn_VertexFeats.
a(0,j) > 0.0)
690 stdss += trajectory[ti][j].
stddev();
694 score = stdss.
mean() * members.length();
698 l1 = (l1ss.
mean() < 10.0) ? 10 : (int)l1ss.
mean();
699 l2 = (l2ss.
mean() < 10.0) ? 10 : (int)l2ss.
mean();
707 for (pp=members.head(); pp != 0; pp=pp->next())
709 i = members.
item(pp);
711 s = (int)wgn_UnitTrack.
a(i,0);
712 for (q=0; q<wgn_UnitTrack.
a(i,1); q++)
713 if (wgn_VertexTrack.
a(s+q,0) == -1.0)
718 s2l = (int)wgn_UnitTrack.
a(i,1) - (s1l + 2);
719 m1 = (float)(s1l)/(float)l1;
720 m2 = (float)(s2l)/(float)l2;
722 for (ti=0,n=0.0; s1l > 0 && ti<l1; ti++,n+=m1)
724 ni = s + (((int)n < s1l) ? (int)n : s1l - 1);
726 if (wgn_VertexFeats.
a(0,j) > 0.0)
727 trajectory[ti][j] += wgn_VertexTrack.
a(ni,j);
731 if (wgn_VertexFeats.
a(0,j) > 0.0)
732 trajectory[ti][j] += -1;
735 for (ti++,n=0.0; s2l > 0 && ti<l-1; ti++,n+=m2)
737 ni = s + (((int)n < s2l) ? (int)n : s2l - 1);
739 if (wgn_VertexFeats.
a(0,j) > 0.0)
740 trajectory[ti][j] += wgn_VertexTrack.
a(ni,j);
743 if (wgn_VertexFeats.
a(0,j) > 0.0)
744 trajectory[ti][j] += -2;
751 for (w=0.0,ti=0; ti<l1; ti++,w+=m)
753 if (wgn_VertexFeats.
a(0,j) > 0.0)
754 stdss += trajectory[ti][j].
stddev() * w;
756 for (w=1.0,ti++; ti<l-1; ti++,w-=m)
758 if (wgn_VertexFeats.
a(0,j) > 0.0)
759 stdss += trajectory[ti][j].
stddev() * w;
762 score = stdss.
mean() * members.length();
778 w = wgn_dataset.width();
780 X.
resize(members.length(),w);
781 Y.
resize(members.length(),1);
782 feat_names.
append(
"Intercept");
785 for (p=0,pp=members.head(); pp; p++,pp=pp->next())
787 n = members.
item(pp);
796 for (m=1,xm=1; m < w; m++)
798 if (wgn_dataset.ftype(m) == wndt_float)
802 feat_names.
append(wgn_dataset.feat_name(m));
817 float WImpurity::ols_impurity()
831 part_to_ols_data(X,Y,included,feat_names,members,*data);
840 if (!robust_ols(X,Y,included,coeffsl))
845 ols_apply(X,coeffsl,pred);
846 ols_test(Y,pred,cor,rmse);
849 printf(
"Impurity OLS X(%d,%d) Y(%d,%d) %f, %f, %f\n",
853 if (fabs(coeffsl[0]) > 10000)
859 return (1-best_score) *members.length();
862 float WImpurity::cluster_impurity()
873 for (pp=members.head(); pp != 0; pp=pp->next())
875 i = members.
item(pp);
876 for (q=pp->next(); q != 0; q=q->next())
879 dist = (j < i ? wgn_DistMatrix.
a_no_check(i,j) :
893 float WImpurity::cluster_distance(
int i)
897 float dist = cluster_member_mean(i);
898 float mdist = dist-a.
mean();
907 int WImpurity::in_cluster(
int i)
911 float dist = cluster_member_mean(i);
914 for (pp=members.head(); pp != 0; pp=pp->next())
916 if (dist < cluster_member_mean(members.
item(pp)))
922 float WImpurity::cluster_ranking(
int i)
925 float dist = cluster_distance(i);
929 for (pp=members.head(); pp != 0; pp=pp->next())
931 if (dist >= cluster_distance(members.
item(pp)))
938 float WImpurity::cluster_member_mean(
int i)
946 for (sum=0.0,n=0,q=members.head(); q != 0; q=q->next())
951 dist = (j < i ? wgn_DistMatrix(i,j) : wgn_DistMatrix(j,i));
957 return ( n == 0 ? 0.0 : sum/n );
960 void WImpurity::cumulate(
const float pv,
double count)
964 if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
969 else if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
974 else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
980 member_counts.
append((
float)count);
982 else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
987 else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
990 p.init(&wgn_discretes[wgn_dataset.ftype(wgn_predictee)]);
992 p.cumulate((
int)pv,count);
994 else if (wgn_dataset.ftype(wgn_predictee) == wndt_binary)
997 a.cumulate((
int)pv,count);
999 else if (wgn_dataset.ftype(wgn_predictee) == wndt_float)
1002 a.cumulate(pv,count);
1006 wagon_error(
"WImpurity: cannot cumulate EST_Val type");
1010 ostream & operator <<(ostream &s,
WImpurity &imp)
1015 if (imp.t == wnim_float)
1016 s <<
"(" << imp.a.
stddev() <<
" " << imp.a.
mean() <<
")";
1017 else if (imp.t == wnim_vector)
1021 imp.vector_impurity();
1022 if (wgn_vertex_output ==
"mean")
1027 for (p=imp.members.head(), countp=imp.member_counts.head(); p != 0; p=p->next(), countp=countp->next())
1030 b.cumulate(wgn_VertexTrack.
a(imp.members.
item(p),j), imp.member_counts.
item(countp));
1033 s <<
"(" << b.
mean() <<
" ";
1034 if (isfinite(b.
stddev()))
1037 s <<
"0.001" <<
")";
1045 double best = WGN_HUGE_VAL;
1053 if (wgn_VertexFeats.
a(0,j) > 0.0)
1056 for (p=imp.members.head(); p != 0; p=p->next())
1058 cs[j] += wgn_VertexTrack.
a(imp.members.
item(p),j);
1062 for (p=imp.members.head(); p != 0; p=p->next())
1065 if (wgn_VertexFeats.
a(0,j) > 0.0)
1067 d = (wgn_VertexTrack.
a(imp.members.
item(p),j)-cs[j].mean())
1073 bestp = imp.members.
item(p);
1080 s << wgn_VertexTrack.
a(bestp,j);
1083 if (isfinite(cs[j].stddev()))
1095 s << imp.a.
mean() <<
")";
1097 else if (imp.t == wnim_trajectory)
1100 imp.trajectory_impurity();
1101 for (i=0; i<imp.l; i++)
1106 s <<
"(" << imp.trajectory[i][j].
mean() <<
" "
1107 << imp.trajectory[i][j].
stddev() <<
" " <<
")";
1113 s << imp.a.
mean() <<
")";
1115 else if (imp.t == wnim_cluster)
1119 for (p=imp.members.head(); p != 0; p=p->next())
1122 s <<
"(" << imp.members.
item(p) <<
" " <<
1123 imp.cluster_member_mean(imp.members.
item(p)) <<
")";
1129 s << imp.a.
mean() <<
")";
1131 else if (imp.t == wnim_ols)
1144 part_to_ols_data(X,Y,included,feat_names,imp.members,*(imp.data));
1145 if (!robust_ols(X,Y,included,coeffsl))
1147 printf(
"no robust ols\n");
1152 ols_apply(X,coeffsl,pred);
1153 ols_test(Y,pred,cor,rmse);
1154 for (i=0; i<coeffsl.
num_rows(); i++)
1157 s << feat_names.
nth(i);
1165 s <<
") " << cor <<
")";
1167 else if (imp.t == wnim_class)
1177 s <<
"(" << name <<
" " << prob <<
") ";
1182 s <<
"([WImpurity unset])";