X-Boost  2.3.8
DecisionTreeLearner.h
Go to the documentation of this file.
1 /* XBoost: Ada-Boost and Friends on Haar/ICF/HOG Features, Library and ToolBox
2  *
3  * Copyright (c) 2008-2014 Paolo Medici <medici@ce.unipr.it>
4  *
5  * This library is free software; you can redistribute it and/or
6  * modify it under the terms of the GNU Lesser General Public
7  * License as published by the Free Software Foundation; either
8  * version 2 of the License, or (at your option) any later version.
9  *
10  * This library is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13  * Lesser General Public License for more details.
14  *
15  * You should have received a copy of the GNU Lesser General Public
16  * License along with this library; if not, write to the
17  * Free Software Foundation, Inc., 59 Temple Place - Suite 330,
18  * Boston, MA 02111-1307, USA.
19  */
20 
21 #ifndef _DECISION_TREE_LEARNER_H
22 #define _DECISION_TREE_LEARNER_H
23 
27 #include "Pattern/Pattern.h"
28 #include "DataSet.h"
29 #include "InformationMetrics.h"
30 #include "Classifier/DecisionTree.h"
31 #include "Test.h"
32 
33 struct DTResponse {
34  int response;
35  int category;
36 };
37 
38 bool operator < (const DTResponse &a, const DTResponse & b)
39 {
40  return a.response < b.response;
41 }
42 
45  int A, B;
46 };
47 
50  double A,B;
51 };
52 
53 template<class T>
54 std::pair<int, float> FindThreshold(const DTResponse *data, int n, int A, int B);
55 
56 // void grow(X, y)
57 // {
58 // se tutti i valodi di Y sono uguali, ritornare Y
59 // se tutti i valori di X sono uguali (ovvero X.row()<2), ritornare LeafNode per maggioranza su Y.
60 // altrimenti trovare un attributo con il piu' elevato Info Gain (IG)
61 //
62 // scegliere m feature a caso.
63 // for j=1 a m:
64 // IG = f(Y|Xj)
65 //
66 //
67 //
68 // }
69 
70 // TODO: extend to AdaBoost weighted samples
71 template<>
72 std::pair<int, float> FindThreshold<GiniIndex>(const DTResponse *data, int n, int A, int B)
73 {
74 int a1,b1; // left count
75 int a2,b2; // right count
76 std::pair<int, float> out;
77 // trovare threshold e parity che massimizzino l'indice di Gini
78 // max(a1*b1/(a1+b1) + a2*b2/(a2*b2));
79 
80 // value > th ? RIGHT : LEFT
81 
82 // con la soglia minima, tutti gli elementi sono nella RIGHT.
83 a2=A;
84 b2=B;
85 // e nessun elemento nella LEFT
86 a1=0;
87 b1=0;
88 /*
89 out.second = (float)(a1*b1)/(float)(a1+b1);
90 out.first = data[0].response;*/
91 
92 out.second = 2.0 * (float)(a2*b2)/(float)(a2+b2);
93 
94 
95 for(int i=0;i<n-1;++i)
96  {
97 // std::cout << data[i].response << ',' << data[i].category << ' ';
98 
99  // look for transition
100  if(i!=0 && data[i-1].response != data[i].response)
101  {
102  float index = (float)(a1*b1)/(float)(a1+b1) + (float)(a2*b2)/(float)(a2+b2);
103 
104  // std::cout << i << ' ' << index << ';';
105  // std::cout << index << ' ' << a1 << ' ' << b1 << ' ' << a2 << ' ' << b2 << '\n';
106 
107  if(index <= out.second) // keep last
108  {
109 // std::cout << index << ' ' << a1 << ' ' << b1 << ' ' << a2 << ' ' << b2 << '\n';
110  out.second = index;
111  out.first = ( data[i-1].response + data[i].response ) / 2;
112  }
113  }
114 
115  if(data[i].category==-1) // B
116  { b1++,b2--; } // TODO: extend to AdaBoost case
117  else
118  { a1++,a2--; } // TODO: extend to AdaBoost case
119 
120 
121  }
122 // exit(1);
123 return out;
124 }
125 
126 template<class FeatureExtractor, class PatternType>
127 float TestThreshold(const std::vector<PatternType> & patterns, long stride, const FeatureExtractor & h, int th)
128 {
129  int a1,a2,b1,b2;
130  a1=a2=b1=b2 = 0;
131  for(int i = 0;i<patterns.size(); ++i)
132  {
133  int resp = h(patterns[i].data+stride+1, stride);
134  if(resp > th)
135  {
136  if(patterns[i].category == 1)
137  a2++;
138  else
139  b2++;
140  }
141  else
142  {
143  if(patterns[i].category == 1)
144  a1++;
145  else
146  b1++;
147  }
148  }
149 
150  std::cout << "Expected: " << a1 << ',' << b1 << ' ' << a2 << ',' << b2 << std::endl;
151 
152  return (float)(a1*b1)/(float)(a1+b1) + (float)(a2*b2)/(float)(a2+b2);
153 }
154 
155 template<class FeatureExtractor, class PatternType, class Param>
156 void build_response(DTResponse *resp, const PatternType * patterns, Param param, int n, const FeatureExtractor & op)
157 {
158  for(int i =0;i<n;++i)
159  {
160  resp[i].response = op.response(getData1(patterns[i], param), getData2(patterns[i], param) );
161  resp[i].category = patterns[i].category;
162  }
163 
164  // sorting
165  std::sort(&resp[0], &resp[n]);
166 }
167 
168 // static int min_tree_size = 4;
169 
171 template<class FeatureExtractor, class FeatureGenerator, class PatternType, class Param1>
172 bool BuildDecisionTree(DecisionTree<FeatureExtractor> &root, const std::vector<PatternType> & patterns, Param1 param, FeatureGenerator & f, int max_depth)
173 {
174  std::pair<int, float> out, best;
175  int n = patterns.size();
176  int A,B;
177 // float wA, wB;
178  A=B=0;
179 // wA = wB = 0.0f;
180 
181  f.Reset();
182 
183  // check
184  for(int i =0;i<n;++i)
185  {
186  if(patterns[i].category == 1)
187  {
188  A++;
189 // wA += patterns[i].GetWeightedCategory();
190  }
191  else
192  if(patterns[i].category == -1)
193  {
194  B++;
195 // wB += patterns[i].GetWeightedCategory();
196  }
197  else
198  throw std::runtime_error("wrong category label");
199  }
200 
201 // std::cout << "A:" << A << " B:" << B << std::endl;
202  float initial_index;
203 
204  if(max_depth == 0 || (A==0) || (B==0) /* || (A+B < min_tree_size) */ )
205  {
206  if(A==0 && B==0)
207  throw std::runtime_error("wrong splitting");
208 
209  std::cout << 'x';
210  std::cout.flush();
211 
212 // std::cout << "LEAF" << std::endl;
213  // root->category = (A>B) ? 1 : -1;
214  root.category = (float)(A-B)/(float)(A+B);
215 // root.category = (wA+wB)/(wA-wB);
216  return true;
217  }
218  else
219  {
220 
221  DTResponse *resp = new DTResponse[ n ] ;
222  typename FeatureGenerator::FeatureType h, bestH;
223  best.second = -1.0f;
224 
225  initial_index = (float)(A*B)/(float)(A+B);
226 
227  int iter = 0;
228  // std::cout << "Evaluate " << src.size() << " features" << std::endl;
229  while(f.Next(h))
230  {
231 
232  build_response(resp, &patterns[0], param, n, h);
233  out = FindThreshold<GiniIndex>(resp, n, A, B);
234 /*
235  float tst = TestThreshold<typename FeatureGenerator::FeatureType>(patterns, stride, h, out.first);
236  if(tst != out.second)
237  throw std::runtime_error("test failed");
238  */
239  if(iter == 0 || out.second < best.second)
240  {
241  best = out;
242  bestH = h;
243  }
244  ++iter;
245  }
246 
247  if(iter == 0)
248  throw std::runtime_error("Feature Generator missconfigured");
249 
250  if(best.second > initial_index)
251  {
252  std::cout << '?';
253  std::cout.flush();
254 
255 // std::cout << "WEAKLEAF" << std::endl;
256  // root->category = (A>B) ? 1 : -1;
257  root.category = (float)(A-B)/(float)(A+B);
258  // root.category = (wA+wB)/(wA-wB);
259  return true;
260  }
261 
262 // std::cout << "from: " << initial_index << '\n';
263 // std::cout << "index:" << best.second << " th:" << best.first << std::endl;
264 // float tst = TestThreshold<typename FeatureGenerator::FeatureType>(patterns, stride, bestH, best.first);
265 // std::cout << "\t" << tst << std::endl;
266 //
267  std::vector<PatternType> left, right;
268 
269  for(int i =0;i<n;++i)
270  {
271  // int resp = bestH(patterns[i].data+stride+1, stride);
272  int resp = bestH.response(getData1(patterns[i], param), getData2(patterns[i], param) );
273 
274  if(resp > best.first)
275  right.push_back(patterns[i]);
276  else
277  left.push_back(patterns[i]);
278 
279  }
280 
281  delete [] resp;
282 
283  root.classifier = bestH;
284 
285  root.th = best.first;
286 // std::cout << "{\n";
288  root.right = new DecisionTree<FeatureExtractor>();
289  BuildDecisionTree<FeatureExtractor>(*root.left, left, param, f, max_depth-1);
290 // std::cout << ",\n";
291  BuildDecisionTree<FeatureExtractor>(*root.right, right, param, f, max_depth-1);
292 // std::cout << "}\n";
293 
294  std::cout << '.';
295  std::cout.flush();
296 
297  return true;
298  }
299 }
300 
302 template<class Aggregator>
303 struct DecisionTreeLearner: public DataSetHandle<Aggregator > {
304 public:
305 
308 
309  DECLARE_AGGREGATOR
310 
311 public:
312 
314  int depth;
315 
316 public:
317 
318  DecisionTreeLearner() : depth(10) { }
319 
322  {
323  static_cast< DataSetHandleType &>(*this) = src;
324  }
325 
327  template<class FeatureExtractor, class FeatureGenerator>
329  {
330  return ::BuildDecisionTree<FeatureExtractor>(root, DataSetHandle<Aggregator>::templates, this->GetParams(), f, depth);
331  }
332 
334  template<class FeatureType>
335  void Test(const DecisionTree<FeatureType> & root)
336  {
337  ::Test(root, static_cast<const DataSetHandle< Aggregator > &>(*this) );
338  }
339 };
340 
341 #if 0
342 template<class DecisionTreeType, class DataSetType>
343 bool Prune(DecisionTreeType * root, const DataSetType & data, float threshold)
344 {
345  if(left!=0)
346  {
347  Prune(left, data, threshold);
348  Prune(right, data, threshold);
349  }
350  else
351  {
352  // is a leaf
353  return true;
354  }
355 
356  for(int i = 0;i<data.templates.size();i++)
357  {
358  float ret = (*root)(data.templates[i].data+data.width+1+1,data.width+1);
359  }
360 }
361 #endif
362 
363 
364 #endif
Feature FeatureType
The feature type generate by this generator.
Definition: FeatureGenerator.h:41
DataType th
Classifier Threshold.
Definition: DecisionTree.h:41
int depth
max depth
Definition: DecisionTreeLearner.h:314
DecisionTree< T, DataType > * left
sub-nodes (if null is a leaf)
Definition: DecisionTree.h:44
virtual void Reset()=0
reset any interal counters
float category
if there are not subnodes, this value -1 .. 1 return the class and probability
Definition: DecisionTree.h:47
void SetTrainingSet(const DataSetHandleType &src)
import a training set
Definition: DecisionTreeLearner.h:321
Definition: FeatureGenerator.h:36
A Decision Tree builder.
Definition: DecisionTreeLearner.h:303
bool BuildDecisionTree(DecisionTree< FeatureExtractor > &root, const std::vector< PatternType > &patterns, Param1 param, FeatureGenerator &f, int max_depth)
Definition: DecisionTreeLearner.h:172
Definition: DataSet.h:33
Definition: DataSet.h:50
discrete policy
Definition: DecisionTreeLearner.h:44
void Test(const DecisionTree< FeatureType > &root)
Test tree performance on the training set.
Definition: DecisionTreeLearner.h:335
Definition: DecisionTreeLearner.h:33
a FeatureExtractor return a scalar number without relationship with classification ...
Definition: Types.h:35
virtual bool Next(Feature &out)=0
return the next feature, or return false
bool BuildDecisionTree(DecisionTree< FeatureExtractor > &root, FeatureGenerator &f) const
Build a tree.
Definition: DecisionTreeLearner.h:328
implement the generic pattern object
T classifier
a feature extractor
Definition: DecisionTree.h:38
some function to test the classifier on the set
some metrics used when train decision trees
declare a DataSet
Definition: DecisionTree.h:34
weighted policy
Definition: DecisionTreeLearner.h:49