Cosan  1.0
Data Analytics Library
gridsearch.h
Go to the documentation of this file.
1 #ifndef TUNING_H
2 #define TUNING_H
3 
4 // TODO: change after integrate with module
5 
6 // import from Cosan
7 // TODO #include<> cross validate here
10 namespace Cosan
11 {
12  /**
13  * Hyperparameter tuning for supervised models that have one or more hyperparameter(s) to tune
14  * Input required:
15  * estimator: class Model&, a model whose hyperparameters need to be tuned;
16  * metric: class Metric&, a metric to use in cross-validation
17  * split: Split & splitter method;*
18  * paramGrid: a vector of hyperparameters combination. For each entry of the vector, it corresponds to one choice of hyperparameter combination;
19  * Ouput: call .GetBestParams() to get the best hyperparameters combination.
20  * the choice of the hyper-parameter in paramGrid that forms the most accurate model
21  **/
22  template<Numeric NumericType,
23  Derived<CosanModel> Model,
24  Derived<CosanMetric<NumericType>> Metric,
25  Derived<Splitter> Split>
26  class GridSearch: public Search{
27  public:
28  GridSearch() = delete;
30  Model & estimator,
31  Metric & metric,
32  Split & split,
33  const std::vector<NumericType> & paramGrid): Search() {
34  NumericType minError = std::numeric_limits<NumericType>::infinity();
35  NumericType currError;
36  decltype(bestParam) currParam;
37  for (gsl::index i = 0; i < paramGrid.size(); ++i){
38  currParam = paramGrid[i];
39  estimator.SetParams(paramGrid[i]);
40  currError = crossValidation(CRD, estimator, metric, split);
41  if (currError < minError)
42  {
43  minError = currError;
44  bestParam = currParam;
45  }
46  }
47  }
48  auto GetBestParams(){return bestParam;}
49 
50  private:
52 
53  /**
54  * @details parallel version for grid search
55  * Hyperparameter tuning for supervised models that have one or more hyperparameter(s) to tune
56  * Input required:
57  * estimator: class Model&, a model whose hyperparameters need to be tuned;
58  * metric: class Metric&, a metric to use in cross-validation
59  * split: Split & splitter method;*
60  * paramGrid: a vector of hyperparameters combination. For each entry of the vector, it corresponds to one choice of hyperparameter combination;
61  * nthreads: int. number of threads to be used for parallel computing.
62  * Ouput: call .GetBestParams() to get the best hyperparameters combination.
63  * the choice of the hyper-parameter in paramGrid that forms the most accurate model
64  **/
65  template<Numeric NumericType,
66  Derived<CosanModel> Model,
67  Derived<CosanMetric<NumericType>> Metric,
68  Derived<Splitter> Split>
69  class GridSearchParallel: public Search{
70  public:
71  GridSearchParallel() = delete;
72 // GridSearch(const std::variant<CosanRawData<NumericType>,CosanData<NumericType>> &CRD,
73 // Model & estimator,
74 // Metric & metric,
75 // Split & split,
76 // const std::vector<std::variant<NumericType,std::vector<NumericType>>> & paramGrid): Selection() {
78  Model & estimator,
79  Metric & metric,
80  Split & split,
81  const std::vector<NumericType> & paramGrid, int nthreads = -1): Search() {
82  NumericType minError = std::numeric_limits<NumericType>::infinity();
83  std::vector<NumericType> allError(paramGrid.size());
84  if (nthreads == -1){
85  omp_set_num_threads(omp_get_max_threads());
86  }
87  else{
88  omp_set_num_threads(nthreads);
89  }
90  #pragma omp parallel for
91  for (gsl::index i = 0; i < paramGrid.size(); ++i){
92  estimator.SetParams(paramGrid[i]);
93  allError[i] = crossValidation(CRD, estimator, metric, split);
94  }
95  bestParam =paramGrid[std::distance(allError.begin(), std::min_element(allError.begin(), allError.end()))];
96  }
97  auto GetBestParams(){return bestParam;}
98 
99  private:
101 
102 
103 // template<typename NumericType,
104 // Derived<CosanModel> Model,
105 // Derived<CosanMetric<NumericType>> Metric,
106 // Derived<Splitter> Split,
107 // typename = typename std::enable_if<std::is_arithmetic<NumericType>::value,NumericType>::type>
108  template<Numeric NumericType,
109  Derived<CosanModel> Model,
110  Derived<CosanMetric<NumericType>> Metric,
111  Derived<Splitter> Split>
112  class GridSearchMulti: public Search{
113  public:
114  GridSearchMulti() = delete;
115 // GridSearch(const std::variant<CosanRawData<NumericType>,CosanData<NumericType>> &CRD,
116 // Model & estimator,
117 // Metric & metric,
118 // Split & split,
119 // const std::vector<std::variant<NumericType,std::vector<NumericType>>> & paramGrid): Selection() {
121  Model & estimator,
122  Metric & metric,
123  Split & split,
124  const std::vector<std::vector<NumericType>> & paramGrid): Search() {
125  NumericType minError = std::numeric_limits<NumericType>::infinity();
126  NumericType currError;
127  decltype(bestParam) currParam;
128 
129  for (gsl::index i = 0; i < paramGrid.size(); ++i){
130  currParam = paramGrid[i];
131  estimator.SetParams(paramGrid[i]);
132  currError = crossValidation(CRD, estimator, metric, split);
133  if (currError < minError)
134  {
135  minError = currError;
136  bestParam = currParam;
137  }
138  }
139  }
140  auto GetBestParams(){return bestParam;}
141 
142  private:
143  std::vector<NumericType> bestParam;};
144 
145 // template<typename NumericType,
146 // Derived<CosanModel> Model,
147 // Derived<CosanMetric<NumericType>> Metric,
148 // Derived<Splitter> Split,
149 // typename = typename std::enable_if<std::is_arithmetic<NumericType>::value,NumericType>::type>
150  template<Numeric NumericType,
151  Derived<CosanModel> Model,
152  Derived<CosanMetric<NumericType>> Metric,
153  Derived<Splitter> Split>
155  public:
157 // GridSearch(const std::variant<CosanRawData<NumericType>,CosanData<NumericType>> &CRD,
158 // Model & estimator,
159 // Metric & metric,
160 // Split & split,
161 // const std::vector<std::variant<NumericType,std::vector<NumericType>>> & paramGrid): Selection() {
163  Model & estimator,
164  Metric & metric,
165  Split & split,
166  const std::vector<std::vector<NumericType>> & paramGrid,
167  int nthreads = -1): Search() {
168  NumericType minError = std::numeric_limits<NumericType>::infinity();
169  std::vector<NumericType> allError(paramGrid.size());
170  if (nthreads == -1){
171  omp_set_num_threads(omp_get_max_threads());
172  }
173  else{
174  omp_set_num_threads(nthreads);
175  }
176  #pragma omp parallel for
177  for (gsl::index i = 0; i < paramGrid.size(); ++i){
178  estimator.SetParams(paramGrid[i]);
179  allError[i] = crossValidation(CRD, estimator, metric, split);
180  }
181  bestParam =paramGrid[std::distance(allError.begin(), std::min_element(allError.begin(), allError.end()))];
182  }
183  auto GetBestParams(){return bestParam;}
184 
185  private:
186  std::vector<NumericType> bestParam;};
187 
188 }
189 #endif
Cosan::GridSearchParallel::GridSearchParallel
GridSearchParallel()=delete
selection.h
Cosan
Definition: CosanBO.h:29
Cosan::GridSearch::bestParam
NumericType bestParam
Definition: gridsearch.h:51
Cosan::GridSearch
Definition: gridsearch.h:26
NumericType
double NumericType
Definition: onehotencodingTest.cpp:20
Cosan::GridSearchMulti::bestParam
std::vector< NumericType > bestParam
Definition: gridsearch.h:143
Numeric
concept Numeric
Definition: CosanBO.h:23
Cosan::GridSearchParallel::bestParam
NumericType bestParam
Definition: gridsearch.h:100
Cosan::GridSearch::GridSearch
GridSearch()=delete
Cosan::GridSearch::GridSearch
GridSearch(CosanData< NumericType > &CRD, Model &estimator, Metric &metric, Split &split, const std::vector< NumericType > &paramGrid)
Definition: gridsearch.h:29
Cosan::GridSearchMultiParallel::GetBestParams
auto GetBestParams()
Definition: gridsearch.h:183
Cosan::crossValidation
NumericType crossValidation(CosanData< NumericType > &CRD, Model &estimator, Metric &metric, Split &split)
Definition: crossvalidation.h:32
Cosan::GridSearchMultiParallel::GridSearchMultiParallel
GridSearchMultiParallel(CosanData< NumericType > &CRD, Model &estimator, Metric &metric, Split &split, const std::vector< std::vector< NumericType >> &paramGrid, int nthreads=-1)
Definition: gridsearch.h:162
Cosan::GridSearchMulti::GridSearchMulti
GridSearchMulti()=delete
Cosan::CosanData
Data container.
Definition: CosanData.h:546
Cosan::GridSearchMultiParallel::bestParam
std::vector< NumericType > bestParam
Definition: gridsearch.h:186
Cosan::GridSearchMultiParallel::GridSearchMultiParallel
GridSearchMultiParallel()=delete
Cosan::Search
Definition: selection.h:27
Cosan::GridSearchMulti::GetBestParams
auto GetBestParams()
Definition: gridsearch.h:140
crossvalidation.h
Cosan::GridSearchParallel::GridSearchParallel
GridSearchParallel(CosanData< NumericType > &CRD, Model &estimator, Metric &metric, Split &split, const std::vector< NumericType > &paramGrid, int nthreads=-1)
Definition: gridsearch.h:77
Cosan::GridSearchMulti::GridSearchMulti
GridSearchMulti(CosanData< NumericType > &CRD, Model &estimator, Metric &metric, Split &split, const std::vector< std::vector< NumericType >> &paramGrid)
Definition: gridsearch.h:120
Cosan::GridSearch::GetBestParams
auto GetBestParams()
Definition: gridsearch.h:48
Cosan::GridSearchMulti
Definition: gridsearch.h:112
Cosan::GridSearchParallel::GetBestParams
auto GetBestParams()
Definition: gridsearch.h:97
Cosan::GridSearchParallel
Definition: gridsearch.h:69
Cosan::GridSearchMultiParallel
Definition: gridsearch.h:154