Cosan  1.0
Data Analytics Library
kfold.h
Go to the documentation of this file.
1 //
2 // Created by Xinyu Zhang on 4/6/21.
3 //
4 // TODO: omp_set_num_threads()
5 #ifndef COSAN_KFOLD_H
6 #define COSAN_KFOLD_H
8 namespace Cosan {
9  /**
10  * https://en.wikipedia.org/wiki/Cross-validation_(statistics)
11  */
12  class KFold : public Splitter {
13  public:
14  KFold() : Splitter() {}
15  KFold(gsl::index kfoldnumber) : Splitter(kfoldnumber) {}
16  KFold(gsl::index nrows, gsl::index kfoldnumber): Splitter(nrows,kfoldnumber){}
17  void SetSplit(gsl::index nrows) {
18  if (nrows <= KFoldNumber) {
19  throw SmallRows;
20  }
21  std::vector <gsl::index> idx(nrows);
22  std::iota(idx.begin(), idx.end(), 0);
23  gsl::index foldSize = nrows / KFoldNumber;
24  for (gsl::index i = 0; i < KFoldNumber; i++) {
25  std::vector <gsl::index> testidx(foldSize), trainidx;
26  if (i == KFoldNumber - 1) {
27  testidx.resize(nrows - i * foldSize);
28  }
29  fmt::print("Current Index is {:}, trainidx size:{:}, testidx size:{:}\n",
30  i, nrows - foldSize, foldSize);
31  std::iota(testidx.begin(), testidx.end(), i * foldSize);
32  std::set_difference(idx.begin(), idx.end(), testidx.begin(), testidx.end(),
33  std::inserter(trainidx, trainidx.begin()));
34 
35  split_batch.push_back({trainidx, testidx});
36 
37  }
38  }
39 
40 // void SetSplit(gsl::index nrows,gsl::index kfoldnumber = KFoldNumber ){
41 // if (nrows<=kfoldnumber){
42 // throw SmallRows;
43 // }
44 // KFoldNumber = kfoldnumber;
45 // std::vector<gsl::index> idx(nrows);
46 // std::iota(idx.begin(), idx.end(), 0);
47 // gsl::index foldSize = nrows/kfoldnumber;
48 // for (gsl::index i = 0;i<kfoldnumber;i++){
49 // std::vector<gsl::index> testidx(foldSize),trainidx;
50 // if (i==kfoldnumber-1){
51 // testidx.resize(nrows-i*foldSize);
52 // }
53 // fmt::print("Current Index is {:}, trainidx size:{:}, testidx size:{:}\n",
54 // i,nrows-foldSize,foldSize);
55 // std::iota(testidx.begin(), testidx.end(), i*foldSize);
56 // std::set_difference(idx.begin(), idx.end(), testidx.begin(), testidx.end(),
57 // std::inserter(trainidx, trainidx.begin()));
58 //
59 // split_batch.push_back({trainidx,testidx});
60 //
61 // }
62 // }
63  std::vector <std::tuple<std::vector < gsl::index>, std::vector<gsl::index>> > &
64 
65  GetSplit() { return split_batch; }
66 
67  private:
68  std::vector <std::tuple<std::vector < gsl::index>, std::vector<gsl::index>> >
70  };
71 
72 // class KFoldParallel : public Splitter {
73 // public:
74 // KFoldParallel() : Splitter() {
75 // KFoldNumber = 5;
76 // }
77 //
78 // KFoldParallel(gsl::index kfoldnumber) : Splitter() {
79 // KFoldNumber = kfoldnumber;
80 // }
81 //
82 // KFoldParallel(gsl::index nrows, gsl::index kfoldnumber) : Splitter(nrows, kfoldnumber) {
83 // KFoldNumber = kfoldnumber;
84 // SetSplit(nrows);
85 // }
86 //
87 // void SetSplit(gsl::index nrows) {
88 // if (nrows <= KFoldNumber) {
89 // throw SmallRows;
90 // }
91 // std::vector <gsl::index> idx(nrows);
92 // std::iota(idx.begin(), idx.end(), 0);
93 // gsl::index foldSize = nrows / KFoldNumber;
94 //#pragma omp parallel for
95 // for (gsl::index i = 0; i < KFoldNumber; i++) {
96 // std::vector <gsl::index> testidx(foldSize), trainidx;
97 // if (i == KFoldNumber - 1) {
98 // testidx.resize(nrows - i * foldSize);
99 // }
100 // fmt::print("Current Index is {:}, trainidx size:{:}, testidx size:{:}\n",
101 // i, nrows - foldSize, foldSize);
102 // std::iota(testidx.begin(), testidx.end(), i * foldSize);
103 // std::set_difference(idx.begin(), idx.end(), testidx.begin(), testidx.end(),
104 // std::inserter(trainidx, trainidx.begin()));
105 //
106 // split_batch.push_back({trainidx, testidx});
107 //
108 // }
109 // }
110 //
111 //// void SetSplit(gsl::index nrows,gsl::index kfoldnumber){
112 //// if (nrows<=kfoldnumber){
113 //// throw SmallRows;
114 //// }
115 //// KFoldNumber = kfoldnumber;
116 //// std::vector<gsl::index> idx(nrows);
117 //// std::iota(idx.begin(), idx.end(), 0);
118 //// gsl::index foldSize = nrows/kfoldnumber;
119 //// #pragma omp parallel for
120 //// for (gsl::index i = 0;i<kfoldnumber;i++){
121 //// std::vector<gsl::index> testidx(foldSize),trainidx;
122 //// if (i==kfoldnumber-1){
123 //// testidx.resize(nrows-i*foldSize);
124 //// }
125 //// fmt::print("Current Index is {:}, trainidx size:{:}, testidx size:{:}\n",
126 //// i,nrows-foldSize,foldSize);
127 //// std::iota(testidx.begin(), testidx.end(), i*foldSize);
128 //// std::set_difference(idx.begin(), idx.end(), testidx.begin(), testidx.end(),
129 //// std::inserter(trainidx, trainidx.begin()));
130 ////
131 //// split_batch.push_back({trainidx,testidx});
132 ////
133 //// }
134 // std::vector <std::tuple<std::vector < gsl::index>, std::vector<gsl::index>> > &
135 //
136 // GetSplit() { return split_batch; }
137 //
138 // private:
139 // std::vector <std::tuple<std::vector < gsl::index>, std::vector<gsl::index>> >
140 // split_batch;
141 // };
142 
143  class KFoldParallel : public Splitter {
144  public:
146  KFoldParallel(gsl::index kfoldnumber) : Splitter(kfoldnumber) {}
147  KFoldParallel(gsl::index nrows, gsl::index kfoldnumber): Splitter(nrows,kfoldnumber){}
148  void SetSplit(gsl::index nrows){
149  if (nrows <= KFoldNumber) {
150  throw SmallRows;
151  }
152  std::vector <gsl::index> idx(nrows);
153  std::iota(idx.begin(), idx.end(), 0);
154  gsl::index foldSize = nrows / KFoldNumber;
155 // std::mutex mylock;
156  split_batch.resize(KFoldNumber);
157  #pragma omp parallel for
158  for (gsl::index i = 0; i < KFoldNumber; i++) {
159  std::vector <gsl::index> testidx(foldSize), trainidx;
160  if (i == KFoldNumber - 1) {
161  testidx.resize(nrows - i * foldSize);
162  }
163  fmt::print(
164  "Current Index is {:}, the current thread num is {:}, total number of threads {:}. trainidx size:{:}, testidx size:{:}\n",
165  i, omp_get_thread_num(), omp_get_num_threads(), nrows - foldSize, foldSize);
166  std::iota(testidx.begin(), testidx.end(), i * foldSize);
167  std::set_difference(idx.begin(), idx.end(), testidx.begin(), testidx.end(),
168  std::inserter(trainidx, trainidx.begin()));
169  split_batch[i] = {trainidx, testidx};
170  }
171  }
172  std::vector <std::tuple<std::vector < gsl::index>, std::vector<gsl::index>> > & GetSplit() { return split_batch; }
173 
174  private:
175  std::vector <std::tuple<std::vector < gsl::index>, std::vector<gsl::index>> > split_batch;
176  };
177 
178 }
179 
180 
181 #endif //COSAN_KFOLD_H
selection.h
Cosan
Definition: CosanBO.h:29
Cosan::SmallRows
Cosan::TooSmallSizeException SmallRows
Cosan::KFold::KFold
KFold()
Definition: kfold.h:14
Cosan::KFold::split_batch
std::vector< std::tuple< std::vector< gsl::index >, std::vector< gsl::index > > > split_batch
Definition: kfold.h:69
Cosan::KFold::KFold
KFold(gsl::index nrows, gsl::index kfoldnumber)
Definition: kfold.h:16
Cosan::KFoldParallel
Definition: kfold.h:143
Cosan::KFold
Definition: kfold.h:12
Cosan::KFoldParallel::KFoldParallel
KFoldParallel(gsl::index kfoldnumber)
Definition: kfold.h:146
Cosan::KFoldParallel::GetSplit
std::vector< std::tuple< std::vector< gsl::index >, std::vector< gsl::index > > > & GetSplit()
Definition: kfold.h:172
Cosan::KFoldParallel::split_batch
std::vector< std::tuple< std::vector< gsl::index >, std::vector< gsl::index > > > split_batch
Definition: kfold.h:175
Cosan::KFold::KFold
KFold(gsl::index kfoldnumber)
Definition: kfold.h:15
Cosan::KFold::SetSplit
void SetSplit(gsl::index nrows)
Definition: kfold.h:17
Cosan::KFoldParallel::KFoldParallel
KFoldParallel(gsl::index nrows, gsl::index kfoldnumber)
Definition: kfold.h:147
Cosan::Splitter::KFoldNumber
gsl::index KFoldNumber
Definition: selection.h:49
Cosan::Splitter
Definition: selection.h:31
Cosan::KFoldParallel::SetSplit
void SetSplit(gsl::index nrows)
Definition: kfold.h:148
Cosan::KFoldParallel::KFoldParallel
KFoldParallel()
Definition: kfold.h:145
Cosan::KFold::GetSplit
std::vector< std::tuple< std::vector< gsl::index >, std::vector< gsl::index > > > & GetSplit()
Definition: kfold.h:65