Main Page · Modules · All Classes · Class Hierarchy
MAClassifier.hpp
1 /*
2  * This file is part of the AiBO+ project
3  *
4  * Copyright (C) 2005-2016 Csaba Kertész (csaba.kertesz@gmail.com)
5  *
6  * AiBO+ is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 2 of the License, or
9  * (at your option) any later version.
10  *
11  * AiBO+ is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Street #330, Boston, MA 02111-1307, USA.
19  *
20  */
21 
22 #pragma once
23 
24 #include "MAClassifierTypes.hpp"
25 
26 #include "core/MACoreTypes.hpp"
27 
28 #include <MCTypes.hpp>
29 
30 #include <boost/serialization/access.hpp>
31 #include <boost/serialization/split_member.hpp>
32 #include <boost/tuple/tuple.hpp>
33 #include <boost/scoped_ptr.hpp>
34 #include <boost/shared_ptr.hpp>
35 
36 class CvANN_MLP;
37 class CvEmWrapper;
38 class CvNormalBayesClassifier;
39 class CvKNearest;
40 class CvSVM;
41 class CvSVMParams;
42 class CvDTree;
43 class CvDTreeParams;
44 class CvERTrees;
45 class CvGBTrees;
46 class DlibWrapper;
47 class GpWrapper;
48 class ME_Model;
49 class LWPR_Object;
50 class PLS_Model;
51 class MCBinaryData;
52 class MAClassifier;
53 class MAModel;
54 class MARandomTrees;
55 
61 namespace MA
62 {
64 typedef boost::shared_ptr<MAClassifier> ClassifierSPtr;
65 
66 /*
67  * @brief Cross-validation results
68  *
69  * The first table list contains the confusion matrices of the testing results.
70  * The second table list contains the confusion matrices of the validation results.
71  * The third element is a list of sample ranking where higher value means better model
72  * building capability.
73  *
74  */
75 typedef boost::tuple<MA::FloatTableList, MA::FloatTableList, MC::FloatList> CvResultsType;
76 }
77 
82 {
83  friend class boost::serialization::access;
84 
86 
87 
94  MAClassifier();
95 
96 public:
97 
111  MAClassifier(MA::CRMethodType method, int label_count, bool regression = false);
112  ~MAClassifier();
113 
121  bool IsValid() const;
122 
130  MA::CRMethodType GetMethodType() const;
131 
139  bool IsRegression() const;
140 
148  unsigned int GetFeatureVectorSize() const;
149 
157  MC::FloatTable& GetFeatureVectors();
158 
166  MC::FloatList& GetLabels();
167 
175  MC::FloatList GetModelLabels() const;
176 
184  void PrioritizeClasses(const MC::FloatList& labels);
185 
192  void Reset();
193 
204  void SetParameter(MA::CRMethodParamType method_parameter, float value);
205 
215  float GetParameter(MA::CRMethodParamType method_parameter);
216 
226  void SetPreprocessingMode(MA::FeaturePreprocessingType preprocessing_mode);
227 
239  void AddSamples(const MC::FloatTable& input_vectors, const MC::FloatList& labels);
240 
253  MC::FloatList Predict(const MC::FloatTable& input_vectors, MC::FloatList& confidences);
254 
267  float Predict(const MC::FloatList& input_vector, MC::FloatList& confidence);
268 private:
269  float PredictReal(const MC::FloatList& input_vector, MC::FloatList& confidence);
270  void Train();
271 public:
272 
294  MA::CvResultsType CrossValidate(int iterations, MA::ClassifierCrossValidationType cv_type,
295  float cv_parameter, const MC::FloatTable& samples,
296  const MC::FloatList& labels,
297  const MC::FloatTable& validation_samples = MC::FloatTable(),
298  const MC::FloatList& validation_labels = MC::FloatList(),
299  float regression_accuracy = MCFloatInfinity());
300 
311  MCBinaryData* Encode() const;
312 
325  static MAClassifier* Decode(const MCBinaryData& data);
326 
337  MAModel* ExportModel() const;
338 
339 private:
340 
346  void CreateClassifier();
347 
348  template<class Archive>
349  void load(Archive& archive, const unsigned int version);
350  template<class Archive>
351  void save(Archive& archive, const unsigned int version) const;
352  BOOST_SERIALIZATION_SPLIT_MEMBER();
353 
355  MA::CRMethodType Classifier;
357  MA::FeaturePreprocessingType Preprocessing;
361  bool Trained;
363  unsigned int FeatureCount;
365  unsigned int LabelCount;
367  float SvmGamma;
369  float SvmC;
371  float SvmP;
373  float SvmNu;
377  float DlibEpsilon;
381  float RvmRbfGamma;
383  float KrrGamma;
385  float KrrLambda;
387  float MeL1;
389  float MeL2;
391  float LwprAlpha;
392  /*
393  * Components parameter (partial least squares regression)
394  * 0 is the default value which means autodetection
395  */
396  int PlsrComponents;
400  float KrlsGamma;
403  // Tree node sample limit (decision/random trees)
404  int TreeNodeSampleLimit;
408  boost::scoped_ptr<CvANN_MLP> NeuralNetwork;
410  boost::scoped_ptr<CvNormalBayesClassifier> BayesClassifier;
412  boost::scoped_ptr<CvKNearest> KNearestClassifier;
414  boost::scoped_ptr<CvSVM> SvmClassifier;
416  boost::scoped_ptr<CvSVMParams> SvmClassifierParams;
418  boost::scoped_ptr<CvDTree> DecisionTree;
420  boost::scoped_ptr<MARandomTrees> RandomTrees;
422  boost::scoped_ptr<CvERTrees> ExtremeRandomTrees;
424  boost::scoped_ptr<CvGBTrees> GradientBoostedTrees;
426  boost::scoped_ptr<CvEmWrapper> EmClassifier;
428  boost::scoped_ptr<DlibWrapper> DlibFunctions;
430  boost::scoped_ptr<ME_Model> MaxEntropy;
432  boost::scoped_ptr<LWPR_Object> Lwpr;
434  boost::scoped_ptr<PLS_Model> Plsr;
436  boost::scoped_ptr<GpWrapper> Gpr;
438  MC::FloatTable CachedSamples;
440  MC::FloatList CachedLabels;
442  MA::FloatSet CachedUniqueLabels;
444  MC::FloatTable PreprocessedData;
446  MC::FloatList PrioritizedClasses;
448  MC::FloatList PriorityLabels;
449 };
450 
float MeL1
L1 regularization parameter (maximum entropy)
float SvmC
C parameter (support vector machine)
boost::scoped_ptr< CvGBTrees > GradientBoostedTrees
Gradient boosted trees classifier.
MC::FloatTable CachedSamples
Cached training samples.
boost::scoped_ptr< CvSVMParams > SvmClassifierParams
Support vector machines classifier.
Binary data class.
boost::scoped_ptr< CvANN_MLP > NeuralNetwork
Neural network classifier.
float SvmGamma
γ parameter (support vector machine)
boost::scoped_ptr< LWPR_Object > Lwpr
Locally weighted projection regression.
boost::scoped_ptr< GpWrapper > Gpr
Gaussian process regression.
boost::scoped_ptr< PLS_Model > Plsr
Partial least squares regression.
float KrrGamma
Gamma parameter (kernel ridge regression/rbf kernel)
MA::FeaturePreprocessingType Preprocessing
Preprocessing type.
Classifier model based on OpenCV classifiers.
Definition: MAModel.hpp:50
float KrlsGamma
Gamma parameter (kernel recursive least squares/rbf kernel)
int RtMaxForestSize
Maximum number of trees in a forest (random/extremely randomized/gradient boosted trees) ...
boost::scoped_ptr< CvNormalBayesClassifier > BayesClassifier
Naive Bayes classifier.
#define MC_DISABLE_COPY(class_name)
Helper macro to disable the copy constructor and assignment operator of a class (object copying) ...
Definition: MCDefs.hpp:604
float RvmRbfGamma
Gamma parameter (relevance vector machine/rbf kernel)
unsigned int LabelCount
Label count.
boost::scoped_ptr< CvSVM > SvmClassifier
Support vector machines classifier.
MC::FloatList PrioritizedClasses
Prioritized classes for OpenCV classifiers.
MA::CRMethodType Classifier
Classifier type.
boost::scoped_ptr< ME_Model > MaxEntropy
Maximum entropy classifier.
boost::scoped_ptr< CvERTrees > ExtremeRandomTrees
Extremely randomized trees classifier.
unsigned int FeatureCount
Feature count.
boost::scoped_ptr< CvEmWrapper > EmClassifier
Expectation maximization classifier.
boost::scoped_ptr< CvKNearest > KNearestClassifier
K-nearest neighbor classifier.
float KrlsTolerance
Tolerance parameter (kernel recursive least squares)
MA::FloatSet CachedUniqueLabels
Cached unique labels.
float SvmNu
ν parameter (support vector machine)
bool Regression
Whether the instance is a normal classifier or regression is used.
float DlibEpsilon
Epsilon parameter (support/relevance vector machine in dlib)
boost::scoped_ptr< MARandomTrees > RandomTrees
Random trees classifier.
boost::scoped_ptr< DlibWrapper > DlibFunctions
Wrapper for Dlib classifier/regression functions.
boost::scoped_ptr< CvDTree > DecisionTree
Decision tree classifier.
float SvmP
P parameter (support vector machine)
bool SvmAutoTrain
Auto-train parameter (support vector machine)
int TreeMaxDepth
Maximum depth (decision/random/extremely randomized/gradient boosted trees)
float MCFloatInfinity()
Get float infinity.
Definition: MCDefs.cpp:110
float LwprAlpha
α parameter (locally weighted projection regression)
MC::FloatList PriorityLabels
Priors for OpenCV classifiers.
MC::FloatTable PreprocessedData
Cached preprocessed data.
float MeL2
L2 regularization parameter (maximum entropy)
bool Trained
Whether the classifier is already trained.
Common inferface for multiple classifiers and regression algorithms.
float KrrLambda
Lambda parameter (kernel ridge regression)
float RvmSigmoidGamma
Gamma parameter (relevance vector machine/sigmoid kernel)
MC::FloatList CachedLabels
Cached training labels.