41 #include "EST_Pathname.h"
44 EST_SCFG_Rule::EST_SCFG_Rule(
double prob,
int p,
int m)
49 EST_SCFG_Rule::EST_SCFG_Rule(
double prob,
int p,
int q,
int r)
54 void EST_SCFG_Rule::set_rule(
double prob,
int p,
int m)
59 p_type = est_scfg_unary_rule;
62 void EST_SCFG_Rule::set_rule(
double prob,
int p,
int q,
int r)
68 p_type = est_scfg_binary_rule;
77 EST_SCFG::EST_SCFG(LISP rs)
84 EST_SCFG::~EST_SCFG(
void)
87 delete_rule_prob_cache();
96 for (r=rs; r != NIL; r=cdr(r))
98 LISP p = car(cdr(car(r)));
99 if (!strlist_member(nt,get_c_string(p)))
100 nt.
append(get_c_string(p));
101 if (siod_llength(car(r)) == 3)
103 LISP d = car(cdr(cdr(car(r))));
104 if (!strlist_member(t,get_c_string(d)))
105 t.
append(get_c_string(d));
109 LISP d1 = car(cdr(cdr(car(r))));
110 LISP d2 = car(cdr(cdr(cdr(car(r)))));
111 if (!strlist_member(nt,get_c_string(d1)))
112 nt.
append(get_c_string(d1));
113 if (!strlist_member(nt,get_c_string(d2)))
114 nt.
append(get_c_string(d2));
127 delete_rule_prob_cache();
130 nonterminals.
init(nt_list);
131 terminals.
init(term_list);
133 if (!consp(car(cdr(car(lrules)))))
134 p_distinguished_symbol =
137 cerr <<
"SCFG: no distinguished symbol" << endl;
139 for (r=lrules; r != NIL; r=cdr(r))
141 if ((siod_llength(car(r)) < 3) ||
142 (siod_llength(car(r)) > 4) ||
143 (!numberp(car(car(r)))))
144 cerr <<
"SCFG rule is malformed" << endl;
149 if (siod_llength(car(r)) == 3)
151 int m =
nonterminal(get_c_string(car(cdr(car(r)))));
152 int d =
terminal(get_c_string(car(cdr(cdr(car(r))))));
153 rule.set_rule(get_c_float(car(car(r))),m,d);
157 int p =
nonterminal(get_c_string(car(cdr(car(r)))));
158 int d1=
nonterminal(get_c_string(car(cdr(cdr(car(r))))));
159 int d2 =
nonterminal(get_c_string(car(cdr(cdr(cdr(car(r)))))));
160 rule.set_rule(get_c_float(car(car(r))),p,d1,d2);
175 for (r=NIL,p=
rules.head(); p != 0; p=p->next())
177 if (
rules(p).type() == est_scfg_unary_rule)
178 r = cons(cons(flocons(
rules(p).prob()),
182 else if (
rules(p).type() == est_scfg_binary_rule)
183 r = cons(cons(flocons(
rules(p).prob()),
197 rs = vload(filename,1);
214 if ((fd=fopen(outfile,
"w")) == NULL)
216 cerr <<
"scfg_train: failed to open file \"" << outfile <<
217 "\" for writing" << endl;
218 return misc_write_error;
223 pprint_to_fd(fd,car(r));
232 void EST_SCFG::rule_prob_cache()
260 for (pp=
rules.head(); pp != 0; pp = pp->next())
262 if (
rules(pp).type() == est_scfg_binary_rule)
264 int p =
rules(pp).mother();
265 int q =
rules(pp).daughter1();
266 int r =
rules(pp).daughter2();
267 p_prob_B[p][q][r] =
rules(pp).prob();
269 else if (
rules(pp).type() == est_scfg_unary_rule)
271 int p =
rules(pp).mother();
272 int m =
rules(pp).daughter1();
273 p_prob_U[p][m] =
rules(pp).prob();
278 void EST_SCFG::delete_rule_prob_cache()
288 delete [] p_prob_B[i][j];
289 delete [] p_prob_B[i];
290 delete [] p_prob_U[i];
302 return s <<
"<<EST_SCFG_Rule>>";
306 #if defined(INSTANTIATE_TEMPLATES)
307 #include "../base_class/EST_TList.cc"
308 #include "../base_class/EST_TSortable.cc"