00001
00002
00003
00004
00005
00006
#ifndef __PROM_BASE_H__
00007
#define __PROM_BASE_H__
00008
00009
typedef enum{PROMMGMULTIPLICATIVE=0,PROMMGFULL,PROMMGADDITIVE}PromMGAlgo;
00010
typedef enum{PROMKSPGMRES,PROMKSPRICHARDSON,PROMKSPBCGS,PROMKSPCR,PROMKSPTFQMR,PROMKSPPREONLY,PROMKSPCG,PROMKSPCHEB,PROMKSPBREZ,PROMKSPFGMRES,PROMKSPSYMMLQ,PROMKSPMINRES,PROMNUMKSP}PromKSPType;
00011
typedef enum{PROMPCILU,PROMPCLU,PROMPCJACOBI,PROMPCASM,PROMPCNONE,PROMPCMG,PROMPCGAUSSSEIDEL,PROMPCSYMGAUSSSEIDEL,PROMPCNODALASM,PROMNUMPC}PromPCType;
00012
00014 class PromComm_base
00015 {
00016
public:
00017
virtual MPI_Comm operator()()
const = 0;
00018
virtual MPI_Comm MPIComm()
const = 0;
00019
virtual int mype()
const = 0;
00020
virtual int np()
const = 0;
00021
virtual int getNewTag()
const = 0;
00022
virtual int getNewTag(MPI_Comm comm)
const = 0;
00023 };
00024
00026
#undef __FUNCT__
00027
#define __FUNCT__ "PromMap::PromMap_base"
00028
class PromMap_base
00029 {
00030
friend class PromMatrix;
00031
protected:
00032 PromMap_base(
int nloc,
int nloceq,
const PromComm_base &comm );
00033
public:
00034
int ResetMap(
int nloc,
int nloceq ) {
00035
int ii;
00036 MPI_Comm mpi = comm_(); assert(mpi!=MPI_COMM_NULL);
int np = comm_.np();
00037
00038 MPI_Allgather( &nloc, 1, MPI_INT, proc_gnode_+1, 1, MPI_INT, mpi );
00039
for( ii = proc_gnode_[0] = 0 ; ii < np ; ii++ )
00040 proc_gnode_[ii+1] = proc_gnode_[ii+1] + proc_gnode_[ii];
00041
00042 MPI_Allgather( &nloceq, 1, MPI_INT, proc_geq_+1, 1, MPI_INT, mpi );
00043
for( ii = proc_geq_[0] = 0 ; ii < np ; ii++ )
00044 proc_geq_[ii+1] = proc_geq_[ii+1] + proc_geq_[ii];
00045
return 0;
00046 }
00047
virtual ~PromMap_base(){
00048 PetscFree( proc_gnode_ ); PetscFree( proc_geq_ );
00049 proc_gnode_ = proc_geq_ = NULL;
00050 }
00051
virtual int nlocal() const = 0;
00052 virtual
int nglobal() const = 0;
00053 virtual
int nlocalEq() const = 0;
00054 virtual
int nglobalEq() const = 0;
00055 virtual
bool isConstNDF() const = 0;
00056 virtual
int constNDFSize() const = 0;
00057 virtual
int locID_ndf(const
int lid) const = 0;
00058
bool operator==(const PromMap_base &map)const{
00059 assert(
this!=NULL);
00060
return (nlocal()==map.nlocal() && nglobal()==map.nglobal() &&
00061 nlocalEq()==map.nlocalEq() && nglobalEq()==map.nglobalEq() );
00062 }
00063
bool operator!=(
const PromMap_base &map)
const{
00064 assert(
this!=NULL);
00065
return (nlocal()!=map.nlocal() || nglobal()!=map.nglobal() ||
00066 nlocalEq()!=map.nlocalEq() || nglobalEq()!=map.nglobalEq() );
00067 }
00068
00069
int my0Eq()const{
return proc_geq_[comm_.mype()]; }
00070
int my0Nd()const{
return proc_gnode_[comm_.mype()]; }
00071
const int * proc_gnode()
const {
return proc_gnode_; }
00072
const int * proc_geq()
const {
return proc_geq_; }
00073
protected:
00074
00075
int *proc_gnode_;
00076
int *proc_geq_;
00077
public:
00078
const PromComm_base &comm_;
00079 };
00080
00081
class PromMap;
00082
class PromCRVector;
00083
class PromSolver;
00085
00088 class PromVector_base
00089 {
00090
friend class PromSolver;
00091
friend class PromCRVector;
00092
friend class PromVector;
00093
friend class PromProjKKTSolver;
00094
public:
00095
PromVector_base(
const PromMap &RowMap) : map_(RowMap) {}
00096
virtual ~
PromVector_base(){}
00097
00098
virtual int GetArray(
double **arr )
const = 0;
00099
virtual int RestoreArray(
double **arr )
const = 0;
00100
virtual int Set(
double val ) = 0;
00101
virtual int Scale(
double val ) = 0;
00102
virtual int Sqrt() = 0;
00103
virtual int Dot(
const PromVector_base*
const b,
double*
const dot)
const=0;
00104
virtual int SetValues(
const int nrow,
int *rowp,
const double *vals,
00105
const int add_type) = 0;
00106
virtual int Norm2(
double *
const norm )
const = 0;
00107
virtual int NormMax(
double *
const norm )
const = 0;
00108
virtual int Copy(
const PromVector_base *
const pvec ) = 0;
00109
virtual int AXPBY(
double a,
double b,
const PromVector_base *
const X)=0;
00110
virtual int AYPX(
double a,
const PromVector_base *
const X ) = 0;
00111
virtual int AXPY(
double a,
const PromVector_base *
const X ) = 0;
00112
virtual int WAXPY(
double alpha,
const PromVector_base *
const X,
00113
const PromVector_base *
const Y) = 0;
00114
virtual int MAXPY(
const int nv,
double a[],
const PromVector_base *
const*X)=0;
00115
virtual int MDot(
const int nv,
const PromVector_base *
const*Xarr,
double arr[])
const=0;
00116
virtual int Assembly() = 0;
00117
virtual int PointwiseMult(
const PromVector_base *
const X,
00118
const PromVector_base *
const Y ) = 0;
00119
virtual int PointwiseDiv(
const PromVector_base *
const X,
00120
const PromVector_base *
const Y ) = 0;
00121
virtual int Reciprocal() = 0;
00122
virtual int getN()
const = 0;
00123
virtual int getn()
const = 0;
00124
virtual int my0()
const = 0;
00125
virtual double operator[](
int i)
const = 0;
00126
virtual bool isComplex()
const{
return FALSE;}
00127
virtual int Print(
char *str=NULL, FILE *file=stderr,
int limit=100000)
const{
00128
double *arr;
int n = getn()*(isComplex()?2:1);
00129 GetArray( &arr );
00130 PromContext::printArr( file, arr, (limit > n) ? n : limit,
00131 (str==NULL) ?
"vector[%d]" : str );
00132 RestoreArray( &arr );
00133
return 0;
00134 }
00135
00136
const PromMap &map_;
00137 };
00138
00139
class PromPC_base;
00141
00144 class PromMatrix_base
00145 {
00146
friend class PromMatrix;
00147
protected:
00148
PromMatrix_base(
const PromMap &map );
00149
public:
00150
virtual int Transpose(
const PromMap &map,
PromMatrix_base ** )
const = 0;
00151
virtual ~
PromMatrix_base(){ DeleteProctable(); }
00153
virtual int GetLocalNodeRow(
const int brow,
const int len,
int &ncols,
00154
int adjacs[],
double **Aij = NULL )
const = 0;
00155
virtual int GetDiagonal(
PromVector_base *work )
const = 0;
00156
virtual int MultTranspose(
const PromVector_base *
const x,
00157
PromVector_base *
const b )
const = 0;
00158
virtual int MultTransposeAdd(
const PromVector_base *
const x,
00159
const PromVector_base *
const y,
00160
PromVector_base *
const b )
const = 0;
00161
virtual int Mult(
const PromVector_base *
const x,
00162
PromVector_base *
const b )
const = 0;
00163
virtual int KSPMult(
const PromVector_base *
const x,
00164
PromVector_base *
const b )
const = 0;
00165
virtual int MultAdd(
const PromVector_base *
const x,
00166
const PromVector_base *
const y,
00167
PromVector_base *
const b )
const = 0;
00168
virtual int GetRow(
int row,
int *ncols,
int **colpp=NULL,
double **arrdat=NULL)
const = 0;
00169
virtual int RestoreRow(
int row,
int*ncols,
int**colpp=NULL,
double**arrdat=NULL)
00170
const = 0;
00171
virtual int SetValues(
const int nrow,
int *pgeqi,
const int ncol,
00172
int *colp,
double *vals,
const int add_type ) = 0;
00173
virtual int SetValuesBlocked(
const int nrow,
int *pgidi,
const int ncol,
00174
int *pcol,
double *vs,
const int add_type) = 0;
00175
virtual int ZeroEntries() = 0;
00176
virtual int ZeroRows( PromIS is,
double diags = 0.0 ) = 0;
00177
virtual int Assembly() = 0;
00178
virtual int FlushAssembly() = 0;
00179
virtual int Shift(
double val ) = 0;
00180
virtual int Scale(
double val ) = 0;
00181
virtual int Scale(
const PromVector_base *
const x,
00182
const PromVector_base *
const b ) = 0;
00183
virtual int NormFrob(
double *
const norm )
const = 0;
00184
virtual int NormInf(
double *
const norm )
const = 0;
00185
virtual int AXPY(
double val,
const PromMatrix_base *
const AA ) = 0;
00186
virtual int SetOption(
int opt ) = 0;
00187
virtual int getN()
const = 0;
00188
virtual int getn()
const = 0;
00189
virtual int getM()
const = 0;
00190
virtual int getm()
const = 0;
00191
virtual int getMNodes()
const = 0;
00192
virtual int getLocalNNZ()
const = 0;
00193
virtual int getGlobalNNZ()
const = 0;
00194
virtual MPI_Comm MPIComm()
const = 0;
00195
virtual int GetGArray(
int **arr,
int *sz )
const = 0;
00196
virtual int RestoreGArray(
int **arr )
const = 0;
00197
virtual bool isOK()
const = 0;
00198
virtual bool isAssembled()
const = 0;
00199
virtual bool isSymmetricValued()
const = 0;
00200
virtual bool isSymmetricMatrix()
const = 0;
00201
virtual int Print(
char *str=NULL, FILE *file=stderr)
const=0;
00202
virtual int ConvertToSymm()=0;
00203
virtual int ConvertFromSymm()=0;
00204
virtual int ConvertToSuperLU()=0;
00205
virtual bool isSuperLU()
const=0;
00207
virtual int ComputeLambda(
const PromOptions &opt, PromPerfMonitor &perf,
00208
int maxits,
PromPC_base *
const pc = NULL)
00209
const;
00210
virtual bool isComplex()
const{
return FALSE;}
00211
double getDRho()
const { assert(
this!=NULL);
return DRhoA_; }
00212
int SetDRho(
double r)
const{assert(
this!=NULL); DRhoA_ = r;
return 0;}
00213
bool isDirty()
const{
return dirty_data_;}
00214
int SetDirty(
bool b = TRUE)
const { dirty_data_ = b;
return 0; }
00215
bool isIndefinite()
const {
return (indefinite_ == 1); }
00216
bool isDefinite()
const {
return (indefinite_ == -1); }
00217
int SetIndefinite(
bool b = TRUE)
const { indefinite_ = (b ? 1 : -1);
return 0; }
00218
int SetIndefinite(
int i)
const { indefinite_ = i;
return 0; }
00219
int DeleteProctable();
00220
00221
00222
const PromMap &map_;
00223
00225 int numRecv_;
00226
bool pRecvDone_;
00227
int *recvSzs_;
00228
int *recvParts_;
00229 PromTable<PromTable<PromTable<double*>*>*> proci_table_gidi_tablej_;
00230 PromTable<int> *gid_ndf_;
00231
protected:
00232
mutable float DRhoA_;
00233
mutable bool dirty_data_;
00234
mutable short int indefinite_;
00235 };
00236
00237
00238
00239
00241
00244
class PromMatrix;
00245
class PromCRMatrix;
00246
class PromPC;
00247
00248
class PromPCKKT
00249 {
00250
public:
00251 PromPCKKT(
const PromMatrix *CC) : C_(CC), dpc_(NULL), r_(NULL), work_(NULL){}
00252
virtual ~PromPCKKT();
00253
virtual int Create(
const PromOptions &opts, PromPerfMonitor &perf,
00254
const PromMatrix *
const AA,
00255
const PromCRVector *
const XX,
bool ilu = FALSE ) = 0;
00256
virtual int SetOperator(
const PromCRMatrix *
const KK ) = 0;
00257
virtual int SetNewOperator(
const PromCRMatrix *
const KK ) = 0;
00258
virtual int SetNewPVec(
const PromMap &map ) = 0;
00259
virtual int SetUp();
00260
virtual int Apply(
PromPC_base *
const pc,
00261
const PromCRMatrix *
const AA,
00262
const PromCRVector *
const bb, PromCRVector *
const xx,
00263
const bool zeroguess=TRUE ) = 0;
00264
virtual const char * getPCString() const = 0;
00265 virtual
int getNumBlocks() const = 0;
00266
00267 const
PromMatrix *C_;
00268
PromPC *dpc_;
00269 protected:
00270 PromCRVector *work_;
00271 PromCRVector *r_;
00272 };
00273
00274
00275
00276
00277
00278 #undef __FUNCT__
00279 #define __FUNCT__ "class
PromBitArr"
00281
00284 class
PromBitArr
00285 {
00286
public:
00287 PromBitArr(
int sz) : isz_(sz), count_(0) {
00288 assert(
sizeof(
unsigned int) == 4);
00289
int bsz = sz/32+1;
00290 PetscMalloc( bsz*
sizeof(
unsigned int), &data_ );
00291 PetscMemzero( (
char*)data_, bsz*
sizeof(
unsigned int) );
00292 }
00293 ~PromBitArr(){ PetscFree( data_ ); }
00294
bool operator[](
const int ii)
const {
00295 assert(
this!=NULL);
00296
unsigned int a = data_[ii/32], k = ii%32, p = 32 - k - 1;
00297
unsigned int l = a << p;
00298
return ((l>>31) == 1);
00299 }
00300
int SetAt(
const int ii,
bool val){
00301 assert(
this!=NULL);
00302
unsigned int a = data_[ii/32], k = ii%32, p = 32 - k - 1;
00303
unsigned int l = a << p;
00304
bool b = ((l>>31) == 1);
00305
if( b != val ){
00306
unsigned int jj = 1; jj = jj << k;
00307
if( b ){ a -= jj; count_--; }
00308
else{ a += jj; count_++; }
00309 data_[ii/32] = a;
00310 }
00311
return 0;
00312 }
00313
int SetZero(){
00314 assert(
this!=NULL);
int bsz = isz_/32+1;
00315 PetscMemzero( (
char*)data_, bsz*
sizeof(
unsigned int) );
00316 count_ = 0;
00317
return 0;
00318 }
00319
int getSize()
const{
return isz_;}
00320
int getCount()
const{
return count_;}
00321
private:
00323 unsigned int *data_;
00324
unsigned const int isz_;
00325
unsigned int count_;
00326 };
00327
00331 class PromPCKKT_ASM :
public PromPCKKT
00332 {
00333
public:
00334
PromPCKKT_ASM(
const PromMatrix *
const CC,
00335
const int naggs, PromIS *CRIS );
00336
virtual ~
PromPCKKT_ASM();
00337
virtual int Create(
const PromOptions &opts,
00338 PromPerfMonitor &perf,
00339
const PromMatrix *
const AA,
00340
const PromCRVector *
const XX,
bool ilu = FALSE );
00341
virtual int SetOperator(
const PromCRMatrix *
const KK );
00342
virtual int SetNewOperator(
const PromCRMatrix *
const KK );
00343
virtual int SetNewPVec(
const PromMap &map );
00344
virtual int Apply(
PromPC_base *
const pc,
00345
const PromCRMatrix *
const AA,
00346
const PromCRVector *
const bb, PromCRVector *
const xx,
00347
const bool zeroguess = TRUE );
00348
virtual const char*getPCString()
const{assert(
this!=NULL);
return "KKT-ASM";}
00349
virtual int getNumBlocks()
const;
00350
protected:
00351
virtual int CreatePrimAggs(
const PromMatrix *
const AA,
bool ilu );
00352
int Apply_private(
const PromCRVector*
const bb, PromCRVector*
const xx )
const;
00353
#if defined(PROM_USE_PETSC)
00354
void *aux_mats_;
00355
void *d_scatters_;
00356
#else // petra
00357
#endif
00358
00359 PromIS *PrimIS_;
00360 PromIS *
const CRIS_;
00361
int *mask_;
00362
double *zeros_;
00363
const short int nBlocks_;
00364
short int nMask_;
00365
public:
00366
PromBitArr *lid_isInKKT_;
00367 };
00368
00372 class PromPCKKT_Shell :
public PromPCKKT_ASM
00373 {
00374
public:
00375
PromPCKKT_Shell(
const PromMatrix*
const AA,
const PromMatrix*
const CC ) :
00376
PromPCKKT_ASM( CC, 1, make_is_w_all( CC ) ){}
00377
virtual int Apply(
PromPC_base *
const,
const PromCRMatrix *
const AA,
00378
const PromCRVector *
const bb, PromCRVector *
const xx,
00379
const bool zeroguess = TRUE );
00380
virtual const char * getPCString()
const;
00381
virtual int getNumBlocks()
const {
return 1; }
00382
protected:
00383
virtual int CreatePrimAggs(
const PromMatrix *
const AA,
bool ilu );
00384
private:
00385
static PromIS *make_is_w_all(
const PromMatrix*
const mat );
00386 };
00387
00389
00392 class PromPCKKT_Seg :
public PromPCKKT
00393 {
00394
public:
00395
PromPCKKT_Seg(
const PromMatrix *
const CC,
const PromMatrix *
const CCt ) :
00396 PromPCKKT(CC), CDinvCt_(CCt) {}
00397
virtual int Create(
const PromOptions &opts,
00398 PromPerfMonitor &perf,
00399
const PromMatrix *
const AA,
00400
const PromCRVector *
const XX,
bool ilu = FALSE );
00401
virtual int SetOperator(
const PromCRMatrix *
const KK );
00402
virtual int SetNewOperator(
const PromCRMatrix *
const KK );
00403
virtual int SetNewPVec(
const PromMap &map ){
return 0; }
00404
virtual int Apply(
PromPC_base*
const,
const PromCRMatrix *
const AA,
00405
const PromCRVector *
const bb, PromCRVector *
const xx,
00406
const bool zeroguess = TRUE );
00407
virtual const char * getPCString()
const {
return "KKT-Segregated"; }
00408
virtual int getNumBlocks()
const {
return 0; }
00409
00410
protected:
00411
00412
const PromMatrix *CDinvCt_;
00413 };
00414
00415
00416
00417
00418
00419
class PromCRMatrix;
00420
class PromGrid;
00421
#undef __FUNCT__
00422
#define __FUNCT__ "class PromPC_base"
00423
00424
00427 class PromPC_base
00428 {
00429
friend class Prometheus_LinSysCore;
00430
friend class PromMG;
00431
friend class PromPCKKT_ASM;
00432
friend class Prometheus;
00433
protected:
00434
PromPC_base(
const PromOptions &opt, PromPerfMonitor &perf):
00435 options_(opt),perf_mon_(perf),numits_(1),setup_called_(0),kkt_(NULL),
00436 KKT_(NULL){}
00437
public:
00438
virtual ~
PromPC_base();
00439
00440
int SetUp();
00441
virtual int SetUp_private() = 0;
00443
virtual int Apply(
const PromVector_base *
const bb,
00444
PromVector_base *
const xx,
const bool zerox = TRUE );
00445
virtual int Apply_private(
const PromVector_base*
const bb,
00446
PromVector_base *
const xx,
00447
const bool zerox = TRUE) = 0;
00449
int SetOperator(
const PromMatrix_base *
const A);
00450
int SetNewOperator(
const PromMatrix_base *
const A);
00451
virtual int SetOperator_private(
const PromMatrix_base *
const A ) = 0;
00452
int CreateKKTSmoother(
const PromMatrix *
const AA,
const PromMatrix *
const CC,
00453
const PromCRVector *
const XX,
const int naggs, PromIS *CRIS,
00454
const PromMatrix *
const CCt );
00455
int CreateKKTSolver(
const PromMatrix *
const AA,
const PromMatrix *
const CC,
00456
const PromCRVector *
const XX );
00458
virtual const char * getPCString()
const = 0;
00459
virtual const char * getSubPCString()
const = 0;
00460
virtual int getNumBlocks()
const{
return (kkt_==NULL)?0:kkt_->getNumBlocks();}
00461
virtual int PrintLevelInfo()
const {
return 0; }
00462
virtual PromPCType getType()
const = 0;
00463
virtual int takeNonZeroGuess()
const{
return 0; }
00464
protected:
00465
virtual int SetNewOperator_private(
const PromMatrix_base *
const A ) = 0;
00466
public:
00467
bool operator==(
const PromPC_base* a )
const {
00468 assert(
this != NULL );
00469
return (a->
getType() == getType());
00470 }
00471
bool operator!=(
const PromPC_base* a )
const {
00472 assert(
this != NULL );
00473
return (a->
getType() != getType());
00474 }
00476
00477
virtual int CreateASM(
const int nblocks, PromIS *blockISs ) {
00478 PetscPrintf(MPI_COMM_SELF,
"[?]%s dummy method ??????\n",
00479
"PromPC_base::CreateASM");
00480
return 1;
00481 }
00482
virtual int SetType( PromPCType type,
const PromVector_base *
const v ){
00483 PetscPrintf(MPI_COMM_SELF,
"[?]%s dummy method ??????\n",
00484
"PromPC_base::SetType");
00485
return 1;
00486 }
00487
int SetNumIts(
int nn) { numits_ = nn;
return 0; }
00488
int getNumIts()
const {
return numits_; }
00489
00490
static const char *
const PromPCStrings[PROMNUMPC];
00491
protected:
00492
const PromOptions &options_;
00493 PromPerfMonitor &perf_mon_;
00494
const PromCRMatrix *KKT_;
00495
short int numits_;
00496
short int setup_called_;
00497
public:
00498 PromPCKKT *kkt_;
00499 };
00500
00501
#endif // __PROM_BASE_H__