X-Boost  2.3.8
AdaBoostTrainer.h
Go to the documentation of this file.
1 /* XBoost: Ada-Boost and Friends on Haar/ICF/HOG Features, Library and ToolBox
2  *
3  * Copyright (c) 2008-2014 Paolo Medici <medici@ce.unipr.it>
4  *
5  * This library is free software; you can redistribute it and/or
6  * modify it under the terms of the GNU Lesser General Public
7  * License as published by the Free Software Foundation; either
8  * version 2 of the License, or (at your option) any later version.
9  *
10  * This library is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13  * Lesser General Public License for more details.
14  *
15  * You should have received a copy of the GNU Lesser General Public
16  * License along with this library; if not, write to the
17  * Free Software Foundation, Inc., 59 Temple Place - Suite 330,
18  * Boston, MA 02111-1307, USA.
19  */
20 
21 #ifndef _ADABOOST_TRAINER_H
22 #define _ADABOOST_TRAINER_H
23 
29 #include "DataSet.h"
30 #include "Test.h"
31 #include "AdaBoost.h"
34 #include "Pattern/Pattern.h"
35 #include "AdaBoostCommon.h"
36 
37 #include <vector>
38 #include <cmath>
39 #include <algorithm>
40 #include <limits>
41 
47 template<class Oracle>
48 class AdaBoostTrainer: public AdaBoost< typename Oracle::ClassifierType >, public Oracle {
49 public:
50 
52 
55 
56 
59 
62 
65 
66 private:
67 
69  ClassifierType m_classifier;
70 
74  DataSetType m_validation_set;
75 
77  bool m_useMAdaBoost;
78 
79 public:
80  AdaBoostTrainer() : m_useMAdaBoost(false) { }
81  ~AdaBoostTrainer() { }
82 
84  void EnableMAdaBoost(bool mada)
85  {
86  m_useMAdaBoost = mada;
87  }
88 
90  static std::string signature() {
92  }
93 
94  template<class SourceDataSetType>
95  void SetValidationSet(const SourceDataSetType & set)
96  {
97  std::cout << "Update with: " << set.n_patternP << "(+), " << set.n_patternN << "(-)\n";
98 
99  m_validation_set = set;
100  //std::cin.get();
101  }
102 
104  const DataSetType & GetValidationSet() const {
105  return m_validation_set;
106  }
107 
108 
111  void ResetValidationSet(double positive_weight = -1.0) // se e' random conviene iniziare con -1.0
112  {
113  ResetWeight(m_validation_set, positive_weight);
114  }
115 
118  void Restart(bool reOptimize);
119 
122  return m_classifier;
123  }
124 
126  const ClassifierType &Classifier() const {
127  return m_classifier;
128  }
129 
134  bool Train();
135 
138  bool Test();
139 
141  ReportTest TestAndExportStat(double th, int max_concurrent_jobs);
142 };
143 
144 
146 
147 
148 template<class Oracle>
150 {
151  std::cout << m_classifier.size() << " classifiers:\n";
152 
153  return ::Test( m_classifier, m_validation_set);
154 }
155 template<class Oracle>
156 ReportTest AdaBoostTrainer<Oracle>::TestAndExportStat(double th, int max_concurrent_jobs)
157 {
158  std::cout << m_classifier.size() << " classifiers:\n";
159 
160  return ::TestAndExportStat( m_classifier, m_validation_set, th, max_concurrent_jobs);
161 }
162 
163 
164 template<class Oracle>
165 void AdaBoostTrainer<Oracle>::Restart(bool reOptimize)
166 {
167  ClassifierListType & list = m_classifier.list();
168 // for any Classifier call Post
169  for(typename ClassifierListType::iterator i = list.begin(); i != list.end();)
170  {
171  if(reOptimize)
172  {
173  // reoptimize (using validation? training?)
174  Oracle::Optimize(*i);
175  }
176 
177  // TODO MAdaBoost is missing?
178 
179  // Update AdaBoost weights on training/validation set
180  bool valid = this->Update(*i, Oracle::GetTrainingSet(), m_validation_set);
181 
182  if( !valid )
183  {
184  std::cout << "Bad classifier removed from list" << std::endl;
185  i = list.erase(i);
186  }
187  else
188  {
189  ++i;
190  }
191  }
192 
193  Test(); // nono ?
194 
195 }
196 
197 template<class Oracle>
199 {
200  WeakClassifierType bestH;
201 
202  if(m_useMAdaBoost)
203  {
204  // TODO: probabilmente non funziona
205 
206  // Modify the "Traning Set" using MAdaBoost variant
207  DataSetType & train = Oracle::GetTrainingSet();
208  double d = 1.0/m_validation_set.Size(); // MAdaBoost threshold
209  int n = 0;
210  for(unsigned int i =0; i<train.Size(); i++)
211  if( train.templates[i].d > d)
212  {
213  train.templates[i].d = d;
214  ++n;
215  }
216  std::cout << "[+] MAdaBoost: " << n << " samples of " << train.Size() << " (" << (n*100)/train.Size() << "%) have been limited in weight (" << d <<")" << std::endl;
217  }
218 
219 // get the best Weak Classifier using AdaBoost* metric on Training Set
220  if(!Oracle::GetHypothesis(bestH))
221  return false;
222 
223  // Call AdaBoost Update on validation_set. Eventually improve bestH using validation set.
224  // TODO: attenzione che se calcola alpha sul validation set, probabilmente si ottengono degli alpha negativi (anche in Madaboost)
225  bool valid = this->Update(bestH, Oracle::GetTrainingSet(), m_validation_set);
226 
227  if(valid)
228  {
229  // store
230  m_classifier.insert(bestH);
231  }
232  else
233  {
234  std::cout << "[WW] classifier skipped by update rule" << std::endl;
235  }
236 
237  return true;
238 }
239 
240 #endif
241 
void ResetWeight(DataSet &list, double priori_knownledge)
Definition: WeightedPattern.h:75
ReportTest TestAndExportStat(double th, int max_concurrent_jobs)
Definition: AdaBoostTrainer.h:156
bool Test(const ClassifierType &c, const DataType &data, double threshold=0.0)
Definition: Test.h:156
a Voting Boostable classifier
const DataSetType & GetValidationSet() const
return the validation Set
Definition: AdaBoostTrainer.h:104
ClassifierType & Classifier()
Direct Access to Final Classifier.
Definition: AdaBoostTrainer.h:121
ClassifierType
Definition: Types.h:31
static std::string signature()
the internal signature of classifier (not the trainer)
Definition: AdaBoostTrainer.h:90
Definition: AdaBoostTrainer.h:48
unsigned int Size() const
Return number of allocated samples (complete size of DataSet)
Definition: DataSet.h:101
a file containing misc type
BoostClassifier< WeakClassifierType > ClassifierType
The Final Strong Classifier as Additive Model.
Definition: AdaBoostTrainer.h:61
ListType templates
a collection of Pattern used in this dataset
Definition: DataSet.h:73
void ResetValidationSet(double positive_weight=-1.0)
Definition: AdaBoostTrainer.h:111
void Restart(bool reOptimize)
Definition: AdaBoostTrainer.h:165
implement the generic pattern object
const ClassifierType & Classifier() const
Direct Access to Final Classifier.
Definition: AdaBoostTrainer.h:126
ClassifierType::ClassifierListType ClassifierListType
The internal list, inside BoostClassifier.
Definition: AdaBoostTrainer.h:64
some function to test the classifier on the set
ReportTest TestAndExportStat(const ClassifierType &c, const DataType &data, double threshold, int max_concurrent_jobs)
Definition: Test.h:265
a voting for majority classifier.
Definition: Test.h:39
Definition: AdaBoost.h:32
Oracle::ClassifierType WeakClassifierType
Weak Classifier Type, reported by Oracle.
Definition: AdaBoostTrainer.h:58
static std::string signature()
propagate signature
Definition: BoostClassifier.h:133
declare a DataSet
DataSetHandle< AdaBoostPattern > DataSetType
DataSet required by AdaBoostTrainer.
Definition: AdaBoostTrainer.h:54
Help declaring some traits.
Definition: Aggregator.h:30
std::vector< BoostedClassifierType > ClassifierListType
List of Boosted classifier.
Definition: BoostClassifier.h:69
void EnableMAdaBoost(bool mada)
Set MAdaBoost Variant.
Definition: AdaBoostTrainer.h:84
bool Train()
Definition: AdaBoostTrainer.h:198
bool Test()
Definition: AdaBoostTrainer.h:149