Cosan  1.0
Data Analytics Library
randomgridsearch.h
Go to the documentation of this file.
1 //
2 // Created by Xinyu Zhang on 4/4/21.
3 //
4 
5 #ifndef COSAN_RANDOMGRIDSEARCH_H
6 #define COSAN_RANDOMGRIDSEARCH_H
7 
10 namespace Cosan{
11  template<typename NumericType,
12  Derived<CosanModel> Model,
13  Derived<CosanMetric<NumericType>> Metric,
14  Derived<Splitter> Split,
15  typename = typename std::enable_if<std::is_arithmetic<NumericType>::value,NumericType>::type>
16  class RandomGridSearch: public Search{
17  public:
18  RandomGridSearch() = delete;
20  Model & estimator,
21  Metric & metric,
22  Split & split,
23  const std::vector<NumericType> & paramGrid,long unsigned int nsamples= 100): Search() {
24  NumericType minError = std::numeric_limits<NumericType>::infinity();
25 
26  NumericType currError;
27  std::vector<NumericType> RandomChoice;
28  std::sample(paramGrid.begin(), paramGrid.end(), std::back_inserter(RandomChoice),
29  std::min({paramGrid.size(),nsamples}), std::mt19937{std::random_device{}()});
30 
31  decltype(bestParam) currParam;
32  for (gsl::index i = 0; i < RandomChoice.size(); ++i){
33  currParam = paramGrid[i];
34  estimator.SetParams(paramGrid[i]);
35  currError = crossValidation(CRD, estimator, metric, split);
36  if (currError < minError)
37  {
38  minError = currError;
39  bestParam = currParam;
40  }
41  }
42  }
43  auto GetBestParams(){return bestParam;}
44 
45  private:
47 
48 
49  template<typename NumericType,
50  Derived<CosanModel> Model,
51  Derived<CosanMetric<NumericType>> Metric,
52  Derived<Splitter> Split,
53  typename = typename std::enable_if<std::is_arithmetic<NumericType>::value,NumericType>::type>
55  public:
58  Model & estimator,
59  Metric & metric,
60  Split & split,
61  const std::vector<NumericType> & paramGrid,long unsigned int nsamples = 100,int nthreads = -1): Search() {
62  NumericType minError = std::numeric_limits<NumericType>::infinity();
63 
64  NumericType currError;
65  std::vector<NumericType> RandomChoice;
66  std::sample(paramGrid.begin(), paramGrid.end(), std::back_inserter(RandomChoice),
67  std::min({paramGrid.size(),nsamples}), std::mt19937{std::random_device{}()});
68 
69  std::vector<NumericType> allError(RandomChoice.size());
70  if (nthreads == -1){
71  omp_set_num_threads(omp_get_max_threads());
72  }
73  else{
74  omp_set_num_threads(nthreads);
75  }
76  #pragma omp parallel for
77  for (gsl::index i = 0; i < RandomChoice.size(); ++i){
78  estimator.SetParams(RandomChoice[i]);
79  allError[i] = crossValidation(CRD, estimator, metric, split);
80  }
81  bestParam =RandomChoice[std::distance(allError.begin(), std::min_element(allError.begin(), allError.end()))];
82 
83  }
84  auto GetBestParams(){return bestParam;}
85 
86  private:
88 
89 
90  template<typename NumericType,
91  Derived<CosanModel> Model,
92  Derived<CosanMetric<NumericType>> Metric,
93  Derived<Splitter> Split,
94  typename = typename std::enable_if<std::is_arithmetic<NumericType>::value,NumericType>::type>
96  public:
99  Model & estimator,
100  Metric & metric,
101  Split & split,
102  const std::vector<std::vector<NumericType>> & paramGrid,
103  long unsigned int nsamples = 100): Search() {
104  NumericType minError = std::numeric_limits<NumericType>::infinity();
105  NumericType currError;
106  std::vector<std::vector<NumericType>> RandomChoice;
107  std::sample(paramGrid.begin(), paramGrid.end(), std::back_inserter(RandomChoice),
108  std::min({paramGrid.size(),nsamples}), std::mt19937{std::random_device{}()});
109 
110  decltype(bestParam) currParam;
111 
112  for (gsl::index i = 0; i < RandomChoice.size(); ++i){
113  currParam = paramGrid[i];
114  estimator.SetParams(paramGrid[i]);
115  currError = crossValidation(CRD, estimator, metric, split);
116  if (currError < minError)
117  {
118  minError = currError;
119  bestParam = currParam;
120  }
121  }
122  }
123  auto GetBestParams(){return bestParam;}
124 
125  private:
126  std::vector<NumericType> bestParam;};
127 
128 
129  template<typename NumericType,
130  Derived<CosanModel> Model,
131  Derived<CosanMetric<NumericType>> Metric,
132  Derived<Splitter> Split,
133  typename = typename std::enable_if<std::is_arithmetic<NumericType>::value,NumericType>::type>
135  public:
138  Model & estimator,
139  Metric & metric,
140  Split & split,
141  const std::vector<std::vector<NumericType>> & paramGrid,
142  long unsigned int nsamples = 100,int nthreads = -1): Search() {
143  NumericType minError = std::numeric_limits<NumericType>::infinity();
144  NumericType currError;
145  std::vector<NumericType> RandomChoice;
146  std::sample(paramGrid.begin(), paramGrid.end(), std::back_inserter(RandomChoice),
147  std::min({paramGrid.size(),nsamples}), std::mt19937{std::random_device{}()});
148  std::vector<std::vector<NumericType>> allError(paramGrid.size());
149  if (nthreads == -1){
150  omp_set_num_threads(omp_get_max_threads());
151  }
152  else{
153  omp_set_num_threads(nthreads);
154  }
155  #pragma omp parallel for
156  for (gsl::index i = 0; i < paramGrid.size(); ++i){
157  estimator.SetParams(paramGrid[i]);
158  allError[i] = crossValidation(CRD, estimator, metric, split);
159  }
160  bestParam =paramGrid[std::distance(allError.begin(), std::min_element(allError.begin(), allError.end()))];
161  }
162  auto GetBestParams(){return bestParam;}
163 
164  private:
165  std::vector<NumericType> bestParam;};
166 
167 }
168 #endif //COSAN_RANDOMGRIDSEARCH_H
selection.h
Cosan
Definition: CosanBO.h:29
Cosan::RandomGridSearchMultiParallel::bestParam
std::vector< NumericType > bestParam
Definition: randomgridsearch.h:165
Cosan::RandomGridSearchParallel::RandomGridSearchParallel
RandomGridSearchParallel(CosanData< NumericType > &CRD, Model &estimator, Metric &metric, Split &split, const std::vector< NumericType > &paramGrid, long unsigned int nsamples=100, int nthreads=-1)
Definition: randomgridsearch.h:57
Cosan::RandomGridSearch
Definition: randomgridsearch.h:16
NumericType
double NumericType
Definition: onehotencodingTest.cpp:20
Cosan::RandomGridSearchMultiParallel::RandomGridSearchMultiParallel
RandomGridSearchMultiParallel(CosanData< NumericType > &CRD, Model &estimator, Metric &metric, Split &split, const std::vector< std::vector< NumericType >> &paramGrid, long unsigned int nsamples=100, int nthreads=-1)
Definition: randomgridsearch.h:137
Cosan::crossValidation
NumericType crossValidation(CosanData< NumericType > &CRD, Model &estimator, Metric &metric, Split &split)
Definition: crossvalidation.h:32
Cosan::RandomGridSearchMulti::RandomGridSearchMulti
RandomGridSearchMulti()=delete
Cosan::RandomGridSearchMultiParallel
Definition: randomgridsearch.h:134
Cosan::RandomGridSearch::GetBestParams
auto GetBestParams()
Definition: randomgridsearch.h:43
Cosan::RandomGridSearchParallel::bestParam
NumericType bestParam
Definition: randomgridsearch.h:87
Cosan::CosanData
Data container.
Definition: CosanData.h:546
Cosan::RandomGridSearchMultiParallel::RandomGridSearchMultiParallel
RandomGridSearchMultiParallel()=delete
Cosan::RandomGridSearchMulti::bestParam
std::vector< NumericType > bestParam
Definition: randomgridsearch.h:126
Cosan::Search
Definition: selection.h:27
crossvalidation.h
Cosan::RandomGridSearchMulti::GetBestParams
auto GetBestParams()
Definition: randomgridsearch.h:123
Cosan::RandomGridSearchMulti
Definition: randomgridsearch.h:95
Cosan::RandomGridSearch::RandomGridSearch
RandomGridSearch(CosanData< NumericType > &CRD, Model &estimator, Metric &metric, Split &split, const std::vector< NumericType > &paramGrid, long unsigned int nsamples=100)
Definition: randomgridsearch.h:19
Cosan::RandomGridSearch::bestParam
NumericType bestParam
Definition: randomgridsearch.h:46
Cosan::RandomGridSearch::RandomGridSearch
RandomGridSearch()=delete
Cosan::RandomGridSearchMultiParallel::GetBestParams
auto GetBestParams()
Definition: randomgridsearch.h:162
Cosan::RandomGridSearchMulti::RandomGridSearchMulti
RandomGridSearchMulti(CosanData< NumericType > &CRD, Model &estimator, Metric &metric, Split &split, const std::vector< std::vector< NumericType >> &paramGrid, long unsigned int nsamples=100)
Definition: randomgridsearch.h:98
Cosan::RandomGridSearchParallel
Definition: randomgridsearch.h:54
Cosan::RandomGridSearchParallel::GetBestParams
auto GetBestParams()
Definition: randomgridsearch.h:84
Cosan::RandomGridSearchParallel::RandomGridSearchParallel
RandomGridSearchParallel()=delete