48 #include "EST_Token.h"
49 #include "EST_simplestats.h"
56 static LISP *find_state_usage(
EST_WFST &wfst, LISP data);
58 static LISP *find_state_entropies(const
EST_WFST &wfst, LISP *data);
62 static LISP find_best_split(
EST_WFST &wfst,
65 static
double find_score_if_split(
EST_WFST &wfst,
69 static LISP find_split_pdfs(
EST_WFST &wfst,
79 static void split_state(
EST_WFST &wfst, LISP trans_list,
int ostate);
90 if (ts.
open(filename) == -1)
91 EST_error(
"wfst_train: failed to read data from \"%s\"",
92 (
const char *)filename);
105 cerr <<
"wfst_train: data contains unknown symbol \"" <<
108 s = cons(flocons(
id),s);
111 while (!ts.
eoln() && !ts.
eof());
113 ss = cons(reverse(s),ss);
116 printf(
"wfst_train: loaded %d lines of %d tokens\n",
122 static LISP *find_state_usage(
EST_WFST &wfst, LISP data)
125 LISP *state_data =
new LISP[wfst.num_states()];
126 static LISP ddd = NIL;
139 for (i=0; i < wfst.num_states(); i++)
142 ddd = cons(state_data[i],ddd);
148 for (i=0,d=data; d; d=cdr(d),i++)
150 s = wfst.start_state();
151 for (w=car(d); w; w=cdr(w))
153 state_data[s] = cons(w,state_data[s]);
154 id = get_c_int(car(w));
158 printf(
"sentence %d not in language, skipping\n",i);
163 trans->set_weight(trans->weight()+1);
177 for (sentropy=0,tp=s->transitions.head(); tp != 0; tp = tp->next())
179 w = s->transitions(tp)->weight();
181 sentropy += w * log(w);
183 return -1 * sentropy;
186 void wfst_train(
EST_WFST &wfst, LISP data)
189 LISP *state_entropies;
190 LISP best_trans_list = NIL;
191 int c=0,i, max_entropy_state;
197 state_data = find_state_usage(wfst,data);
200 state_entropies = find_state_entropies(wfst,state_data);
202 max_entropy_state = -1;
203 for (i=0; i < wfst.num_states(); i++)
206 max_entropy_state = get_c_int(cdr(state_entropies[i]));
211 best_trans_list = find_best_split(wfst,max_entropy_state,
213 if (best_trans_list != NIL)
218 delete [] state_entropies;
220 if (max_entropy_state == -1)
222 printf(
"No new max_entropy state\n");
225 if (best_trans_list == NIL)
227 printf(
"No best_trans in max_entropy state\n");
238 printf(
"c is %d\n",c);
241 printf(
"reached cycle end %d\n",c);
245 split_state(wfst, best_trans_list, max_entropy_state);
251 sprintf(bbb,
"%03d",c);
252 wfst.
save(chkpntname+bbb+
".wfst");
255 delete [] state_data;
260 static int me_compare_function(
const void *a,
const void *b)
267 float fa = get_c_float(car(la));
268 float fb = get_c_float(car(lb));
278 static LISP *find_state_entropies(
const EST_WFST &wfst, LISP *data)
280 double all_entropy = 0;
283 LISP *slist =
new LISP[wfst.num_states()];
284 static LISP ddd = NIL;
290 for (i=0; i < wfst.num_states(); i++)
293 sentropy = entropy(s);
295 all_entropy += sentropy * siod_llength(data[i]);
296 slist[i] = cons(flocons(sentropy),flocons(i));
297 ddd = cons(slist[i],ddd);
299 printf(
"average entropy is %g\n",all_entropy/i);
301 qsort(slist,wfst.num_states(),
sizeof(LISP),me_compare_function);
306 static LISP find_best_split(
EST_WFST &wfst,
307 int split_state_name,
321 double best_score, score, sfreq;
323 for (dd = data[split_state_name]; dd; dd = cdr(dd))
324 pdf_all.cumulate(get_c_int(car(car(dd))));
325 splits = find_split_pdfs(wfst,split_state_name,data,pdf_all);
326 if (siod_llength(splits) < 2)
328 ssplits =
new LISP[siod_llength(splits)];
329 for (num_pdfs=0,s=splits; s != NIL; s=cdr(s),num_pdfs++)
330 ssplits[num_pdfs] = car(s);
332 qsort(ssplits,num_pdfs,
sizeof(LISP),me_compare_function);
337 best_score = get_c_float(car(ssplits[0]));
339 a_pdf = pdf(car(cdr(cdr(ssplits[0]))));
340 for (b=1; b < num_pdfs; b++)
342 if (ssplits[b] == NIL)
344 score = score_pdf_combine(*a_pdf,*pdf(car(cdr(cdr(ssplits[b])))),
346 if (score < best_score)
360 setcar(cdr(ssplits[0]),
361 append(car(cdr(ssplits[0])),
362 car(cdr(ssplits[best_b]))));
363 setcar(ssplits[0], flocons(best_score));
365 b_pdf = pdf(car(cdr(cdr(ssplits[best_b]))));
366 for (i=b_pdf->item_start(); !b_pdf->item_end(i);
367 i = b_pdf->item_next(i))
369 b_pdf->item_freq(i,sname,sfreq);
370 a_pdf->cumulate(i,sfreq);
372 ssplits[best_b] = NIL;
377 printf(
"score %g ",(
double)get_c_float(car(ssplits[0])));
378 for (dd=car(cdr(ssplits[0])); dd; dd=cdr(dd))
379 printf(
"%s ",(
const char *)wfst.
in_symbol(trans(car(dd))->in_symbol()));
381 gc_unprotect(&splits);
382 r = car(cdr(ssplits[0]));
401 ab.cumulate(i,sfreq);
404 for (i=ab.item_start(); !ab.item_end(i);
407 ab.item_freq(i,sname,sfreq);
408 all_but_ab.cumulate(i,-1*sfreq);
411 score = (ab.entropy() * ab.samples()) +
412 (all_but_ab.entropy() * all_but_ab.samples());
418 static LISP find_split_pdfs(
EST_WFST &wfst,
419 int split_state_name,
427 LISP pdfs = NIL,dd,ttt,p,t;
431 for (i=0; i < wfst.num_states(); i++)
434 for (tp=s->transitions.head(); tp != 0; tp = tp->next())
436 if ((s->transitions(tp)->state() == split_state_name)
437 && (s->transitions(tp)->weight() > 0))
439 in = s->transitions(tp)->in_symbol();
442 for (dd = data[i]; dd; dd = cdr(dd))
444 id = get_c_int(car(car(dd)));
448 pdf->
cumulate(get_c_int(car(cdr(car(dd)))));
452 value = score_pdf_combine(*pdf,empty,pdf_all);
456 t = siod(s->transitions(tp));
458 ttt = cons(flocons(value),
461 pdfs = cons(ttt,pdfs);
472 int split_state_name,
478 double best_score,bb;
481 best_score = entropy(split_state)*siod_llength(data[split_state_name]);
485 for (i=1; i < wfst.num_states(); i++)
488 for (tp=s->transitions.head(); tp != 0; tp = tp->next())
490 if ((wfst.
state(s->transitions(tp)->state()) == split_state) &&
491 (s->transitions(tp)->weight() > 0))
493 bb = find_score_if_split(wfst,i,s->transitions(tp),data);
502 best_trans = s->transitions(tp);
509 cout <<
"best " << wfst.
in_symbol(best_trans->in_symbol()) <<
" "
510 << best_trans->weight() <<
" "
511 << best_trans->state() <<
" " << best_score << endl;
515 static double find_score_if_split(
EST_WFST &wfst,
530 ent_split = ent_remain = 32*32*32*32;
541 in = trans->in_symbol();
542 for (dd = data[fromstate]; dd; dd = cdr(dd))
544 id = get_c_int(car(car(dd)));
548 pdf_split.cumulate(get_c_int(car(cdr(car(dd)))));
551 if (pdf_split.samples() > 0)
552 ent_split = pdf_split.entropy();
554 tostate = trans->state();
556 for (dd = data[tostate]; dd; dd = cdr(dd))
557 pdf_remain.cumulate(get_c_int(car(car(dd))));
559 for (i=pdf_split.item_start(); !pdf_split.item_end(i);
560 i = pdf_split.item_next(i))
562 pdf_split.item_freq(i,sname,sfreq);
563 pdf_remain.cumulate(i,-1*sfreq);
565 if (pdf_remain.samples() > 0)
566 ent_remain = pdf_remain.entropy();
568 if ((pdf_remain.samples() == 0) ||
569 (pdf_split.samples() == 0))
572 score = (ent_remain * pdf_remain.samples()) +
573 (ent_split * pdf_split.samples());
589 int ostate = trans->state();
593 trans->set_state(nstate);
595 for (tp=wfst.
state(ostate)->transitions.head(); tp != 0; tp = tp->next())
599 wfst.
state(ostate)->transitions(tp)->state(),
600 wfst.
state(ostate)->transitions(tp)->in_symbol(),
601 wfst.
state(ostate)->transitions(tp)->out_symbol());
610 static void split_state(
EST_WFST &wfst, LISP trans_list,
int ostate)
620 for (t=trans_list; t; t=cdr(t))
621 trans(car(t))->set_state(nstate);
623 for (tp=wfst.
state(ostate)->transitions.head(); tp != 0; tp = tp->next())
627 wfst.
state(ostate)->transitions(tp)->state(),
628 wfst.
state(ostate)->transitions(tp)->in_symbol(),
629 wfst.
state(ostate)->transitions(tp)->out_symbol());