54 static void load_vocab(
const EST_String &vfile);
58 static void load_wstream(
const EST_String &filename,
63 static void load_given(
const EST_String &filename,
64 const int ngram_order);
66 static double find_gram_prob(
EST_VTPath *p,
int *state);
69 static double find_extra_gram_prob(
EST_VTPath *p,
int *state,
int time);
73 static int is_a_special(
const EST_String &s,
int &val);
74 static int max_history=0;
77 static EST_String pstring = SENTENCE_START_MARKER;
78 static EST_String ppstring = SENTENCE_END_MARKER;
79 static float lm_scale = 1.0;
80 static float ob_scale = 1.0;
81 static float ob_scale2 = 1.0;
85 static float ob_beam=-1;
86 static int n_beam = -1;
88 static bool trace_on = FALSE;
91 static double ob_log_prob_floor = SAFE_LOG_ZERO;
92 static double ob_log_prob_floor2 = SAFE_LOG_ZERO;
93 static double lm_log_prob_floor = SAFE_LOG_ZERO;
95 int btest_debug = FALSE;
101 int using_given=FALSE;
104 int take_logs = FALSE;
174 int main(
int argc,
char **argv)
181 parse_command_line(argc, argv,
182 EST_String(
"[observations file] -o [output file]\n")+
183 "Summary: find the most likely path through a sequence of\n"+
184 " observations, constrained by a language model.\n"+
185 "-ngram <string> Grammar file, required\n"+
186 "-given <string> ngram left contexts, per frame\n"+
187 "-vocab <string> File with names of vocabulary, this\n"+
188 " must be same number as width of observations, required\n"+
189 "-ob_type <string> Observation type : likelihood .... and change doc\"probs\" or \"logs\" (default is \"logs\")\n"+
190 "\nFloor values and scaling (scaling is applied after floor value)\n"+
191 "-lm_floor <float> LM floor probability\n"+
192 "-lm_scale <float> LM scale factor factor (applied to log prob)\n"+
193 "-ob_floor <float> Observations floor probability\n"+
194 "-ob_scale <float> Observation scale factor (applied to prob or log prob, depending on -ob_type)\n\n"+
195 "-prev_tag <string>\n"+
196 " tag before sentence start\n"+
197 "-prev_prev_tag <string>\n"+
198 " all words before 'prev_tag'\n"+
199 "-last_tag <string>\n"+
200 " after sentence end\n"+
201 "-default_tags use default tags of "+SENTENCE_START_MARKER+
","
202 SENTENCE_END_MARKER+
" and "+SENTENCE_END_MARKER+
"\n"+
205 "-observes2 <string> second observations (overlays first, ob_type must be same)\n"+
206 "-ob_floor2 <float> \n"+
207 "-ob_scale2 <float> \n\n"+
208 "-ob_prune <float> observation pruning beam width (log) probability\n"+
209 "-n_prune <int> top-n pruning of observations\n"+
210 "-prune <float> pruning beam width (log) probability\n"+
211 "-trace show details of search as it proceeds\n",
216 if (files.length() != 1)
219 cerr <<
": you must give exactly one observations file on the command line";
221 cerr <<
"(use -observes2 for optional second observations)" << endl;
227 ngram.load(al.
val(
"-ngram"));
231 cerr << argv[0] <<
": no ngram specified" << endl;
237 cerr <<
"You must provide a vocabulary file !" << endl;
241 load_wstream(files.
first(),al.
val(
"-vocab"),wstream,observations);
244 load_wstream(al.
val(
"-observes2"),al.
val(
"-vocab"),wstream,observations2);
250 load_given(al.
val(
"-given"),ngram.order());
255 lm_scale = al.
fval(
"-lm_scale");
260 ob_scale = al.
fval(
"-ob_scale");
265 ob_scale2 = al.
fval(
"-ob_scale2");
270 pstring = al.
val(
"-prev_tag");
271 if (al.
present(
"-prev_prev_tag"))
272 ppstring = al.
val(
"-prev_prev_tag");
276 beam = al.
fval(
"-prune");
281 ob_beam = al.
fval(
"-ob_prune");
287 n_beam = al.
ival(
"-n_prune");
290 cerr <<
"WARNING : " << n_beam;
291 cerr <<
" is not a reasonable value for -n_prune !" << endl;
305 floor = al.
fval(
"-lm_floor");
308 cerr <<
"Error : LM floor probability is negative !" << endl;
313 cerr <<
"Error : LM floor probability > 1 " << endl;
316 lm_log_prob_floor = safe_log(floor);
322 floor = al.
fval(
"-ob_floor");
325 cerr <<
"Error : Observation floor probability is negative !" << endl;
330 cerr <<
"Error : Observation floor probability > 1 " << endl;
333 ob_log_prob_floor = safe_log(floor);
338 floor = al.
fval(
"-ob_floor2");
341 cerr <<
"Error : Observation2 floor probability is negative !" << endl;
346 cerr <<
"Error : Observation2 floor probability > 1 " << endl;
349 ob_log_prob_floor2 = safe_log(floor);
355 if(al.
val(
"-ob_type") ==
"logs")
357 else if(al.
val(
"-ob_type") ==
"probs")
361 cerr <<
"\"" << al.
val(
"-ob_type")
362 <<
"\" is not a valid ob_type : try \"logs\" or \"probs\"" << endl;
367 if(do_search(wstream))
368 print_results(wstream);
370 cerr <<
"No path could be found." << endl;
384 else if ((fd = fopen(out_file,
"wb")) == NULL)
386 cerr <<
"can't open \"" << out_file <<
"\" for output" << endl;
390 for (s=wstream.
head(); s != 0 ; s=s->next())
392 predict = s->f(
"best").
string();
393 pscore = s->f(
"best_score");
394 fprintf(fd,
"%s %f\n",(
const char *)predict,pscore);
407 states = ngram.num_states();
410 vc.initialise(&wstream);
412 if((beam > 0) || (ob_beam > 0))
413 vc.set_pruning_parameters(beam,ob_beam);
418 cerr <<
"Starting Viterbi search..." << endl;
423 return vc.result(
"best");
427 static void load_wstream(
const EST_String &filename,
439 if (obs.
load(filename,0.10) != 0)
441 cerr <<
"can't find observations file \"" << filename <<
"\"" << endl;
447 cerr <<
"Number in vocab (" << vocab.length() <<
448 ") not equal to observation's width (" <<
455 add_word(w,itoString(i),i);
459 static void load_given(
const EST_String &filename,
460 const int ngram_order)
467 if (load_TList_of_StrVector(given,filename,ngram_order-1) != 0)
469 cerr <<
"can't load given file \"" << filename <<
"\"" << endl;
474 for (p = given.head(); p; p = p->next())
476 for(i=0;i<given(p).length();i++)
477 if( is_a_special( given(p)(i), j) && (-j > max_history))
484 static void load_vocab(
const EST_String &vfile)
489 if (ts.
open(vfile) == -1)
491 cerr <<
"can't find vocab file \"" << vfile <<
"\"" << endl;
506 item->set_name(word);
507 item->
set(
"pos",pos);
513 double prob=1.0,prob2=1.0;
521 observe = s->f(
"pos");
522 for (i=0,p=vocab.head(); i < observations.
num_channels(); i++,p=p->next())
526 prob = observations.
a(observe,i);
528 prob2 = observations2.
a(observe,i);
532 prob = safe_log10(prob);
533 if (prob < ob_log_prob_floor)
534 prob = ob_log_prob_floor;
538 prob2 = safe_log10(prob2);
539 if (prob2 < ob_log_prob_floor2)
540 prob2 = ob_log_prob_floor2;
545 if (prob < ob_log_prob_floor)
546 prob = ob_log_prob_floor;
547 if ((num_obs == 2) && (prob2 < ob_log_prob_floor2))
548 prob2 = ob_log_prob_floor2;
555 c->score = prob + prob2;
567 top_n_candidates(all_c);
590 prob = find_extra_gram_prob(np,&np->state,c->s->f(
"pos"));
592 prob = find_gram_prob(np,&np->state);
594 lprob = safe_log10(prob);
595 if (lprob < lm_log_prob_floor)
596 lprob = lm_log_prob_floor;
600 np->f.
set(
"lscore",(c->score+lprob));
602 np->score = (c->score+lprob);
604 np->score = (c->score+lprob) + p->score;
609 static double find_gram_prob(
EST_VTPath *p,
int *state)
613 double prob=0.0,nprob;
618 for (pp=p->from,i=ngram.order()-2; i >= 0; i--)
622 window[i] = pp->c->name.
string();
626 window[i] = ppstring;
633 window[ngram.order()-1] = p->c->name.
string();
638 prob = (double)pd.probability(p->c->name.
string());
640 for (i=0; i < ngram.order()-1; i++)
641 window[i] = window(i+1);
642 ngram.predict(window,&nprob,state);
648 static double find_extra_gram_prob(
EST_VTPath *p,
int *state,
int time)
652 double prob=0.0,nprob;
656 get_history(history,p);
658 fill_window(window,history,p,time);
671 prob = (double)pd.probability(p->c->name.
string());
676 for(i=history.length()-1;i>0;i--)
677 history[i] = history(i-1);
678 history[0] = p->c->name.
string();
681 fill_window(window,history,p,time+1);
682 ngram.predict(window,&nprob,state);
695 for (pp=p->from,i=0; i < history.
length(); i++)
700 history[i] = pp->c->name.
string();
704 history[i] = ppstring;
707 history[i] = pstring;
723 if( time >= given.length() )
733 window[ngram.order()-1] = p->c->name.
string();
739 for(i=0;i<ngram.order()-1;i++)
742 if( is_a_special( (*this_g)(i), j))
743 window[i] = history(-1-j);
745 window[i] = (*this_g)(i);
751 static int is_a_special(
const EST_String &s,
int &val)
782 for(i=0;i<n_beam;i++)
791 for(p=all_c;p!= NULL;q=p,p=p->next)
794 if(p->score > this_best->score)
801 if(this_best == NULL)
805 if(prev_to_best == NULL)
807 all_c = this_best->next;
810 prev_to_best->next = this_best->next;
812 this_best->next = top_c;