Edinburgh Speech Tools  2.4-release
 All Classes Functions Variables Typedefs Enumerations Enumerator Friends Pages
wagon_aux.cc
1 /*************************************************************************/
2 /* */
3 /* Centre for Speech Technology Research */
4 /* University of Edinburgh, UK */
5 /* Copyright (c) 1996,1997 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black */
34 /* Date : May 1996 */
35 /*-----------------------------------------------------------------------*/
36 /* */
37 /* Various method functions */
38 /*=======================================================================*/
39 
40 #include <cstdlib>
41 #include <iostream>
42 #include <cstring>
43 #include "EST_unix.h"
44 #include "EST_cutils.h"
45 #include "EST_Token.h"
46 #include "EST_Wagon.h"
47 #include "EST_math.h"
48 
49 
50 EST_Val WNode::predict(const WVector &d)
51 {
52  if (leaf())
53  return impurity.value();
54  else if (question.ask(d))
55  return left->predict(d);
56  else
57  return right->predict(d);
58 }
59 
60 WNode *WNode::predict_node(const WVector &d)
61 {
62  if (leaf())
63  return this;
64  else if (question.ask(d))
65  return left->predict_node(d);
66  else
67  return right->predict_node(d);
68 }
69 
70 int WNode::pure(void)
71 {
72  // A node is pure if it has no sub-nodes or its not of type class
73 
74  if ((left == 0) && (right == 0))
75  return TRUE;
76  else if (get_impurity().type() != wnim_class)
77  return TRUE;
78  else
79  return FALSE;
80 }
81 
82 void WNode::prune(void)
83 {
84  // Check all sub-nodes and if they are all of the same class
85  // delete their sub nodes. Returns pureness of this node
86 
87  if (pure() == FALSE)
88  {
89  // Ok lets try and make it pure
90  if (left != 0) left->prune();
91  if (right != 0) right->prune();
92 
93  // Have to check purity as well as values to ensure left and right
94  // don't further split
95  if ((left->pure() == TRUE) && ((right->pure() == TRUE)) &&
96  (left->get_impurity().value() == right->get_impurity().value()))
97  {
98  delete left; left = 0;
99  delete right; right = 0;
100  }
101  }
102 
103 }
104 
105 void WNode::held_out_prune()
106 {
107  // prune tree with held out data
108  // Check if node's questions differentiates for the held out data
109  // if not, prune all sub_nodes
110 
111  // Rescore with prune data
112  set_impurity(WImpurity(get_data())); // for this new data
113 
114  if (left != 0)
115  {
116  wgn_score_question(question,get_data());
117  if (question.get_score() < get_impurity().measure())
118  { // its worth goint ot the next level
119  wgn_find_split(question,get_data(),
120  left->get_data(),
121  right->get_data());
122  left->held_out_prune();
123  right->held_out_prune();
124  }
125  else
126  { // not worth the split so prune both sub_nodes
127  delete left; left = 0;
128  delete right; right = 0;
129  }
130  }
131 }
132 
133 void WNode::print_out(ostream &s, int margin)
134 {
135  int i;
136 
137  s << endl;
138  for (i=0;i<margin;i++) s << " ";
139  s << "(";
140  if (left==0) // base case
141  s << impurity;
142  else
143  {
144  s << question;
145  left->print_out(s,margin+1);
146  right->print_out(s,margin+1);
147  }
148  s << ")";
149 }
150 
151 ostream & operator <<(ostream &s, WNode &n)
152 {
153  // Output this node and its sub-node
154 
155  n.print_out(s,0);
156  s << endl;
157  return s;
158 }
159 
160 void WDataSet::ignore_non_numbers()
161 {
162  /* For ols we want to ignore anything that is categorial */
163  int i;
164 
165  for (i=0; i<dlength; i++)
166  {
167  if ((p_type[i] == wndt_binary) ||
168  (p_type[i] == wndt_float))
169  continue;
170  else
171  {
172  p_ignore[i] = TRUE;
173  }
174  }
175 
176  return;
177 }
178 
179 void WDataSet::load_description(const EST_String &fname, LISP ignores)
180 {
181  // Initialise a dataset with sizes and types
182  EST_String tname;
183  int i;
184  LISP description,d;
185 
186  description = car(vload(fname,1));
187  dlength = siod_llength(description);
188 
189  p_type.resize(dlength);
190  p_ignore.resize(dlength);
191  p_name.resize(dlength);
192 
193  if (wgn_predictee_name == "")
194  wgn_predictee = 0; // default predictee is first field
195  else
196  wgn_predictee = -1;
197 
198  for (i=0,d=description; d != NIL; d=cdr(d),i++)
199  {
200  p_name[i] = get_c_string(car(car(d)));
201  tname = get_c_string(car(cdr(car(d))));
202  p_ignore[i] = FALSE;
203  if ((wgn_predictee_name != "") && (wgn_predictee_name == p_name[i]))
204  wgn_predictee = i;
205  if ((wgn_count_field_name != "") &&
206  (wgn_count_field_name == p_name[i]))
207  wgn_count_field = i;
208  if ((tname == "count") || (i == wgn_count_field))
209  {
210  // The count must be ignored, repeat it if you want it too
211  p_type[i] = wndt_ignore; // the count must be ignored
212  p_ignore[i] = TRUE;
213  wgn_count_field = i;
214  }
215  else if ((tname == "ignore") || (siod_member_str(p_name[i],ignores)))
216  {
217  p_type[i] = wndt_ignore; // user specified ignore
218  p_ignore[i] = TRUE;
219  if (i == wgn_predictee)
220  wagon_error(EST_String("predictee \"")+p_name[i]+
221  "\" can't be ignored \n");
222  }
223  else if (siod_llength(car(d)) > 2)
224  {
225  LISP rest = cdr(car(d));
226  EST_StrList sl;
227  siod_list_to_strlist(rest,sl);
228  p_type[i] = wgn_discretes.def(sl);
229  if (streq(get_c_string(car(rest)),"_other_"))
230  wgn_discretes[p_type[i]].def_val("_other_");
231  }
232  else if (tname == "binary")
233  p_type[i] = wndt_binary;
234  else if (tname == "cluster")
235  p_type[i] = wndt_cluster;
236  else if (tname == "vector")
237  p_type[i] = wndt_vector;
238  else if (tname == "trajectory")
239  p_type[i] = wndt_trajectory;
240  else if (tname == "ols")
241  p_type[i] = wndt_ols;
242  else if (tname == "matrix")
243  p_type[i] = wndt_matrix;
244  else if (tname == "float")
245  p_type[i] = wndt_float;
246  else
247  {
248  wagon_error(EST_String("Unknown type \"")+tname+
249  "\" for field number "+itoString(i)+
250  "/"+p_name[i]+" in description file \""+fname+"\"");
251  }
252  }
253 
254  if (wgn_predictee == -1)
255  {
256  wagon_error(EST_String("predictee field \"")+wgn_predictee_name+
257  "\" not found in description ");
258  }
259 }
260 
261 const int WQuestion::ask(const WVector &w) const
262 {
263  // Ask this question of the given vector
264  switch (op)
265  {
266  case wnop_equal: // for numbers
267  if (w.get_flt_val(feature_pos) == operand1.Float())
268  return TRUE;
269  else
270  return FALSE;
271  case wnop_binary: // for numbers
272  if (w.get_int_val(feature_pos) == 1)
273  return TRUE;
274  else
275  return FALSE;
276  case wnop_greaterthan:
277  if (w.get_flt_val(feature_pos) > operand1.Float())
278  return TRUE;
279  else
280  return FALSE;
281  case wnop_lessthan:
282  if (w.get_flt_val(feature_pos) < operand1.Float())
283  return TRUE;
284  else
285  return FALSE;
286  case wnop_is: // for classes
287  if (w.get_int_val(feature_pos) == operand1.Int())
288  return TRUE;
289  else
290  return FALSE;
291  case wnop_in: // for subsets -- note operand is list of ints
292  if (ilist_member(operandl,w.get_int_val(feature_pos)))
293  return TRUE;
294  else
295  return FALSE;
296  default:
297  wagon_error("Unknown test operator");
298  }
299 
300  return FALSE;
301 }
302 
303 ostream& operator<<(ostream& s, const WQuestion &q)
304 {
305  EST_String name;
306  static EST_Regex needquotes(".*[()'\";., \t\n\r].*");
307 
308  s << "(" << wgn_dataset.feat_name(q.get_fp());
309  switch (q.get_op())
310  {
311  case wnop_equal:
312  s << " = " << q.get_operand1().string();
313  break;
314  case wnop_binary:
315  break;
316  case wnop_greaterthan:
317  s << " > " << q.get_operand1().Float();
318  break;
319  case wnop_lessthan:
320  s << " < " << q.get_operand1().Float();
321  break;
322  case wnop_is:
323  name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
324  name(q.get_operand1().Int());
325  s << " is ";
326  if (name.matches(needquotes))
327  s << quote_string(name,"\"","\\",1);
328  else
329  s << name;
330  break;
331  case wnop_matches:
332  name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
333  name(q.get_operand1().Int());
334  s << " matches " << quote_string(name,"\"","\\",1);
335  break;
336  case wnop_in:
337  s << " in (";
338  for (int l=0; l < q.get_operandl().length(); l++)
339  {
340  name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
341  name(q.get_operandl().nth(l));
342  if (name.matches(needquotes))
343  s << quote_string(name,"\"","\\",1);
344  else
345  s << name;
346  s << " ";
347  }
348  s << ")";
349  break;
350  // SunCC wont let me add this
351 // default:
352 // s << " unknown operation ";
353  }
354  s << ")";
355 
356  return s;
357 }
358 
359 EST_Val WImpurity::value(void)
360 {
361  // Returns the recommended value for this
362  EST_String s;
363  double prob;
364 
365  if (t==wnim_unset)
366  {
367  cerr << "WImpurity: no value currently set\n";
368  return EST_Val(0.0);
369  }
370  else if (t==wnim_class)
371  return EST_Val(p.most_probable(&prob));
372  else if (t==wnim_cluster)
373  return EST_Val(a.mean());
374  else if (t==wnim_ols) /* OLS TBA */
375  return EST_Val(a.mean());
376  else if (t==wnim_vector)
377  return EST_Val(a.mean()); /* wnim_vector */
378  else if (t==wnim_trajectory)
379  return EST_Val(a.mean()); /* NOT YET WRITTEN */
380  else
381  return EST_Val(a.mean());
382 }
383 
384 double WImpurity::samples(void)
385 {
386  if (t==wnim_float)
387  return a.samples();
388  else if (t==wnim_class)
389  return (int)p.samples();
390  else if (t==wnim_cluster)
391  return members.length();
392  else if (t==wnim_ols)
393  return members.length();
394  else if (t==wnim_vector)
395  return members.length();
396  else if (t==wnim_trajectory)
397  return members.length();
398  else
399  return 0;
400 }
401 
402 WImpurity::WImpurity(const WVectorVector &ds)
403 {
404  int i;
405 
406  t=wnim_unset;
407  a.reset(); trajectory=0; l=0; width=0;
408  data = &ds; // for ols, model calculation
409  for (i=0; i < ds.n(); i++)
410  {
411  if (t == wnim_ols)
412  cumulate(i,1);
413  else if (wgn_count_field == -1)
414  cumulate((*(ds(i)))[wgn_predictee],1);
415  else
416  cumulate((*(ds(i)))[wgn_predictee],
417  (*(ds(i)))[wgn_count_field]);
418  }
419 }
420 
421 float WImpurity::measure(void)
422 {
423  if (t == wnim_float)
424  return a.variance()*a.samples();
425  else if (t == wnim_vector)
426  return vector_impurity();
427  else if (t == wnim_trajectory)
428  return trajectory_impurity();
429  else if (t == wnim_matrix)
430  return a.variance()*a.samples();
431  else if (t == wnim_class)
432  return p.entropy()*p.samples();
433  else if (t == wnim_cluster)
434  return cluster_impurity();
435  else if (t == wnim_ols)
436  return ols_impurity(); /* RMSE for OLS model */
437  else
438  {
439  cerr << "WImpurity: can't measure unset object" << endl;
440  return 0.0;
441  }
442 }
443 
444 float WImpurity::vector_impurity()
445 {
446  // Find the mean/stddev for all values in all vectors
447  // sum the variances and multiply them by the number of members
448  EST_Litem *pp;
449  EST_Litem *countpp;
450  int i,j;
451  EST_SuffStats b;
452  double count = 1;
453 
454  a.reset();
455 #if 1
456  /* simple distance */
457  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
458  {
459  if (wgn_VertexFeats.a(0,j) > 0.0)
460  {
461  b.reset();
462  for (pp=members.head(), countpp=member_counts.head(); pp != 0; pp=pp->next(), countpp=countpp->next())
463  {
464  i = members.item(pp);
465 
466  // Accumulate the value with count
467  b.cumulate(wgn_VertexTrack.a(i,j), member_counts.item(countpp)) ;
468  }
469  a += b.stddev();
470  count = b.samples();
471  }
472  }
473 #endif
474 
475 #if 0
476  EST_SuffStats *c;
477  float x, lshift, rshift, ushift;
478  /* Find base mean, then measure do fshift to find best match */
479  c = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
480  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
481  {
482  if (wgn_VertexFeats.a(0,j) > 0.0)
483  {
484  c[j].reset();
485  for (pp=members.head(), countpp=member_counts.head(); pp != 0;
486  pp=pp->next(), countpp=countpp->next())
487  {
488  i = members.item(pp);
489  // Accumulate the value with count
490  c[j].cumulate(wgn_VertexTrack.a(i,j),member_counts.item(countpp));
491  }
492  count = c[j].samples();
493  }
494  }
495 
496  /* Pass through again but vary the num_channels offset (hardcoded) */
497  for (pp=members.head(), countpp=member_counts.head(); pp != 0;
498  pp=pp->next(), countpp=countpp->next())
499  {
500  int q;
501  float bshift, qshift;
502  /* For each sample */
503  i = members.item(pp);
504  /* Find the value left shifted, unshifted, and right shifted */
505  lshift = 0; ushift = 0; rshift = 0;
506  bshift = 0;
507  for (q=-20; q<=20; q++)
508  {
509  qshift = 0;
510  for (j=67+q; j<147+q/*hardcoded*/; j++)
511  {
512  x = c[j].mean() - wgn_VertexTrack(i,j);
513  qshift += sqrt(x*x);
514  if ((bshift > 0) && (qshift > bshift))
515  break;
516  }
517  if ((bshift == 0) || (qshift < bshift))
518  bshift = qshift;
519  }
520  a += bshift;
521  }
522 
523 #endif
524 
525 #if 0
526  /* full covariance */
527  /* worse in listening experiments */
528  EST_SuffStats **cs;
529  int mmm;
530  cs = new EST_SuffStats *[wgn_VertexTrack.num_channels()+1];
531  for (j=0; j<=wgn_VertexTrack.num_channels(); j++)
532  cs[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
533  /* Find means for diagonal */
534  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
535  {
536  if (wgn_VertexFeats.a(0,j) > 0.0)
537  {
538  for (pp=members.head(); pp != 0; pp=pp->next())
539  cs[j][j] += wgn_VertexTrack.a(members.item(pp),j);
540  }
541  }
542  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
543  {
544  for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
545  if (wgn_VertexFeats.a(0,j) > 0.0)
546  {
547  for (pp=members.head(); pp != 0; pp=pp->next())
548  {
549  mmm = members.item(pp);
550  cs[i][j] += (wgn_VertexTrack.a(mmm,i)-cs[j][j].mean())*
551  (wgn_VertexTrack.a(mmm,j)-cs[j][j].mean());
552  }
553  }
554  }
555  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
556  {
557  for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
558  if (wgn_VertexFeats.a(0,j) > 0.0)
559  a += cs[i][j].stddev();
560  }
561  count = cs[0][0].samples();
562 #endif
563 
564 #if 0
565  // look at mean euclidean distance between vectors
566  EST_Litem *qq;
567  int x,y;
568  double d,q;
569  count = 0;
570  for (pp=members.head(); pp != 0; pp=pp->next())
571  {
572  x = members.item(pp);
573  count++;
574  for (qq=pp->next(); qq != 0; qq=qq->next())
575  {
576  y = members.item(qq);
577  for (q=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
578  if (wgn_VertexFeats.a(0,j) > 0.0)
579  {
580  d = wgn_VertexTrack(x,j)-wgn_VertexTrack(y,j);
581  q += d*d;
582  }
583  a += sqrt(q);
584  }
585 
586  }
587 #endif
588 
589  // This is sum of stddev*samples
590  return a.mean() * count;
591 }
592 
593 WImpurity::~WImpurity()
594 {
595  int j;
596 
597  if (trajectory != 0)
598  {
599  for (j=0; j<l; j++)
600  delete [] trajectory[j];
601  delete [] trajectory;
602  trajectory = 0;
603  l = 0;
604  }
605 }
606 
607 
608 float WImpurity::trajectory_impurity()
609 {
610  // Find the mean length of all the units in the cluster
611  // Create that number of points
612  // Interpolate each unit to that number of points
613  // collect means and standard deviations for each point
614  // impurity is sum of the variance for each point and each coef
615  // multiplied by the number of units.
616  EST_Litem *pp;
617  int i, j;
618  int s, ti, ni, q;
619  int s1l, s2l;
620  double n, m, m1, m2, w;
621  EST_SuffStats lss, stdss;
622  EST_SuffStats l1ss, l2ss;
623  int l1, l2;
624  int ola=0;
625 
626  if (trajectory != 0)
627  { /* already done this */
628  return score;
629  }
630 
631  lss.reset();
632  l = 0;
633  for (pp=members.head(); pp != 0; pp=pp->next())
634  {
635  i = members.item(pp);
636  for (q=0; q<wgn_UnitTrack.a(i,1); q++)
637  {
638  ni = (int)wgn_UnitTrack.a(i,0)+q;
639  if (wgn_VertexTrack.a(ni,0) == -1.0)
640  {
641  l1ss += q;
642  ola = 1;
643  break;
644  }
645  }
646  if (q==wgn_UnitTrack.a(i,1))
647  { /* can't find -1 center point, so put all in l2 */
648  l1ss += 0;
649  l2ss += q;
650  }
651  else
652  l2ss += wgn_UnitTrack.a(i,1) - (q+1) - 1;
653  lss += wgn_UnitTrack.a(i,1); /* length of each unit in the cluster */
654  if (wgn_UnitTrack.a(i,1) > l)
655  l = (int)wgn_UnitTrack.a(i,1);
656  }
657 
658  if (ola==0) /* no -1's so its not an ola type cluster */
659  {
660  l = ((int)lss.mean() < 7) ? 7 : (int)lss.mean();
661 
662  /* a list of SuffStats on for each point in the trajectory */
663  trajectory = new EST_SuffStats *[l];
664  width = wgn_VertexTrack.num_channels()+1;
665  for (j=0; j<l; j++)
666  trajectory[j] = new EST_SuffStats[width];
667 
668  for (pp=members.head(); pp != 0; pp=pp->next())
669  { /* for each unit */
670  i = members.item(pp);
671  m = (float)wgn_UnitTrack.a(i,1)/(float)l; /* find interpolation */
672  s = (int)wgn_UnitTrack.a(i,0); /* start point */
673  for (ti=0,n=0.0; ti<l; ti++,n+=m)
674  {
675  ni = (int)n; // hmm floor or nint ??
676  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
677  {
678  if (wgn_VertexFeats.a(0,j) > 0.0)
679  trajectory[ti][j] += wgn_VertexTrack.a(s+ni,j);
680  }
681  }
682  }
683 
684  /* find sum of sum of stddev for all coefs of all traj points */
685  stdss.reset();
686  for (ti=0; ti<l; ti++)
687  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
688  {
689  if (wgn_VertexFeats.a(0,j) > 0.0)
690  stdss += trajectory[ti][j].stddev();
691  }
692 
693  // This is sum of all stddev * samples
694  score = stdss.mean() * members.length();
695  }
696  else
697  { /* OLA model */
698  l1 = (l1ss.mean() < 10.0) ? 10 : (int)l1ss.mean();
699  l2 = (l2ss.mean() < 10.0) ? 10 : (int)l2ss.mean();
700  l = l1 + l2 + 1 + 1;
701 
702  /* a list of SuffStats on for each point in the trajectory */
703  trajectory = new EST_SuffStats *[l];
704  for (j=0; j<l; j++)
705  trajectory[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
706 
707  for (pp=members.head(); pp != 0; pp=pp->next())
708  { /* for each unit */
709  i = members.item(pp);
710  s1l = 0;
711  s = (int)wgn_UnitTrack.a(i,0); /* start point */
712  for (q=0; q<wgn_UnitTrack.a(i,1); q++)
713  if (wgn_VertexTrack.a(s+q,0) == -1.0)
714  {
715  s1l = q; /* printf("awb q is -1 at %d\n",q); */
716  break;
717  }
718  s2l = (int)wgn_UnitTrack.a(i,1) - (s1l + 2);
719  m1 = (float)(s1l)/(float)l1; /* find interpolation step */
720  m2 = (float)(s2l)/(float)l2; /* find interpolation step */
721  /* First half */
722  for (ti=0,n=0.0; s1l > 0 && ti<l1; ti++,n+=m1)
723  {
724  ni = s + (((int)n < s1l) ? (int)n : s1l - 1);
725  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
726  if (wgn_VertexFeats.a(0,j) > 0.0)
727  trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
728  }
729  ti = l1; /* do it explicitly in case s1l < 1 */
730  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
731  if (wgn_VertexFeats.a(0,j) > 0.0)
732  trajectory[ti][j] += -1;
733  /* Second half */
734  s += s1l+1;
735  for (ti++,n=0.0; s2l > 0 && ti<l-1; ti++,n+=m2)
736  {
737  ni = s + (((int)n < s2l) ? (int)n : s2l - 1);
738  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
739  if (wgn_VertexFeats.a(0,j) > 0.0)
740  trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
741  }
742  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
743  if (wgn_VertexFeats.a(0,j) > 0.0)
744  trajectory[ti][j] += -2;
745  }
746 
747  /* find sum of sum of stddev for all coefs of all traj points */
748  /* windowing the sums with a triangular weight window */
749  stdss.reset();
750  m = 1.0/(float)l1;
751  for (w=0.0,ti=0; ti<l1; ti++,w+=m)
752  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
753  if (wgn_VertexFeats.a(0,j) > 0.0)
754  stdss += trajectory[ti][j].stddev() * w;
755  m = 1.0/(float)l2;
756  for (w=1.0,ti++; ti<l-1; ti++,w-=m)
757  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
758  if (wgn_VertexFeats.a(0,j) > 0.0)
759  stdss += trajectory[ti][j].stddev() * w;
760 
761  // This is sum of all stddev * samples
762  score = stdss.mean() * members.length();
763  }
764  return score;
765 }
766 
767 static void part_to_ols_data(EST_FMatrix &X, EST_FMatrix &Y,
768  EST_IVector &included,
769  EST_StrList &feat_names,
770  const EST_IList &members,
771  const WVectorVector &d)
772 {
773  int m,n,p;
774  int w, xm=0;
775  EST_Litem *pp;
776  WVector *wv;
777 
778  w = wgn_dataset.width();
779  included.resize(w);
780  X.resize(members.length(),w);
781  Y.resize(members.length(),1);
782  feat_names.append("Intercept");
783  included[0] = TRUE;
784 
785  for (p=0,pp=members.head(); pp; p++,pp=pp->next())
786  {
787  n = members.item(pp);
788  if (n < 0)
789  {
790  p--;
791  continue;
792  }
793  wv = d(n);
794  Y.a_no_check(p,0) = (*wv)[0];
795  X.a_no_check(p,0) = 1;
796  for (m=1,xm=1; m < w; m++)
797  {
798  if (wgn_dataset.ftype(m) == wndt_float)
799  {
800  if (p == 0) // only do this once
801  {
802  feat_names.append(wgn_dataset.feat_name(m));
803  }
804  X.a_no_check(p,xm) = (*wv)[m];
805  included.a_no_check(xm) = FALSE;
806  included.a_no_check(xm) = TRUE;
807  xm++;
808  }
809  }
810  }
811 
812  included.resize(xm);
813  X.resize(p,xm);
814  Y.resize(p,1);
815 }
816 
817 float WImpurity::ols_impurity()
818 {
819  // Build an OLS model for the current data and measure it against
820  // the data itself and give a RMSE
821  EST_FMatrix X,Y;
822  EST_IVector included;
823  EST_FMatrix coeffs;
824  EST_StrList feat_names;
825  float best_score;
826  EST_FMatrix coeffsl;
827  EST_FMatrix pred;
828  float cor,rmse;
829 
830  // Load the sample members into matrices for ols
831  part_to_ols_data(X,Y,included,feat_names,members,*data);
832 
833  // Find the best ols model.
834  // Far too computationally expensive
835  // if (!stepwise_ols(X,Y,feat_names,0.0,coeffs,
836  // X,Y,included,best_score))
837  // return WGN_HUGE_VAL; // couldn't find a model
838 
839  // Non stepwise model
840  if (!robust_ols(X,Y,included,coeffsl))
841  {
842  // printf("no robust ols\n");
843  return WGN_HUGE_VAL;
844  }
845  ols_apply(X,coeffsl,pred);
846  ols_test(Y,pred,cor,rmse);
847  best_score = cor;
848 
849  printf("Impurity OLS X(%d,%d) Y(%d,%d) %f, %f, %f\n",
850  X.num_rows(),X.num_columns(),Y.num_rows(),Y.num_columns(),
851  rmse,cor,
852  1-best_score);
853  if (fabs(coeffsl[0]) > 10000)
854  {
855  // printf("weird sized Intercept %f\n",coeffsl[0]);
856  return WGN_HUGE_VAL;
857  }
858 
859  return (1-best_score) *members.length();
860 }
861 
862 float WImpurity::cluster_impurity()
863 {
864  // Find the mean distance between all members of the dataset
865  // Uses the global DistMatrix for distances between members of
866  // the cluster set. Distances are assumed to be symmetric thus only
867  // the bottom half of the distance matrix is filled
868  EST_Litem *pp, *q;
869  int i,j;
870  double dist;
871 
872  a.reset();
873  for (pp=members.head(); pp != 0; pp=pp->next())
874  {
875  i = members.item(pp);
876  for (q=pp->next(); q != 0; q=q->next())
877  {
878  j = members.item(q);
879  dist = (j < i ? wgn_DistMatrix.a_no_check(i,j) :
880  wgn_DistMatrix.a_no_check(j,i));
881  a+=dist; // cumulate for whole cluster
882  }
883  }
884 
885  // This is sum distance between cross product of members
886 // return a.sum();
887  if (a.samples() > 1)
888  return a.stddev() * a.samples();
889  else
890  return 0.0;
891 }
892 
893 float WImpurity::cluster_distance(int i)
894 {
895  // Distance this unit is from all others in this cluster
896  // in absolute standard deviations from the the mean.
897  float dist = cluster_member_mean(i);
898  float mdist = dist-a.mean();
899 
900  if (mdist == 0.0)
901  return 0.0;
902  else
903  return fabs((dist-a.mean())/a.stddev());
904 
905 }
906 
907 int WImpurity::in_cluster(int i)
908 {
909  // Would this be a member of this cluster?. Returns 1 if
910  // its distance is less than at least one other
911  float dist = cluster_member_mean(i);
912  EST_Litem *pp;
913 
914  for (pp=members.head(); pp != 0; pp=pp->next())
915  {
916  if (dist < cluster_member_mean(members.item(pp)))
917  return 1;
918  }
919  return 0;
920 }
921 
922 float WImpurity::cluster_ranking(int i)
923 {
924  // Position in ranking closest to centre
925  float dist = cluster_distance(i);
926  EST_Litem *pp;
927  int ranking = 1;
928 
929  for (pp=members.head(); pp != 0; pp=pp->next())
930  {
931  if (dist >= cluster_distance(members.item(pp)))
932  ranking++;
933  }
934 
935  return ranking;
936 }
937 
938 float WImpurity::cluster_member_mean(int i)
939 {
940  // Returns the mean difference between this member and all others
941  // in cluster
942  EST_Litem *q;
943  int j,n;
944  double dist,sum;
945 
946  for (sum=0.0,n=0,q=members.head(); q != 0; q=q->next())
947  {
948  j = members.item(q);
949  if (i != j)
950  {
951  dist = (j < i ? wgn_DistMatrix(i,j) : wgn_DistMatrix(j,i));
952  sum += dist;
953  n++;
954  }
955  }
956 
957  return ( n == 0 ? 0.0 : sum/n );
958 }
959 
960 void WImpurity::cumulate(const float pv,double count)
961 {
962  // Cumulate data for impurity calculation
963 
964  if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
965  {
966  t = wnim_cluster;
967  members.append((int)pv);
968  }
969  else if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
970  {
971  t = wnim_ols;
972  members.append((int)pv);
973  }
974  else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
975  {
976  t = wnim_vector;
977 
978  // AUP: Implement counts in vectors
979  members.append((int)pv);
980  member_counts.append((float)count);
981  }
982  else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
983  {
984  t = wnim_trajectory;
985  members.append((int)pv);
986  }
987  else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
988  {
989  if (t == wnim_unset)
990  p.init(&wgn_discretes[wgn_dataset.ftype(wgn_predictee)]);
991  t = wnim_class;
992  p.cumulate((int)pv,count);
993  }
994  else if (wgn_dataset.ftype(wgn_predictee) == wndt_binary)
995  {
996  t = wnim_float;
997  a.cumulate((int)pv,count);
998  }
999  else if (wgn_dataset.ftype(wgn_predictee) == wndt_float)
1000  {
1001  t = wnim_float;
1002  a.cumulate(pv,count);
1003  }
1004  else
1005  {
1006  wagon_error("WImpurity: cannot cumulate EST_Val type");
1007  }
1008 }
1009 
1010 ostream & operator <<(ostream &s, WImpurity &imp)
1011 {
1012  int j,i;
1013  EST_SuffStats b;
1014 
1015  if (imp.t == wnim_float)
1016  s << "(" << imp.a.stddev() << " " << imp.a.mean() << ")";
1017  else if (imp.t == wnim_vector)
1018  {
1019  EST_Litem *p, *countp;
1020  s << "((";
1021  imp.vector_impurity();
1022  if (wgn_vertex_output == "mean") //output means
1023  {
1024  for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1025  {
1026  b.reset();
1027  for (p=imp.members.head(), countp=imp.member_counts.head(); p != 0; p=p->next(), countp=countp->next())
1028  {
1029  // Accumulate the members with their counts
1030  b.cumulate(wgn_VertexTrack.a(imp.members.item(p),j), imp.member_counts.item(countp));
1031  //b += wgn_VertexTrack.a(imp.members.item(p),j);
1032  }
1033  s << "(" << b.mean() << " ";
1034  if (isfinite(b.stddev()))
1035  s << b.stddev() << ")";
1036  else
1037  s << "0.001" << ")";
1038  if (j+1<wgn_VertexTrack.num_channels())
1039  s << " ";
1040  }
1041  }
1042  else /* output best in the cluster */
1043  {
1044  /* print out vector closest to center, rather than average */
1045  double best = WGN_HUGE_VAL;
1046  double x,d;
1047  int bestp = 0;
1048  EST_SuffStats *cs;
1049 
1050  cs = new EST_SuffStats [wgn_VertexTrack.num_channels()+1];
1051 
1052  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
1053  if (wgn_VertexFeats.a(0,j) > 0.0)
1054  {
1055  cs[j].reset();
1056  for (p=imp.members.head(); p != 0; p=p->next())
1057  {
1058  cs[j] += wgn_VertexTrack.a(imp.members.item(p),j);
1059  }
1060  }
1061 
1062  for (p=imp.members.head(); p != 0; p=p->next())
1063  {
1064  for (x=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
1065  if (wgn_VertexFeats.a(0,j) > 0.0)
1066  {
1067  d = (wgn_VertexTrack.a(imp.members.item(p),j)-cs[j].mean())
1068  /* / cs[j].stddev() */ ; /* seems worse 061218 */
1069  x += d*d;
1070  }
1071  if (x < best)
1072  {
1073  bestp = imp.members.item(p);
1074  best = x;
1075  }
1076  }
1077  for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1078  {
1079  s << "( ";
1080  s << wgn_VertexTrack.a(bestp,j);
1081  // s << " 0 "; // fake stddev
1082  s << " ";
1083  if (isfinite(cs[j].stddev()))
1084  s << cs[j].stddev();
1085  else
1086  s << "0";
1087  s << " ) ";
1088  if (j+1<wgn_VertexTrack.num_channels())
1089  s << " ";
1090  }
1091 
1092  delete [] cs;
1093  }
1094  s << ") ";
1095  s << imp.a.mean() << ")";
1096  }
1097  else if (imp.t == wnim_trajectory)
1098  {
1099  s << "((";
1100  imp.trajectory_impurity();
1101  for (i=0; i<imp.l; i++)
1102  {
1103  s << "(";
1104  for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1105  {
1106  s << "(" << imp.trajectory[i][j].mean() << " "
1107  << imp.trajectory[i][j].stddev() << " " << ")";
1108  }
1109  s << ")\n";
1110  }
1111  s << ") ";
1112  // Mean of cross product of distances (cluster score)
1113  s << imp.a.mean() << ")";
1114  }
1115  else if (imp.t == wnim_cluster)
1116  {
1117  EST_Litem *p;
1118  s << "((";
1119  for (p=imp.members.head(); p != 0; p=p->next())
1120  {
1121  // Ouput cluster member and its mean distance to others
1122  s << "(" << imp.members.item(p) << " " <<
1123  imp.cluster_member_mean(imp.members.item(p)) << ")";
1124  if (p->next() != 0)
1125  s << " ";
1126  }
1127  s << ") ";
1128  // Mean of cross product of distances (cluster score)
1129  s << imp.a.mean() << ")";
1130  }
1131  else if (imp.t == wnim_ols)
1132  {
1133  /* Output intercept, feature names and coefficients for ols model */
1134  EST_FMatrix X,Y;
1135  EST_IVector included;
1136  EST_FMatrix coeffs;
1137  EST_StrList feat_names;
1138  EST_FMatrix coeffsl;
1139  EST_FMatrix pred;
1140  float cor=0.0,rmse;
1141 
1142  s << "((";
1143  // Load the sample members into matrices for ols
1144  part_to_ols_data(X,Y,included,feat_names,imp.members,*(imp.data));
1145  if (!robust_ols(X,Y,included,coeffsl))
1146  {
1147  printf("no robust ols\n");
1148  // shouldn't happen
1149  }
1150  else
1151  {
1152  ols_apply(X,coeffsl,pred);
1153  ols_test(Y,pred,cor,rmse);
1154  for (i=0; i<coeffsl.num_rows(); i++)
1155  {
1156  s << "(";
1157  s << feat_names.nth(i);
1158  s << " ";
1159  s << coeffsl[i];
1160  s << ") ";
1161  }
1162  }
1163 
1164  // Mean of cross product of distances (cluster score)
1165  s << ") " << cor << ")";
1166  }
1167  else if (imp.t == wnim_class)
1168  {
1169  EST_Litem *i;
1170  EST_String name;
1171  double prob;
1172 
1173  s << "(";
1174  for (i=imp.p.item_start(); !imp.p.item_end(i); i=imp.p.item_next(i))
1175  {
1176  imp.p.item_prob(i,name,prob);
1177  s << "(" << name << " " << prob << ") ";
1178  }
1179  s << imp.p.most_probable(&prob) << ")";
1180  }
1181  else
1182  s << "([WImpurity unset])";
1183 
1184  return s;
1185 }
1186 
1187 
1188 
1189