21 #ifndef _DECISION_TREE_LEARNER_H
22 #define _DECISION_TREE_LEARNER_H
30 #include "Classifier/DecisionTree.h"
40 return a.response < b.response;
54 std::pair<int, float> FindThreshold(
const DTResponse *data,
int n,
int A,
int B);
72 std::pair<int, float> FindThreshold<GiniIndex>(
const DTResponse *data,
int n,
int A,
int B)
76 std::pair<int, float> out;
92 out.second = 2.0 * (float)(a2*b2)/(float)(a2+b2);
95 for(
int i=0;i<n-1;++i)
100 if(i!=0 && data[i-1].response != data[i].response)
102 float index = (float)(a1*b1)/(float)(a1+b1) + (float)(a2*b2)/(float)(a2+b2);
107 if(index <= out.second)
111 out.first = ( data[i-1].response + data[i].response ) / 2;
115 if(data[i].category==-1)
126 template<
class FeatureExtractor,
class PatternType>
127 float TestThreshold(
const std::vector<PatternType> & patterns,
long stride,
const FeatureExtractor & h,
int th)
131 for(
int i = 0;i<patterns.size(); ++i)
133 int resp = h(patterns[i].data+stride+1, stride);
136 if(patterns[i].category == 1)
143 if(patterns[i].category == 1)
150 std::cout <<
"Expected: " << a1 <<
',' << b1 <<
' ' << a2 <<
',' << b2 << std::endl;
152 return (
float)(a1*b1)/(
float)(a1+b1) + (
float)(a2*b2)/(
float)(a2+b2);
155 template<
class FeatureExtractor,
class PatternType,
class Param>
158 for(
int i =0;i<n;++i)
160 resp[i].response = op.response(getData1(patterns[i], param), getData2(patterns[i], param) );
161 resp[i].category = patterns[i].category;
165 std::sort(&resp[0], &resp[n]);
171 template<
class FeatureExtractor,
class FeatureGenerator,
class PatternType,
class Param1>
174 std::pair<int, float> out, best;
175 int n = patterns.size();
184 for(
int i =0;i<n;++i)
186 if(patterns[i].category == 1)
192 if(patterns[i].category == -1)
198 throw std::runtime_error(
"wrong category label");
204 if(max_depth == 0 || (A==0) || (B==0) )
207 throw std::runtime_error(
"wrong splitting");
214 root.
category = (float)(A-B)/(float)(A+B);
225 initial_index = (float)(A*B)/(float)(A+B);
232 build_response(resp, &patterns[0], param, n, h);
233 out = FindThreshold<GiniIndex>(resp, n, A, B);
239 if(iter == 0 || out.second < best.second)
248 throw std::runtime_error(
"Feature Generator missconfigured");
250 if(best.second > initial_index)
257 root.
category = (float)(A-B)/(float)(A+B);
267 std::vector<PatternType> left, right;
269 for(
int i =0;i<n;++i)
272 int resp = bestH.response(getData1(patterns[i], param), getData2(patterns[i], param) );
274 if(resp > best.first)
275 right.push_back(patterns[i]);
277 left.push_back(patterns[i]);
285 root.
th = best.first;
289 BuildDecisionTree<FeatureExtractor>(*root.
left, left, param, f, max_depth-1);
291 BuildDecisionTree<FeatureExtractor>(*root.right, right, param, f, max_depth-1);
302 template<
class Aggregator>
327 template<
class FeatureExtractor,
class FeatureGenerator>
334 template<
class FeatureType>
342 template<
class DecisionTreeType,
class DataSetType>
343 bool Prune(DecisionTreeType * root,
const DataSetType & data,
float threshold)
347 Prune(left, data, threshold);
348 Prune(right, data, threshold);
356 for(
int i = 0;i<data.templates.size();i++)
358 float ret = (*root)(data.templates[i].data+data.width+1+1,data.width+1);
Feature FeatureType
The feature type generate by this generator.
Definition: FeatureGenerator.h:41
DataType th
Classifier Threshold.
Definition: DecisionTree.h:41
int depth
max depth
Definition: DecisionTreeLearner.h:314
DecisionTree< T, DataType > * left
sub-nodes (if null is a leaf)
Definition: DecisionTree.h:44
virtual void Reset()=0
reset any interal counters
float category
if there are not subnodes, this value -1 .. 1 return the class and probability
Definition: DecisionTree.h:47
void SetTrainingSet(const DataSetHandleType &src)
import a training set
Definition: DecisionTreeLearner.h:321
Definition: FeatureGenerator.h:36
A Decision Tree builder.
Definition: DecisionTreeLearner.h:303
bool BuildDecisionTree(DecisionTree< FeatureExtractor > &root, const std::vector< PatternType > &patterns, Param1 param, FeatureGenerator &f, int max_depth)
Definition: DecisionTreeLearner.h:172
discrete policy
Definition: DecisionTreeLearner.h:44
void Test(const DecisionTree< FeatureType > &root)
Test tree performance on the training set.
Definition: DecisionTreeLearner.h:335
Definition: DecisionTreeLearner.h:33
a FeatureExtractor return a scalar number without relationship with classification ...
Definition: Types.h:35
virtual bool Next(Feature &out)=0
return the next feature, or return false
bool BuildDecisionTree(DecisionTree< FeatureExtractor > &root, FeatureGenerator &f) const
Build a tree.
Definition: DecisionTreeLearner.h:328
implement the generic pattern object
T classifier
a feature extractor
Definition: DecisionTree.h:38
some function to test the classifier on the set
Definition: DecisionTree.h:34
weighted policy
Definition: DecisionTreeLearner.h:49