Edinburgh Speech Tools  2.4-release
 All Classes Functions Variables Typedefs Enumerations Enumerator Friends Pages
wfst_train.cc
1 /*************************************************************************/
2 /* */
3 /* Language Technologies Institute */
4 /* Carnegie Mellon University */
5 /* Copyright (c) 1999-2003 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* CARNEGIE MELLON UNIVERSITY AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL CARNEGIE MELLON UNIVERSITY NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black */
34 /* Date : October 1999 */
35 /*-----------------------------------------------------------------------*/
36 /* */
37 /* Training method to split states of existing WFST based on data to */
38 /* optimize entropy */
39 /* */
40 /* Confusing as this has nothing to do with the modelling */
41 /* technique known as "maximum entropy" */
42 /* */
43 /*=======================================================================*/
44 #include <iostream>
45 #include <cstdlib>
46 #include "EST_WFST.h"
47 #include "wfst_aux.h"
48 #include "EST_Token.h"
49 #include "EST_simplestats.h"
50 
51 VAL_REGISTER_TYPE_NODEL(trans,EST_WFST_Transition)
52 SIOD_REGISTER_CLASS(trans,EST_WFST_Transition)
53 VAL_REGISTER_CLASS(pdf,EST_DiscreteProbDistribution)
54 SIOD_REGISTER_CLASS(pdf,EST_DiscreteProbDistribution)
55 
56 static LISP *find_state_usage(EST_WFST &wfst, LISP data);
57 static double entropy(const EST_WFST_State *s);
58 static LISP *find_state_entropies(const EST_WFST &wfst, LISP *data);
59 EST_WFST_Transition *find_best_trans_split(EST_WFST &wfst,
60  int split_state,
61  LISP *data);
62 static LISP find_best_split(EST_WFST &wfst,
63  int split_state_name,
64  LISP *data);
65 static double find_score_if_split(EST_WFST &wfst,
66  int fromstate,
67  EST_WFST_Transition *trans,
68  LISP *data);
69 static LISP find_split_pdfs(EST_WFST &wfst,
70  int split_state_name,
71  LISP *data,
73 static double score_pdf_combine(EST_DiscreteProbDistribution &a,
76 #if 0
77 static void split_state(EST_WFST &wfst, EST_WFST_Transition *trans);
78 #endif
79 static void split_state(EST_WFST &wfst, LISP trans_list, int ostate);
80 
81 LISP load_string_data(EST_WFST &wfst,EST_String &filename)
82 {
83  // Load in sentences into data table, assume sentence per line
84  EST_TokenStream ts;
85  LISP ss = NIL;
86  EST_String t;
87  int id;
88  int i,j;
89 
90  if (ts.open(filename) == -1)
91  EST_error("wfst_train: failed to read data from \"%s\"",
92  (const char *)filename);
93 
94  i = 0;
95  j = 0;
96  while (!ts.eof())
97  {
98  LISP s = NIL;
99  do
100  {
101  t = (EST_String)ts.get();
102  id = wfst.in_symbol(t);
103  if (id == -1)
104  {
105  cerr << "wfst_train: data contains unknown symbol \"" <<
106  t << "\"" << endl;
107  }
108  s = cons(flocons(id),s);
109  j++;
110  }
111  while (!ts.eoln() && !ts.eof());
112  i++;
113  ss = cons(reverse(s),ss);
114  }
115 
116  printf("wfst_train: loaded %d lines of %d tokens\n",
117  i,j);
118 
119  return reverse(ss);
120 }
121 
122 static LISP *find_state_usage(EST_WFST &wfst, LISP data)
123 {
124  // Builds list of states, and which data points the represent
125  LISP *state_data = new LISP[wfst.num_states()];
126  static LISP ddd = NIL;
127  int s,i,id;
128  LISP d,w;
129  EST_WFST_Transition *trans;
130 // EST_Litem *tp;
131 
132  if (ddd == NIL)
133  gc_protect(&ddd);
134 
135  ddd = NIL;
136 
137  wfst.start_cumulate(); // zero existing weights
138 
139  for (i=0; i < wfst.num_states(); i++)
140  {
141  state_data[i] = NIL;
142  ddd = cons(state_data[i],ddd);
143 // // smoothing
144 // for (tp=wfst.state(i)->transitions.head(); tp != 0; tp = tp->next())
145 // wfst.state(i)->transitions(tp)->set_weight(1);
146  }
147 
148  for (i=0,d=data; d; d=cdr(d),i++)
149  {
150  s = wfst.start_state();
151  for (w=car(d); w; w=cdr(w))
152  {
153  state_data[s] = cons(w,state_data[s]);
154  id = get_c_int(car(w));
155  trans = wfst.find_transition(s,id,id);
156  if (!trans)
157  {
158  printf("sentence %d not in language, skipping\n",i);
159  continue;
160  }
161  else
162  {
163  trans->set_weight(trans->weight()+1);
164  s = trans->state();
165  }
166  }
167  }
168 
169  wfst.stop_cumulate();
170  return state_data;
171 }
172 
173 static double entropy(const EST_WFST_State *s)
174 {
175  double sentropy,w;
176  EST_Litem *tp;
177  for (sentropy=0,tp=s->transitions.head(); tp != 0; tp = tp->next())
178  {
179  w = s->transitions(tp)->weight(); /* the probability */
180  if (w > 0)
181  sentropy += w * log(w);
182  }
183  return -1 * sentropy;
184 }
185 
186 void wfst_train(EST_WFST &wfst, LISP data)
187 {
188  LISP *state_data;
189  LISP *state_entropies;
190  LISP best_trans_list = NIL;
191  int c=0,i, max_entropy_state;
192  gc_protect(&data);
193 
194  while (1)
195  {
196  // Build table of state to points in data, and cumulate transitions
197  state_data = find_state_usage(wfst,data);
198 
199  /* find entropy for each state (sorted) */
200  state_entropies = find_state_entropies(wfst,state_data);
201 
202  max_entropy_state = -1;
203  for (i=0; i < wfst.num_states(); i++)
204  {
205 // double me = (double)get_c_float(car(state_entropies[i]));
206  max_entropy_state = get_c_int(cdr(state_entropies[i]));
207 // printf("trying %d %g\n",max_entropy_state,me);
208 
209 // best_trans = find_best_trans_split(wfst,max_entropy_state,
210 // state_data);
211  best_trans_list = find_best_split(wfst,max_entropy_state,
212  state_data);
213  if (best_trans_list != NIL)
214  break;
215 // else
216 // printf("No best trans\n");
217  }
218  delete [] state_entropies;
219 
220  if (max_entropy_state == -1)
221  {
222  printf("No new max_entropy state\n");
223  break;
224  }
225  if (best_trans_list == NIL)
226  {
227  printf("No best_trans in max_entropy state\n");
228  break;
229  }
230 
231  /* for each transition *entering* max_entropy_state */
232  /* find entropy if it were split */
233  /* find best split */
234 
235  /* print stats */
236  /* some sort of stop check */
237  c++;
238  printf("c is %d\n",c);
239  if (c > 5000)
240  {
241  printf("reached cycle end %d\n",c);
242  break;
243  }
244  /* split on best split */
245  split_state(wfst, best_trans_list, max_entropy_state);
246 
247  if ((c % 100) == 0)
248  {
249  EST_String chkpntname = "chkpnt";
250  char bbb[7];
251  sprintf(bbb,"%03d",c);
252  wfst.save(chkpntname+bbb+".wfst");
253  }
254 
255  delete [] state_data;
256  user_gc(NIL);
257  }
258 }
259 
260 static int me_compare_function(const void *a, const void *b)
261 {
262  LISP la;
263  LISP lb;
264  la = *(LISP *)a;
265  lb = *(LISP *)b;
266 
267  float fa = get_c_float(car(la));
268  float fb = get_c_float(car(lb));
269 
270  if (fa < fb)
271  return 1;
272  else if (fa == fb)
273  return 0;
274  else
275  return -1;
276 }
277 
278 static LISP *find_state_entropies(const EST_WFST &wfst, LISP *data)
279 {
280  double all_entropy = 0;
281  int i;
282  double sentropy;
283  LISP *slist = new LISP[wfst.num_states()];
284  static LISP ddd = NIL;
285 
286  if (ddd == NIL)
287  gc_protect(&ddd);
288  ddd = NIL;
289 
290  for (i=0; i < wfst.num_states(); i++)
291  {
292  const EST_WFST_State *s = wfst.state(i);
293  sentropy = entropy(s);
294 // printf("dlength is %d %d\n",i,siod_llength(data[i]));
295  all_entropy += sentropy * siod_llength(data[i]);
296  slist[i] = cons(flocons(sentropy),flocons(i));
297  ddd = cons(slist[i],ddd);
298  }
299  printf("average entropy is %g\n",all_entropy/i);
300 
301  qsort(slist,wfst.num_states(),sizeof(LISP),me_compare_function);
302 
303  return slist;
304 }
305 
306 static LISP find_best_split(EST_WFST &wfst,
307  int split_state_name,
308  LISP *data)
309 {
310  // Find the best partition of incoming translations that
311  // minimises entropy
312  EST_DiscreteProbDistribution pdf_all(&wfst.in_symbols());
313  EST_DiscreteProbDistribution *a_pdf, *b_pdf;
314  LISP splits,s,dd,r;
315  LISP *ssplits;
316  gc_protect(&splits);
317  EST_String sname;
318  int b,best_b;
319  EST_Litem *i;
320  int num_pdfs;
321  double best_score, score, sfreq;
322 
323  for (dd = data[split_state_name]; dd; dd = cdr(dd))
324  pdf_all.cumulate(get_c_int(car(car(dd))));
325  splits = find_split_pdfs(wfst,split_state_name,data,pdf_all);
326  if (siod_llength(splits) < 2)
327  return NIL;
328  ssplits = new LISP[siod_llength(splits)];
329  for (num_pdfs=0,s=splits; s != NIL; s=cdr(s),num_pdfs++)
330  ssplits[num_pdfs] = car(s);
331 
332  qsort(ssplits,num_pdfs,sizeof(LISP),me_compare_function);
333  // Combine trans pdfs in pdfs until more combination doesn't improve
334  while (1)
335  {
336 
337  best_score = get_c_float(car(ssplits[0]));
338  best_b = -1;
339  a_pdf = pdf(car(cdr(cdr(ssplits[0]))));
340  for (b=1; b < num_pdfs; b++)
341  {
342  if (ssplits[b] == NIL)
343  continue;
344  score = score_pdf_combine(*a_pdf,*pdf(car(cdr(cdr(ssplits[b])))),
345  pdf_all);
346  if (score < best_score)
347  {
348  best_score = score;
349  best_b = b;
350  }
351  }
352 
353  // combine a and b
354  if (best_b == -1)
355  break;
356  else
357  {
358  // combine a and b
359  // Add trans to 0
360  setcar(cdr(ssplits[0]),
361  append(car(cdr(ssplits[0])),
362  car(cdr(ssplits[best_b]))));
363  setcar(ssplits[0], flocons(best_score));
364  // Update 0's pdf with values from best_b's
365  b_pdf = pdf(car(cdr(cdr(ssplits[best_b]))));
366  for (i=b_pdf->item_start(); !b_pdf->item_end(i);
367  i = b_pdf->item_next(i))
368  {
369  b_pdf->item_freq(i,sname,sfreq);
370  a_pdf->cumulate(i,sfreq);
371  }
372  ssplits[best_b] = NIL;
373  }
374 
375  }
376 
377  printf("score %g ",(double)get_c_float(car(ssplits[0])));
378  for (dd=car(cdr(ssplits[0])); dd; dd=cdr(dd))
379  printf("%s ",(const char *)wfst.in_symbol(trans(car(dd))->in_symbol()));
380  printf("\n");
381  gc_unprotect(&splits);
382  r = car(cdr(ssplits[0]));
383  delete [] ssplits;
384  return r;
385 }
386 
387 static double score_pdf_combine(EST_DiscreteProbDistribution &a,
390 {
391  // Find score of (a+b) vs (all-(a+b))
393  EST_DiscreteProbDistribution all_but_ab(all);
394  EST_Litem *i;
395  EST_String sname;
396  double sfreq, score;
397  for (i=b.item_start(); !b.item_end(i);
398  i = b.item_next(i))
399  {
400  b.item_freq(i,sname,sfreq);
401  ab.cumulate(i,sfreq);
402  }
403 
404  for (i=ab.item_start(); !ab.item_end(i);
405  i = ab.item_next(i))
406  {
407  ab.item_freq(i,sname,sfreq);
408  all_but_ab.cumulate(i,-1*sfreq);
409  }
410 
411  score = (ab.entropy() * ab.samples()) +
412  (all_but_ab.entropy() * all_but_ab.samples());
413 
414  return score;
415 
416 }
417 
418 static LISP find_split_pdfs(EST_WFST &wfst,
419  int split_state_name,
420  LISP *data,
422 {
423  // Find following pdfs for each incoming transition as if they where
424  // split to a new state
425  int i,id, in;
426  EST_Litem *tp;
427  LISP pdfs = NIL,dd,ttt,p,t;
429  double value;
430 
431  for (i=0; i < wfst.num_states(); i++)
432  {
433  const EST_WFST_State *s = wfst.state(i);
434  for (tp=s->transitions.head(); tp != 0; tp = tp->next())
435  {
436  if ((s->transitions(tp)->state() == split_state_name)
437  && (s->transitions(tp)->weight() > 0))
438  {
439  in = s->transitions(tp)->in_symbol();
442  for (dd = data[i]; dd; dd = cdr(dd))
443  {
444  id = get_c_int(car(car(dd)));
445  if (id == in)
446  { // This one would go to the new state so we count it
447  if (cdr(car(dd))) // not end of data string
448  pdf->cumulate(get_c_int(car(cdr(car(dd)))));
449  }
450  }
451  // value, list of trans, pdf
452  value = score_pdf_combine(*pdf,empty,pdf_all);
453  if ((value > 0) && // ignore transitions with no data
454  (pdf->samples() > 10))// and those with only a few data pnts
455  {
456  t = siod(s->transitions(tp));
457  p = siod(pdf);
458  ttt = cons(flocons(value),
459  cons(cons(t,NIL),
460  cons(p,NIL)));
461  pdfs = cons(ttt,pdfs);
462  }
463  else
464  delete pdf;
465  }
466  }
467  }
468  return pdfs;
469 }
470 
471 EST_WFST_Transition *find_best_trans_split(EST_WFST &wfst,
472  int split_state_name,
473  LISP *data)
474 {
475  EST_Litem *tp;
476  EST_WFST_Transition *best_trans = 0;
477  const EST_WFST_State *split_state = wfst.state(split_state_name);
478  double best_score,bb;
479  int i;
480 
481  best_score = entropy(split_state)*siod_llength(data[split_state_name]);
482 // printf("unsplit score %g\n",best_score);
483 
484  /* For each transition going to split_state */
485  for (i=1; i < wfst.num_states(); i++)
486  {
487  const EST_WFST_State *s = wfst.state(i);
488  for (tp=s->transitions.head(); tp != 0; tp = tp->next())
489  {
490  if ((wfst.state(s->transitions(tp)->state()) == split_state) &&
491  (s->transitions(tp)->weight() > 0))
492  {
493  bb = find_score_if_split(wfst,i,s->transitions(tp),data);
494 // cout << i << " "
495 // << wfst.in_symbol(s->transitions(tp)->in_symbol()) << " "
496 // << s->transitions(tp)->state() << " " << bb << endl;
497  if (bb == -1) /* didn't find a split */
498  continue;
499  if (bb < best_score)
500  {
501  best_score = bb;
502  best_trans = s->transitions(tp);
503  }
504  }
505  }
506  }
507 
508  if (best_trans)
509  cout << "best " << wfst.in_symbol(best_trans->in_symbol()) << " "
510  << best_trans->weight() << " "
511  << best_trans->state() << " " << best_score << endl;
512  return best_trans;
513 }
514 
515 static double find_score_if_split(EST_WFST &wfst,
516  int fromstate,
517  EST_WFST_Transition *trans,
518  LISP *data)
519 {
520  double ent_split;
521  double ent_remain;
522  double score;
523  EST_DiscreteProbDistribution pdf_split(&wfst.in_symbols());
524  EST_DiscreteProbDistribution pdf_remain(&wfst.in_symbols());
525  int in, tostate, id;
526  EST_Litem *i;
527  double sfreq;
528  EST_String sname;
529 
530  ent_split = ent_remain = 32*32*32*32;
531  LISP dd;
532 
533 // printf("considering %d %s %g %d\n",
534 // fromstate,
535 // (const char *)wfst.in_symbol(trans->in_symbol()),
536 // trans->weight(),
537 // trans->state());
538 
539  /* find entropy of possible new state */
540  /* for each data point through fromstate */
541  in = trans->in_symbol();
542  for (dd = data[fromstate]; dd; dd = cdr(dd))
543  {
544  id = get_c_int(car(car(dd)));
545  if (id == in)
546  { // This one would go to the new state so we count it
547  if (cdr(car(dd))) // not end of data string
548  pdf_split.cumulate(get_c_int(car(cdr(car(dd)))));
549  }
550  }
551  if (pdf_split.samples() > 0)
552  ent_split = pdf_split.entropy();
553  /* find entropy of old state minus trans into it */
554  tostate = trans->state();
555  // Actually only need to do this once per state
556  for (dd = data[tostate]; dd; dd = cdr(dd))
557  pdf_remain.cumulate(get_c_int(car(car(dd))));
558  // Subtract the bit thats split
559  for (i=pdf_split.item_start(); !pdf_split.item_end(i);
560  i = pdf_split.item_next(i))
561  {
562  pdf_split.item_freq(i,sname,sfreq);
563  pdf_remain.cumulate(i,-1*sfreq);
564  }
565  if (pdf_remain.samples() > 0)
566  ent_remain = pdf_remain.entropy();
567 
568  if ((pdf_remain.samples() == 0) ||
569  (pdf_split.samples() == 0))
570  return -1;
571 
572  score = (ent_remain * pdf_remain.samples()) +
573  (ent_split * pdf_split.samples());
574 // printf("tostate %d remain %g %d split %g %d score %g\n",
575 // tostate, ent_remain, (int)pdf_remain.samples(),
576 // ent_split, (int)pdf_split.samples(), score);
577 
578  return score;
579 }
580 
581 #if 0
582 static void split_state(EST_WFST &wfst, EST_WFST_Transition *trans)
583 {
584  /* Split off a new state for given trans. Add transitions */
585  /* to this new state for all transitions in (old) state trans */
586  /* goes to */
587  EST_Litem *tp;
588  int nstate = wfst.add_state(wfst_final);
589  int ostate = trans->state();
590 
591 // printf("state %d entropy %g\n",ostate,entropy(wfst.state(ostate)));
592  /* must be done before adding the new transitions to nstate */
593  trans->set_state(nstate);
594 
595  for (tp=wfst.state(ostate)->transitions.head(); tp != 0; tp = tp->next())
596  {
597  wfst.state_non_const(nstate)->
598  add_transition(0.0, /* weight will be filled in later*/
599  wfst.state(ostate)->transitions(tp)->state(),
600  wfst.state(ostate)->transitions(tp)->in_symbol(),
601  wfst.state(ostate)->transitions(tp)->out_symbol());
602 
603  }
604 // printf(" nstate %d entropy %g\n",nstate,entropy(wfst.state(nstate)));
605 // printf(" ostate %d entropy %g\n",ostate,entropy(wfst.state(ostate)));
606 
607 }
608 #endif
609 
610 static void split_state(EST_WFST &wfst, LISP trans_list, int ostate)
611 {
612  /* Split off a new state for given trans. Add transitions */
613  /* to this new state for all transitions in (old) state trans */
614  /* goes to */
615  EST_Litem *tp;
616  int nstate = wfst.add_state(wfst_final);
617  LISP t;
618 
619  /* must be done before adding the new transitions to nstate */
620  for (t=trans_list; t; t=cdr(t))
621  trans(car(t))->set_state(nstate);
622 
623  for (tp=wfst.state(ostate)->transitions.head(); tp != 0; tp = tp->next())
624  {
625  wfst.state_non_const(nstate)->
626  add_transition(0.0, /* weight will be filled in later*/
627  wfst.state(ostate)->transitions(tp)->state(),
628  wfst.state(ostate)->transitions(tp)->in_symbol(),
629  wfst.state(ostate)->transitions(tp)->out_symbol());
630 
631  }
632 }
633