Edinburgh Speech Tools  2.4-release
 All Classes Functions Variables Typedefs Enumerations Enumerator Friends Pages
EST_ols.cc
1 /*************************************************************************/
2 /* */
3 /* Centre for Speech Technology Research */
4 /* University of Edinburgh, UK */
5 /* Copyright (c) 1998 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black (and lots of books) */
34 /* Date : January 1998 */
35 /*-----------------------------------------------------------------------*/
36 /* Ordinary Least Squares/Linear regression */
37 /* */
38 /*=======================================================================*/
39 #include <cmath>
40 #include "EST_multistats.h"
41 #include "EST_simplestats.h"
42 
43 static void ols_load_selected_feats(const EST_FMatrix &X,
44  const EST_IVector &included,
45  EST_FMatrix &Xl);
46 static int ols_stepwise_find_best(const EST_FMatrix &X,
47  const EST_FMatrix &Y,
48  EST_IVector &included,
49  EST_FMatrix &coeffs,
50  float &bscore,
51  int &best_feat,
52  const EST_FMatrix &Xtest,
53  const EST_FMatrix &Ytest,
54  const EST_StrList &feat_names
55  );
56 
57 int ols(const EST_FMatrix &X,const EST_FMatrix &Y, EST_FMatrix &coeffs)
58 {
59  // Ordinary least squares, X contains the samples with 1 (for intercept)
60  // in column 0, Y contains the single values.
61  EST_FMatrix Xplus;
62 
63  if (!pseudo_inverse(X,Xplus))
64  return FALSE;
65 
66  multiply(Xplus,Y,coeffs);
67 
68  return TRUE;
69 }
70 
71 int robust_ols(const EST_FMatrix &X,
72  const EST_FMatrix &Y,
73  EST_FMatrix &coeffs)
74 {
75  EST_IVector included;
76  int i;
77 
78  included.resize(X.num_columns());
79  for (i=0; i<included.length(); i++)
80  included.a_no_check(i) = TRUE;
81 
82  return robust_ols(X,Y,included,coeffs);
83 }
84 
85 int robust_ols(const EST_FMatrix &X,
86  const EST_FMatrix &Y,
87  EST_IVector &included,
88  EST_FMatrix &coeffs)
89 {
90  // as with ols but if the pseudo inverse fails remove the offending
91  // features and try again until it works, this can be costly but
92  // its saves *you* from finding the singularity
93  // This expands the output and puts weights of 0 for omitted features
94  EST_FMatrix Xl;
95  EST_FMatrix coeffsl;
96  EST_FMatrix Xplus;
97  int i,j,singularity=-1;
98 
99  if (X.num_rows() <= X.num_columns())
100  {
101  cerr << "OLS: less rows than columns, so cannot find solution."
102  << endl;
103  return FALSE;
104  }
105  if (X.num_columns() != included.length())
106  {
107  cerr << "OLS: `included' list wrong size: internal error."
108  << endl;
109  return FALSE;
110  }
111 
112  while (TRUE)
113  {
114  ols_load_selected_feats(X,included,Xl);
115  if (pseudo_inverse(Xl,Xplus,singularity))
116  {
117  multiply(Xplus,Y,coeffsl);
118  break;
119  }
120  else
121  { // found a singularity so try again without that column
122  // remap singularity position back to X
123  int s;
124  for (s=i=0; i<singularity; i++)
125  {
126  s++;
127  while ((included(s) == FALSE) ||
128  (included(s) == OLS_IGNORE))
129  s++;
130  }
131  if (included(s) == FALSE)
132  { // oops
133  cerr << "OLS: found singularity twice, shouldn't happen"
134  << endl;
135  return FALSE;
136  }
137  else
138  {
139  cerr << "OLS: omitting singularity in column " << s << endl;
140  included[s] = FALSE;
141  }
142  }
143  }
144 
145  // Map coefficients back, making coefficient 0 for singular cols
146  coeffs.resize(X.num_columns(),1);
147  for (j=i=0; i<X.num_columns(); i++)
148  if (included(i))
149  {
150  coeffs.a_no_check(i,0) = coeffsl(j,0);
151  j++;
152  }
153  else
154  coeffs.a_no_check(i,0) = 0.0;
155 
156 
157  return TRUE;
158 
159 }
160 
161 static void ols_load_selected_feats(const EST_FMatrix &X,
162  const EST_IVector &included,
163  EST_FMatrix &Xl)
164 {
165  int i,j,k,width;
166 
167  for (width=i=0; i<included.length(); i++)
168  if (included(i) == TRUE)
169  width++;
170 
171  Xl.resize(X.num_rows(),width);
172 
173  for (i=0; i<X.num_rows(); i++)
174  for (k=j=0; j < X.num_columns(); j++)
175  if (included(j) == TRUE)
176  {
177  Xl.a_no_check(i,k) = X.a_no_check(i,j);
178  k++;
179  }
180 
181 }
182 
183 int ols_apply(const EST_FMatrix &samples,
184  const EST_FMatrix &coeffs,
185  EST_FMatrix &res)
186 {
187  // Apply coefficients to samples for res.
188 
189  if (samples.num_columns() != coeffs.num_rows())
190  return FALSE;
191 
192  multiply(samples,coeffs,res);
193 
194  return TRUE;
195 }
196 
197 int stepwise_ols(const EST_FMatrix &X,
198  const EST_FMatrix &Y,
199  const EST_StrList &feat_names,
200  float limit,
201  EST_FMatrix &coeffs,
202  const EST_FMatrix &Xtest,
203  const EST_FMatrix &Ytest,
204  EST_IVector &included)
205 {
206  // Find the features that contribute to the correlation using a
207  // a greedy algorithm
208 
209  EST_FMatrix coeffsl;
210  float best_score=0.0,bscore;
211  int i,best_feat;
212  int nf=1; // for nice printing of progress
213 
214  for (i=1; i < X.num_columns(); i++)
215  {
216  if (!ols_stepwise_find_best(X,Y,included,coeffsl,
217  bscore,best_feat,Xtest,Ytest,
218  feat_names))
219  {
220  cerr << "OLS: stepwise failed" << endl;
221  return FALSE;
222  }
223  if ((bscore - (bscore * (limit/100))) <= best_score)
224  break;
225  else
226  {
227  best_score = bscore;
228  coeffs = coeffsl;
229  included[best_feat] = TRUE;
230  printf("FEATURE %d %s: %2.4f\n",
231  nf,
232  (const char *)feat_names.nth(best_feat),
233  best_score);
234  fflush(stdout);
235  nf++;
236  }
237  }
238 
239  return TRUE;
240 }
241 
242 static int ols_stepwise_find_best(const EST_FMatrix &X,
243  const EST_FMatrix &Y,
244  EST_IVector &included,
245  EST_FMatrix &coeffs,
246  float &bscore,
247  int &best_feat,
248  const EST_FMatrix &Xtest,
249  const EST_FMatrix &Ytest,
250  const EST_StrList &feat_names
251  )
252 {
253  EST_FMatrix coeffsl;
254  bscore = 0;
255  best_feat = -1;
256  int i;
257 
258  for (i=0; i < included.length(); i++)
259  {
260  if (included.a_no_check(i) == FALSE)
261  {
262  float cor, rmse;
263  EST_FMatrix pred;
264  included.a_no_check(i) = TRUE;
265  if (!robust_ols(X,Y,included,coeffsl))
266  return FALSE; // failed for some reason
267  ols_apply(Xtest,coeffsl,pred);
268  ols_test(Ytest,pred,cor,rmse);
269  printf("tested %d %s %f best %f\n",
270  i,(const char *)feat_names.nth(i),cor,bscore);
271  if (fabs(cor) > bscore)
272  {
273  bscore = fabs(cor);
274  coeffs = coeffsl;
275  best_feat = i;
276  }
277  included.a_no_check(i) = FALSE;
278  }
279  }
280 
281  return TRUE;
282 }
283 
284 int ols_test(const EST_FMatrix &real,
285  const EST_FMatrix &predicted,
286  float &correlation,
287  float &rmse)
288 {
289  // Others probably want this function too
290  // return correlation and RMSE for col 0 in real and predicted
291  int i;
292  float p,r;
293  EST_SuffStats x,y,xx,yy,xy,se,e;
294  double error;
295  double v1,v2,v3;
296 
297  if (real.num_rows() != predicted.num_rows())
298  return FALSE; // can't do this
299 
300  for (i=0; i < real.num_rows(); i++)
301  {
302  r = real(i,0);
303  p = predicted(i,0);
304  x += p;
305  y += r;
306  error = p-r;
307  se += error*error;
308  e += fabs(error);
309  xx += p*p;
310  yy += r*r;
311  xy += p*r;
312  }
313 
314  rmse = sqrt(se.mean());
315 
316  v1 = xx.mean()-(x.mean()*x.mean());
317  v2 = yy.mean()-(y.mean()*y.mean());
318 
319  v3 = v1*v2;
320 
321  if (v3 <= 0)
322  { // happens when there's very little variation in x
323  correlation = 0;
324  rmse = se.mean();
325  return FALSE;
326  }
327  // Pearson's product moment correlation coefficient
328  correlation = (xy.mean() - (x.mean()*y.mean()))/ sqrt(v3);
329 
330  // I hate to have to do this but it is necessary.
331  // When the the variation of X is very small v1*v2 approaches
332  // 0 (the negative and equals case is caught above) but that
333  // may not be enough when v1 or v2 are very small but positive.
334  // So I catch it here. If I knew more math I'd be able to describe
335  // this better but the code would remain the same.
336  if ((correlation <= 1.0) && (correlation >= -1.0))
337  return TRUE;
338  else
339  {
340  correlation = 0;
341  return FALSE;
342  }
343 }