Cosan  1.0
Data Analytics Library
crossvalidation.h
Go to the documentation of this file.
1 #ifndef CROSSVALIDATE_H
2 #define CROSSVALIDATE_H
3 //#include<cosan/selection/KFold.h>
4 //#include<cosan/model/AllModels.h>
5 //#include<cosan/utils/CosanMetric.h>
6 #include <variant>
7 #include <vector>
9 namespace Cosan
10 {
11  /*
12  Conduct Cross Validation on a given model
13  Inputs:
14  model: CosanModel&, the reference to the model to be validated;
15  metric: from CosanMetric class;
16  CRD: const std::variant<CosanRawData<NumericType>;
17  split: from splittor class.
18 // */
19 // NumericType crossValidation(const std::variant<CosanRawData<NumericType>,CosanData<NumericType>> &CRD,
20 // Model& estimator,
21 // Metric & metric,
22 // Split & split){
23 // template<typename NumericType,
24 // Derived<CosanModel> Model,
25 // Derived<CosanMetric<NumericType>> Metric,
26 // Derived<Splitter> Split,
27 // typename = typename std::enable_if<std::is_arithmetic<NumericType>::value,NumericType>::type >
28  template<Numeric NumericType,
29  Derived<CosanModel> Model,
30  Derived<CosanMetric<NumericType>> Metric,
31  Derived<Splitter> Split>
33  Model& estimator,
34  Metric & metric,
35  Split & split){
36  split.SetSplit(CRD.GetrowsX());
39  std::vector<NumericType> errors;
40  errors.resize(split.GetKFoldNumber());
41  std::vector< std::tuple<std::vector<gsl::index>,std::vector<gsl::index> > > split_idx = split.GetSplit();
42 // #pragma omp parallel for
43  for (gsl::index i =0;i<split.GetKFoldNumber();i++){
44  auto & each = split_idx[i];
45  CosanMatrix<NumericType> X_train = X(std::get<0>(each),Eigen::all),Y_train = Y(std::get<0>(each),Eigen::all);
46  CosanMatrix<NumericType> X_test = X(std::get<1>(each),Eigen::all),Y_test = Y(std::get<1>(each),Eigen::all);
47  estimator.fit(X_train, Y_train);
48  errors[i] = metric.GetError(estimator.predict(X_test ),Y_test);
49  }
50  return std::accumulate(errors.begin(), errors.end(), 0)/errors.size();
51  };
52 
53 // template<typename NumericType,
54 // Derived<CosanModel> Model,
55 // Derived<CosanMetric<NumericType>> Metric,
56 // Derived<Splitter> Split,
57 // typename = typename std::enable_if<std::is_arithmetic<NumericType>::value,NumericType>::type >
58  template<Numeric NumericType,
59  Derived<CosanModel> Model,
60  Derived<CosanMetric<NumericType>> Metric,
61  Derived<Splitter> Split>
63  Model& estimator,
64  Metric & metric,
65  Split & split, int nthreads = -1){
66  split.SetSplit(CRD.GetrowsX());
69  std::vector<NumericType> errors;
70  errors.resize(split.GetKFoldNumber());
71  std::vector< std::tuple<std::vector<gsl::index>,std::vector<gsl::index> > > split_idx = split.GetSplit();
72  if (nthreads == -1){
73  omp_set_num_threads(omp_get_max_threads());
74  }
75  else{
76  omp_set_num_threads(nthreads);
77  }
78  #pragma omp parallel for
79  for (gsl::index i =0;i<split.GetKFoldNumber();i++){
80  auto & each = split_idx[i];
81  CosanMatrix<NumericType> X_train = X(std::get<0>(each),Eigen::all),Y_train = Y(std::get<0>(each),Eigen::all);
82  CosanMatrix<NumericType> X_test = X(std::get<1>(each),Eigen::all),Y_test = Y(std::get<1>(each),Eigen::all);
83  estimator.fit(X_train, Y_train);
84  errors[i] = metric.GetError(estimator.predict(X_test ),Y_test);
85  }
86  return std::accumulate(errors.begin(), errors.end(), 0)/errors.size();
87  };
88 
89 
90 }
91 
92 #endif
selection.h
Cosan
Definition: CosanBO.h:29
Cosan::CosanRawData::GetInput
CosanMatrix< NumericType > GetInput()
Get a copy of CosanMatrix<NumericType> X.
Definition: CosanData.h:141
Cosan::CosanRawData::GetTarget
CosanMatrix< NumericType > GetTarget()
Get a copy of CosanMatrix<NumericType> Y.
Definition: CosanData.h:147
NumericType
double NumericType
Definition: onehotencodingTest.cpp:20
Numeric
concept Numeric
Definition: CosanBO.h:23
Cosan::crossValidation
NumericType crossValidation(CosanData< NumericType > &CRD, Model &estimator, Metric &metric, Split &split)
Definition: crossvalidation.h:32
Cosan::CosanMatrix
Eigen::Matrix< NumericType, Eigen::Dynamic, Eigen::Dynamic > CosanMatrix
Definition: CosanBO.h:37
Cosan::CosanData
Data container.
Definition: CosanData.h:546
Cosan::crossValidationParallel
NumericType crossValidationParallel(CosanData< NumericType > &CRD, Model &estimator, Metric &metric, Split &split, int nthreads=-1)
Definition: crossvalidation.h:62
Cosan::CosanRawData::GetrowsX
gsl::index GetrowsX()
Get the number of rows for X.
Definition: CosanData.h:254