Edinburgh Speech Tools  2.4-release
 All Classes Functions Variables Typedefs Enumerations Enumerator Friends Pages
EST_SCFG_inout.cc
1 /*************************************************************************/
2 /* */
3 /* Centre for Speech Technology Research */
4 /* University of Edinburgh, UK */
5 /* Copyright (c) 1997 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black */
34 /* Date : October 1997 */
35 /*-----------------------------------------------------------------------*/
36 /* */
37 /* Implementation of an inside-outside reestimation procedure for */
38 /* building a stochastic CFG seeded with a bracket corpus. */
39 /* Based on "Inside-Outside Reestimation from partially bracketed */
40 /* corpora", F Pereira and Y. Schabes. pp 128-135, 30th ACL, Newark, */
41 /* Delaware 1992. */
42 /* */
43 /* This should really be done in the log domain. Addition in the log */
44 /* domain can be done with a formula in Huang, Ariki and Jack */
45 /* (log(a)-log(b)) */
46 /* log(a+b) = log(1 + e ) + log(b) */
47 /* */
48 /*=======================================================================*/
49 #include <cstdlib>
50 #include "EST_SCFG_Chart.h"
51 #include "EST_simplestats.h"
52 #include "EST_math.h"
53 #include "EST_TVector.h"
54 
55 static const EST_bracketed_string def_val_s;
56 static EST_bracketed_string error_return_s;
59 
60 
61 #if defined(INSTANTIATE_TEMPLATES)
62 #include "../base_class/EST_TVector.cc"
63 
65 #endif
66 
67 void set_corpus(EST_Bcorpus &b, LISP examples)
68 {
69  LISP e;
70  int i;
71 
72  b.resize(siod_llength(examples));
73 
74  for (i=0,e=examples; e != NIL; e=cdr(e),i++)
75  b.a_no_check(i).set_bracketed_string(car(e));
76 }
77 
78 void EST_bracketed_string::init()
79 {
80  bs = NIL;
81  gc_protect(&bs);
82  symbols = 0;
83  valid_spans = 0;
84  p_length = 0;
85 }
86 
87 EST_bracketed_string::EST_bracketed_string()
88 {
89  init();
90 }
91 
92 EST_bracketed_string::EST_bracketed_string(LISP string)
93 {
94  init();
95 
96  set_bracketed_string(string);
97 }
98 
99 EST_bracketed_string::~EST_bracketed_string()
100 {
101  int i;
102  bs=NIL;
103  gc_unprotect(&bs);
104  delete [] symbols;
105  for (i=0; i < p_length; i++)
106  delete [] valid_spans[i];
107  delete [] valid_spans;
108 }
109 
110 void EST_bracketed_string::set_bracketed_string(LISP string)
111 {
112 
113  bs=NIL;
114  delete [] symbols;
115 
116  p_length = find_num_nodes(string);
117  symbols = new LISP[p_length];
118 
119  set_leaf_indices(string,0,symbols);
120 
121  bs = string;
122 
123  int i,j;
124  valid_spans = new int*[length()];
125  for (i=0; i < length(); i++)
126  {
127  valid_spans[i] = new int[length()+1];
128  for (j=i+1; j <= length(); j++)
129  valid_spans[i][j] = 0;
130  }
131 
132  // fill in valid table
133  if (p_length > 0)
134  find_valid(0,bs);
135 
136 }
137 
138 int EST_bracketed_string::find_num_nodes(LISP string)
139 {
140  // This wont could nil as an atom
141  if (string == NIL)
142  return 0;
143  else if (CONSP(string))
144  return find_num_nodes(car(string))+
145  find_num_nodes(cdr(string));
146  else
147  return 1;
148 }
149 
150 int EST_bracketed_string::set_leaf_indices(LISP string,int i,LISP *syms)
151 {
152  if (string == NIL)
153  return i;
154  else if (!CONSP(car(string)))
155  {
156  syms[i] = string;
157  return set_leaf_indices(cdr(string),i+1,syms);
158  }
159  else // car is a tree
160  {
161  return set_leaf_indices(cdr(string),
162  set_leaf_indices(car(string),i,syms),
163  syms);
164  }
165 }
166 
167 void EST_bracketed_string::find_valid(int s,LISP t) const
168 {
169  LISP l;
170  int c;
171 
172  if (consp(t))
173  {
174  for (c=s,l=t; l != NIL; l=cdr(l))
175  {
176  c += num_leafs(car(l));
177  valid_spans[s][c] = 1;
178  }
179  find_valid(s,car(t));
180  find_valid(s+num_leafs(car(t)),cdr(t));
181  }
182 }
183 
184 int EST_bracketed_string::num_leafs(LISP t) const
185 {
186  if (t == NIL)
187  return 0;
188  else if (!consp(t))
189  return 1;
190  else
191  return num_leafs(car(t)) + num_leafs(cdr(t));
192 }
193 
194 EST_SCFG_traintest::EST_SCFG_traintest(void) : EST_SCFG()
195 {
196  inside = 0;
197  outside = 0;
198  n.resize(0);
199  d.resize(0);
200 }
201 
202 EST_SCFG_traintest::~EST_SCFG_traintest(void)
203 {
204 
205 }
206 
208 {
209  set_corpus(corpus,vload(filename,1));
210 }
211 
212 // From the formula in the paper
213 double EST_SCFG_traintest::f_I_cal(int c, int p, int i, int k)
214 {
215  // Find Inside probability
216  double res;
217 
218  if (i == k-1)
219  {
220  res = prob_U(p,terminal(corpus.a_no_check(c).symbol_at(i)));
221 // printf("prob_U p %s (%d) %d m %s (%d) res %g\n",
222 // (const char *)nonterminal(p),p,
223 // i,
224 // (const char *)corpus.a_no_check(c).symbol_at(i),
225 // terminal(corpus.a_no_check(c).symbol_at(i)),
226 // res);
227  }
228  else if (corpus.a_no_check(c).valid(i,k) == TRUE)
229  {
230  int j;
231  double s=0;
232  int q,r;
233 
234  for (q = 0; q < num_nonterminals(); q++)
235  for (r = 0; r < num_nonterminals(); r++)
236  {
237  double pBpqr = prob_B(p,q,r);
238  if (pBpqr > 0)
239  for (j=i+1; j < k; j++)
240  {
241  double in = f_I(c,q,i,j);
242  if (in > 0)
243  s += pBpqr * in * f_I(c,r,j,k);
244  }
245  }
246  res = s;
247  }
248  else
249  res = 0.0;
250 
251  inside[p][i][k] = res;
252 
253 // printf("f_I p %s i %d k %d res %g\n",
254 // (const char *)nonterminal(p),i,k,res);
255 
256  return res;
257 }
258 
259 double EST_SCFG_traintest::f_O_cal(int c, int p, int i, int k)
260 {
261  // Find Outside probability
262  double res;
263 
264  if ((i == 0) && (k == corpus.a_no_check(c).length()))
265  {
266  if (p == distinguished_symbol()) // distinguished non-terminal
267  res = 1.0;
268  else
269  res = 0.0;
270  }
271  else if (corpus.a_no_check(c).valid(i,k) == TRUE)
272  {
273  double s1=0.0;
274  double s2,s3;
275  double pBqrp,pBqpr;
276  int j;
277  int q,r;
278 
279  for (q = 0; q < num_nonterminals(); q++)
280  for (r = 0; r < num_nonterminals(); r++)
281  {
282  pBqrp = prob_B(q,r,p);
283  s2 = s3 = 0.0;
284  if (pBqrp > 0)
285  {
286  for (j=0;j < i; j++)
287  {
288  double out = f_O(c,q,j,k);
289  if (out > 0)
290  s2 += out * f_I(c,r,j,i);
291  }
292  s2 *= pBqrp;
293  }
294  pBqpr = prob_B(q,p,r);
295  if (pBqpr > 0)
296  {
297  for (j=k+1;j <= corpus.a_no_check(c).length(); j++)
298  {
299  double out = f_O(c,q,i,j);
300  if (out > 0)
301  s3 += out * f_I(c,r,k,j);
302  }
303  s3 *= pBqpr;
304  }
305  s1 += s2 + s3;
306  }
307  res = s1;
308  }
309  else // not a valid bracketing
310  res = 0.0;
311 
312  outside[p][i][k] = res;
313 
314  return res;
315 }
316 
317 void EST_SCFG_traintest::reestimate_rule_prob_B(int c, int ri, int p, int q, int r)
318 {
319  // Re-estimate probability for binary rules
320  int i,j,k;
321  double n2=0;
322 
323  double pBpqr = prob_B(p,q,r);
324 
325  if (pBpqr > 0)
326  {
327  for (i=0; i <= corpus.a_no_check(c).length()-2; i++)
328  for (j=i+1; j <= corpus.a_no_check(c).length()-1; j++)
329  {
330  double d1 = f_I(c,q,i,j);
331  if (d1 == 0) continue;
332  for (k=j+1; k <= corpus.a_no_check(c).length(); k++)
333  {
334  double d2 = f_I(c,r,j,k);
335  if (d2 == 0) continue;
336  double d3 = f_O(c,p,i,k);
337  if (d3 == 0) continue;
338  n2 += d1 * d2 * d3;
339  }
340  }
341  n2 *= pBpqr;
342  }
343  // f_P(c) is probably redundant
344  double fp = f_P(c);
345  double n1,d1;
346  n1 = n2 / fp;
347  if (fp == 0) n1=0;
348 
349  d1 = f_P(c,p) / fp;
350  if (fp == 0) d1=0;
351  // printf("n1 %f d1 %f n2 %f fp %f\n",n1,d1,n2,fp);
352  n[ri] += n1;
353  d[ri] += d1;
354 
355 }
356 
357 void EST_SCFG_traintest::reestimate_rule_prob_U(int c,int ri, int p, int m)
358 {
359  // Re-estimate probability for unary rules
360  int i;
361 
362 // printf("reestimate_rule_prob_U: %f p %s m %s\n",
363 // prob_U(ip,im),
364 // (const char *)p,
365 // (const char *)m);
366 
367  double n2=0;
368 
369  for (i=1; i < corpus.a_no_check(c).length(); i++)
370  if (m == terminal(corpus.a_no_check(c).symbol_at(i-1)))
371  n2 += prob_U(p,m) * f_O(c,p,i-1,i);
372 
373  double fP = f_P(c);
374  if (fP != 0)
375  {
376  n[ri] += n2 / fP;
377  d[ri] += f_P(c,p) / fP;
378  }
379 }
380 
381 double EST_SCFG_traintest::f_P(int c)
382 {
383  return f_I(c,distinguished_symbol(),0,corpus.a_no_check(c).length());
384 }
385 
386 double EST_SCFG_traintest::f_P(int c,int p)
387 {
388  int i,j;
389  double db=0;
390 
391  for (i=0; i < corpus.a_no_check(c).length(); i++)
392  for (j=i+1; j <= corpus.a_no_check(c).length(); j++)
393  {
394  double d1 = f_O(c,p,i,j);
395  if (d1 == 0) continue;
396  db += f_I(c,p,i,j)*d1;
397  }
398 
399  return db;
400 }
401 
402 void EST_SCFG_traintest::reestimate_grammar_probs(int passes,
403  int startpass,
404  int checkpoint,
405  int spread,
406  const EST_String &outfile)
407 {
408  // Iterate over the corpus cummulating factors for each rules
409  // This reduces the space requirements and recalculations of
410  // values for each sentences.
411  // Repeat training passes to number specified
412  int pass = 0;
413  double zero=0;
414  double se;
415  int ri,c;
416 
417  n.resize(rules.length());
418  d.resize(rules.length());
419 
420  for (pass = startpass; pass < passes; pass++)
421  {
422  EST_Litem *r;
423  double mC, lPc;
424 
425  d.fill(zero);
426  n.fill(zero);
428 
429  for (mC=0.0,lPc=0.0,c=0; c < corpus.length(); c++)
430  {
431  // For skipping some sentences to speed up convergence
432  if ((spread > 0) && (((c+(pass*spread))%100) >= spread))
433  continue;
434  printf(" %d",c); fflush(stdout);
435  if (corpus.a_no_check(c).length() == 0) continue;
436  init_io_cache(c,num_nonterminals());
437  for (ri=0,r=rules.head(); r != 0; r=r->next(),ri++)
438  {
439  if (rules(r).type() == est_scfg_binary_rule)
440  reestimate_rule_prob_B(c,ri,
441  rules(r).mother(),
442  rules(r).daughter1(),
443  rules(r).daughter2());
444  else
445  reestimate_rule_prob_U(c,
446  ri,
447  rules(r).mother(),
448  rules(r).daughter1());
449  }
450  lPc += safe_log(f_P(c));
451  mC += corpus.a_no_check(c).length();
452  clear_io_cache(c);
453  }
454  printf("\n");
455 
456  for (se=0.0,ri=0,r=rules.head(); r != 0; r=r->next(),ri++)
457  {
458  double n_prob = n[ri]/d[ri];
459  if (d[ri] == 0)
460  n_prob = 0;
461  se += (n_prob-rules(r).prob())*(n_prob-rules(r).prob());
462  rules(r).set_prob(n_prob);
463  }
464  printf("pass %d cross entropy %g RMSE %f %f %d\n",
465  pass,-(lPc/mC),sqrt(se/rules.length()),
466  se,rules.length());
467 
468  if (checkpoint != -1)
469  {
470  if ((pass % checkpoint) == checkpoint-1)
471  {
472  char cp[20];
473  sprintf(cp,".%03d",pass);
474  save(outfile+cp);
475  user_gc(NIL); // just to keep things neat
476  }
477  }
478 
479  }
480 }
481 
483  int startpass,
484  int checkpoint,
485  int spread,
486  const EST_String &outfile)
487 {
488  // Train a Stochastic CFG using the inside outside algorithm
489 
490  reestimate_grammar_probs(passes, startpass, checkpoint,
491  spread, outfile);
492 }
493 
494 void EST_SCFG_traintest::init_io_cache(int c,int nt)
495 {
496  // Build an array to cache the in/out values
497  int i,j,k;
498  int mc = corpus.a_no_check(c).length()+1;
499 
500  inside = new double**[nt];
501  outside = new double**[nt];
502  for (i=0; i < nt; i++)
503  {
504  inside[i] = new double*[mc];
505  outside[i] = new double*[mc];
506  for (j=0; j < mc; j++)
507  {
508  inside[i][j] = new double[mc];
509  outside[i][j] = new double[mc];
510  for (k=0; k < mc; k++)
511  {
512  inside[i][j][k] = -1;
513  outside[i][j][k] = -1;
514  }
515  }
516  }
517 }
518 
519 void EST_SCFG_traintest::clear_io_cache(int c)
520 {
521  int mc = corpus.a_no_check(c).length()+1;
522  int i,j;
523 
524  if (inside == 0)
525  return;
526 
527  for (i=0; i < num_nonterminals(); i++)
528  {
529  for (j=0; j < mc; j++)
530  {
531  delete [] inside[i][j];
532  delete [] outside[i][j];
533  }
534  delete [] inside[i];
535  delete [] outside[i];
536  }
537 
538  delete [] inside;
539  delete [] outside;
540 
541  inside = 0;
542  outside = 0;
543 }
544 
545 double EST_SCFG_traintest::cross_entropy()
546 {
547  double lPc=0,mC=0;
548  int c;
549 
550  for (c=0; c < corpus.length(); c++)
551  {
552  lPc += log(f_P(c));
553  mC += corpus.a_no_check(c).length();
554  }
555 
556  return -(lPc/mC);
557 }
558 
560 {
561  // Test corpus against current grammar.
562  double mC,lPc;
563  int c,i;
564  int failed=0;
565  double fP;
566 
567  // Lets try simply finding the cross entropy
568  n.resize(rules.length());
569  d.resize(rules.length());
570  for (i=0; i < rules.length(); i++)
571  d[i] = n[i] = 0.0;
572 
573  for (mC=0.0,lPc=0.0,c=0; c < corpus.length(); c++)
574  {
575  if (corpus.length() > 50)
576  {
577  printf(" %d",c);
578  fflush(stdout);
579  }
580  init_io_cache(c,num_nonterminals());
581  fP = f_P(c);
582  if (fP == 0)
583  failed++;
584  else
585  {
586  lPc += safe_log(fP);
587  mC += corpus.a_no_check(c).length();
588  }
589  clear_io_cache(c);
590  }
591  if (corpus.length() > 50)
592  printf("\n");
593 
594  cout << "cross entropy " << -(lPc/mC) << " (" << failed << " failed out of " <<
595  corpus.length() << " sentences )" << endl;
596 
597 }
598