50 #include "EST_SCFG_Chart.h"
51 #include "EST_simplestats.h"
53 #include "EST_TVector.h"
61 #if defined(INSTANTIATE_TEMPLATES)
62 #include "../base_class/EST_TVector.cc"
72 b.
resize(siod_llength(examples));
74 for (i=0,e=examples; e != NIL; e=cdr(e),i++)
78 void EST_bracketed_string::init()
87 EST_bracketed_string::EST_bracketed_string()
92 EST_bracketed_string::EST_bracketed_string(LISP
string)
96 set_bracketed_string(
string);
99 EST_bracketed_string::~EST_bracketed_string()
105 for (i=0; i < p_length; i++)
106 delete [] valid_spans[i];
107 delete [] valid_spans;
110 void EST_bracketed_string::set_bracketed_string(LISP
string)
116 p_length = find_num_nodes(
string);
117 symbols =
new LISP[p_length];
119 set_leaf_indices(
string,0,symbols);
124 valid_spans =
new int*[length()];
125 for (i=0; i < length(); i++)
127 valid_spans[i] =
new int[length()+1];
128 for (j=i+1; j <= length(); j++)
129 valid_spans[i][j] = 0;
138 int EST_bracketed_string::find_num_nodes(LISP
string)
143 else if (CONSP(
string))
144 return find_num_nodes(car(
string))+
145 find_num_nodes(cdr(
string));
150 int EST_bracketed_string::set_leaf_indices(LISP
string,
int i,LISP *syms)
154 else if (!CONSP(car(
string)))
157 return set_leaf_indices(cdr(
string),i+1,syms);
161 return set_leaf_indices(cdr(
string),
162 set_leaf_indices(car(
string),i,syms),
167 void EST_bracketed_string::find_valid(
int s,LISP t)
const
174 for (c=s,l=t; l != NIL; l=cdr(l))
176 c += num_leafs(car(l));
177 valid_spans[s][c] = 1;
179 find_valid(s,car(t));
180 find_valid(s+num_leafs(car(t)),cdr(t));
184 int EST_bracketed_string::num_leafs(LISP t)
const
191 return num_leafs(car(t)) + num_leafs(cdr(t));
194 EST_SCFG_traintest::EST_SCFG_traintest(
void) :
EST_SCFG()
202 EST_SCFG_traintest::~EST_SCFG_traintest(
void)
209 set_corpus(corpus,vload(filename,1));
213 double EST_SCFG_traintest::f_I_cal(
int c,
int p,
int i,
int k)
237 double pBpqr =
prob_B(p,q,r);
239 for (j=i+1; j < k; j++)
241 double in = f_I(c,q,i,j);
243 s += pBpqr * in * f_I(c,r,j,k);
251 inside[p][i][k] = res;
259 double EST_SCFG_traintest::f_O_cal(
int c,
int p,
int i,
int k)
264 if ((i == 0) && (k == corpus.
a_no_check(c).length()))
266 if (p == distinguished_symbol())
288 double out = f_O(c,q,j,k);
290 s2 += out * f_I(c,r,j,i);
297 for (j=k+1;j <= corpus.
a_no_check(c).length(); j++)
299 double out = f_O(c,q,i,j);
301 s3 += out * f_I(c,r,k,j);
312 outside[p][i][k] = res;
317 void EST_SCFG_traintest::reestimate_rule_prob_B(
int c,
int ri,
int p,
int q,
int r)
323 double pBpqr =
prob_B(p,q,r);
327 for (i=0; i <= corpus.
a_no_check(c).length()-2; i++)
328 for (j=i+1; j <= corpus.
a_no_check(c).length()-1; j++)
330 double d1 = f_I(c,q,i,j);
331 if (d1 == 0)
continue;
332 for (k=j+1; k <= corpus.
a_no_check(c).length(); k++)
334 double d2 = f_I(c,r,j,k);
335 if (d2 == 0)
continue;
336 double d3 = f_O(c,p,i,k);
337 if (d3 == 0)
continue;
357 void EST_SCFG_traintest::reestimate_rule_prob_U(
int c,
int ri,
int p,
int m)
369 for (i=1; i < corpus.
a_no_check(c).length(); i++)
371 n2 +=
prob_U(p,m) * f_O(c,p,i-1,i);
377 d[ri] += f_P(c,p) / fP;
381 double EST_SCFG_traintest::f_P(
int c)
383 return f_I(c,distinguished_symbol(),0,corpus.
a_no_check(c).length());
386 double EST_SCFG_traintest::f_P(
int c,
int p)
391 for (i=0; i < corpus.
a_no_check(c).length(); i++)
392 for (j=i+1; j <= corpus.
a_no_check(c).length(); j++)
394 double d1 = f_O(c,p,i,j);
395 if (d1 == 0)
continue;
396 db += f_I(c,p,i,j)*d1;
402 void EST_SCFG_traintest::reestimate_grammar_probs(
int passes,
420 for (pass = startpass; pass < passes; pass++)
429 for (mC=0.0,lPc=0.0,c=0; c < corpus.
length(); c++)
432 if ((spread > 0) && (((c+(pass*spread))%100) >= spread))
434 printf(
" %d",c); fflush(stdout);
435 if (corpus.
a_no_check(c).length() == 0)
continue;
437 for (ri=0,r=
rules.head(); r != 0; r=r->next(),ri++)
439 if (
rules(r).type() == est_scfg_binary_rule)
440 reestimate_rule_prob_B(c,ri,
442 rules(r).daughter1(),
443 rules(r).daughter2());
445 reestimate_rule_prob_U(c,
448 rules(r).daughter1());
450 lPc += safe_log(f_P(c));
456 for (se=0.0,ri=0,r=
rules.head(); r != 0; r=r->next(),ri++)
458 double n_prob = n[ri]/d[ri];
461 se += (n_prob-
rules(r).prob())*(n_prob-
rules(r).prob());
462 rules(r).set_prob(n_prob);
464 printf(
"pass %d cross entropy %g RMSE %f %f %d\n",
465 pass,-(lPc/mC),sqrt(se/
rules.length()),
468 if (checkpoint != -1)
470 if ((pass % checkpoint) == checkpoint-1)
473 sprintf(cp,
".%03d",pass);
490 reestimate_grammar_probs(passes, startpass, checkpoint,
494 void EST_SCFG_traintest::init_io_cache(
int c,
int nt)
500 inside =
new double**[nt];
501 outside =
new double**[nt];
502 for (i=0; i < nt; i++)
504 inside[i] =
new double*[mc];
505 outside[i] =
new double*[mc];
506 for (j=0; j < mc; j++)
508 inside[i][j] =
new double[mc];
509 outside[i][j] =
new double[mc];
510 for (k=0; k < mc; k++)
512 inside[i][j][k] = -1;
513 outside[i][j][k] = -1;
519 void EST_SCFG_traintest::clear_io_cache(
int c)
529 for (j=0; j < mc; j++)
531 delete [] inside[i][j];
532 delete [] outside[i][j];
535 delete [] outside[i];
545 double EST_SCFG_traintest::cross_entropy()
550 for (c=0; c < corpus.
length(); c++)
570 for (i=0; i <
rules.length(); i++)
573 for (mC=0.0,lPc=0.0,c=0; c < corpus.
length(); c++)
594 cout <<
"cross entropy " << -(lPc/mC) <<
" (" << failed <<
" failed out of " <<
595 corpus.
length() <<
" sentences )" << endl;