00001 #ifndef MLAPI_SERIALMATRIX_H
00002 #define MLAPI_SERIALMATRIX_H
00003
00013
00014
00015
00016
00017
00018 #include "ml_common.h"
00019
00020 #include "ml_include.h"
00021
00022 #include "ml_comm.h"
00023 #include "MLAPI_Error.h"
00024 #include "MLAPI_Space.h"
00025 #include "MLAPI_Operator.h"
00026 #include "Epetra_Vector.h"
00027 #include "Epetra_RowMatrix.h"
00028 #include "Teuchos_RefCountPtr.hpp"
00029 #include <iomanip>
00030
00031 namespace MLAPI {
00032
00033 class Epetra_SerialMatrix : public Epetra_RowMatrix {
00034
00035 public:
00036
00037 Epetra_SerialMatrix(const Space& RowSpace, const Space& ColSpace)
00038 {
00039 NumMyRows_ = RowSpace.GetNumMyElements();
00040 NumMyCols_ = ColSpace.GetNumMyElements();
00041
00042 NumMyNonzeros_ = 0;
00043 NumMyDiagonals_ = 0;
00044
00045 if (GetNumProcs() != 1)
00046 ML_THROW("Class SerialMatrix can only be used for serial computations.", -1);
00047
00048 RowMap_ = Teuchos::rcp(new Epetra_Map(NumMyRows_,0,GetEpetra_Comm()));
00049 ColMap_ = Teuchos::rcp(new Epetra_Map(NumMyCols_,0,GetEpetra_Comm()));
00050
00051 ptr_.resize(NumMyRows_);
00052 }
00053
00054 virtual int NumMyRowEntries(int MyRow, int & NumEntries) const
00055 {
00056 #ifdef MLAPI_CHECK
00057 if (MyRow < 0 || MyRow >= NumMyRows())
00058 ML_THROW("Requested not valid row (" + GetString(MyRow) +").", -1);
00059 #endif
00060 NumEntries = ptr_[MyRow].size();
00061
00062 return(0);
00063 }
00064
00065 virtual int MaxNumEntries() const
00066 {
00067 int res = 0, res_i = 0;
00068
00069 for (int i = 0 ; i < NumMyRows() ; ++i) {
00070 NumMyRowEntries(i, res_i);
00071 if (res_i > res)
00072 res = res_i;
00073 }
00074
00075 return(res);
00076 }
00077
00078 virtual int ExtractMyRowCopy(int MyRow, int Length, int & NumEntries,
00079 double *Values, int * Indices) const
00080 {
00081 NumMyRowEntries(MyRow, NumEntries);
00082 if (Length < NumEntries) ML_CHK_ERR(-1);
00083 if (MyRow < 0 || MyRow >= NumMyRows())
00084 ML_CHK_ERR(-2);
00085
00086 int count = 0;
00087 for (where_ = ptr_[MyRow].begin() ; where_ != ptr_[MyRow].end() ; ++where_) {
00088 Indices[count] = where_->first;
00089 Values[count] = where_->second;
00090 ++count;
00091 }
00092 return(0);
00093 }
00094
00095 virtual int ExtractDiagonalCopy(Epetra_Vector & Diagonal) const
00096 {
00097 #ifdef MLAPI_CHECK
00098 if (!Diagonal.Map().SameAs(RowMatrixRowMap()))
00099 ML_CHK_ERR(-1);
00100 #endif
00101
00102 Diagonal.PutScalar(0.0);
00103
00104 for (int i = 0 ; i < NumMyRows() ; ++i) {
00105 for (where_ = ptr_[i].begin() ; where_ != ptr_[i].end() ; ++where_) {
00106 if (where_->first == i) {
00107 Diagonal[i] = where_->second;
00108 break;
00109 }
00110 }
00111 }
00112 return(0);
00113 }
00114
00115 virtual int Multiply(bool TransA, const Epetra_MultiVector& X,
00116 Epetra_MultiVector& Y) const
00117 {
00118
00119 Y.PutScalar(0.0);
00120
00121 if (!TransA) {
00122 for (int v = 0 ; v < X.NumVectors() ; ++v) {
00123 for (int i = 0 ; i < NumMyRows() ; ++i) {
00124 for (where_ = ptr_[i].begin() ; where_ != ptr_[i].end() ; ++where_) {
00125 Y[v][i] += (where_->second) * X[v][where_->first];
00126 }
00127 }
00128 }
00129 }
00130 else {
00131 for (int v = 0 ; v < X.NumVectors() ; ++v) {
00132 for (int i = 0 ; i < NumMyRows() ; ++i) {
00133 for (where_ = ptr_[i].begin() ; where_ != ptr_[i].end() ; ++where_) {
00134 Y[v][where_->first] += (where_->second) * X[v][i];
00135 }
00136 }
00137 }
00138 }
00139
00140 return(0);
00141 }
00142
00143 virtual int Solve(bool Upper, bool Trans, bool UnitDiagonal, const Epetra_MultiVector& X,
00144 Epetra_MultiVector& Y) const
00145 {
00146 ML_CHK_ERR(-1);
00147 }
00148
00149 virtual int InvRowSums(Epetra_Vector& x) const
00150 {
00151 ML_CHK_ERR(-1);
00152 }
00153
00154 virtual int LeftScale(const Epetra_Vector& x)
00155 {
00156 ML_CHK_ERR(-1);
00157 }
00158
00159 virtual int InvColSums(Epetra_Vector& x) const
00160 {
00161 ML_CHK_ERR(-1);
00162 }
00163
00164 virtual int RightScale(const Epetra_Vector& x)
00165 {
00166 ML_CHK_ERR(-1);
00167 }
00168
00169 virtual bool Filled() const
00170 {
00171 return(true);
00172 }
00173
00174 virtual double NormInf() const
00175 {
00176 ML_CHK_ERR(-1);
00177 }
00178
00179 virtual double NormOne() const
00180 {
00181 ML_CHK_ERR(-1);
00182 }
00183
00184 virtual int NumGlobalNonzeros() const
00185 {
00186 return(NumMyNonzeros_);
00187 }
00188
00189 virtual int NumGlobalRows() const
00190 {
00191 return(NumMyRows_);
00192 }
00193
00194 virtual int NumGlobalCols() const
00195 {
00196 return(NumMyCols_);
00197 }
00198
00199 virtual int NumGlobalDiagonals() const
00200 {
00201 return(NumMyDiagonals_);
00202 }
00203
00204 virtual int NumMyNonzeros() const
00205 {
00206 return(NumMyNonzeros_);
00207 }
00208
00209 virtual int NumMyRows() const
00210 {
00211 return(NumMyRows_);
00212 }
00213
00214 virtual int NumMyCols() const
00215 {
00216 return(NumMyCols_);
00217 }
00218
00219 virtual int NumMyDiagonals() const
00220 {
00221 return(NumMyDiagonals_);
00222 }
00223
00224 virtual bool LowerTriangular() const
00225 {
00226 return(false);
00227 }
00228
00229 virtual bool UpperTriangular() const
00230 {
00231 return(false);
00232 }
00233
00234 virtual const Epetra_Map & RowMatrixRowMap() const
00235 {
00236 return(*(RowMap_.get()));
00237 }
00238
00239 virtual const Epetra_Map & RowMatrixColMap() const
00240 {
00241 return(*(ColMap_.get()));
00242 }
00243
00244 virtual const Epetra_Import * RowMatrixImporter() const
00245 {
00246 return(0);
00247 }
00248
00249 virtual const Epetra_Map& OperatorDomainMap() const
00250 {
00251 return(*(ColMap_.get()));
00252 }
00253
00254 virtual const Epetra_Map& OperatorRangeMap() const
00255 {
00256 return(*(RowMap_.get()));
00257 }
00258
00259 virtual const Epetra_Map& Map() const
00260 {
00261 return(*(ColMap_.get()));
00262 }
00263
00265
00266 virtual int SetUseTranspose(bool)
00267 {
00268 ML_CHK_ERR(-1);
00269 }
00270
00271 virtual int Apply(const Epetra_MultiVector& X, Epetra_MultiVector& Y) const
00272 {
00273 return(Multiply(false, X, Y));
00274 }
00275
00276 virtual int ApplyInverse(const Epetra_MultiVector& X,
00277 Epetra_MultiVector& Y) const
00278 {
00279 ML_CHK_ERR(-1);
00280 }
00281
00282 virtual const char* Label() const
00283 {
00284 return("Epetra_SerialMatrix");
00285 }
00286
00287 virtual bool UseTranspose() const
00288 {
00289 return(false);
00290 }
00291
00292 virtual bool HasNormInf() const
00293 {
00294 return(false);
00295 }
00296
00297 virtual const Epetra_Comm& Comm() const
00298 {
00299 return(GetEpetra_Comm());
00300 }
00301
00302 inline double& operator()(const int row, const int col)
00303 {
00304 #ifdef MLAPI_CHECK
00305 if (row < 0 || row >= NumMyRows())
00306 ML_THROW("Requested not valid row (" + GetString(row) +").", -1);
00307 if (col < 0 || row >= NumMyCols())
00308 ML_THROW("Requested not valid column (" + GetString(col) +").", -1);
00309 #endif
00310 where_ = ptr_[row].find(col);
00311
00312 if (where_ != ptr_[row].end())
00313
00314 return(where_->second);
00315 else {
00316 ptr_[row][col] = 0.0;
00317
00318 ++NumMyNonzeros_;
00319
00320 if (row == col)
00321 ++NumMyDiagonals_;
00322
00323 return(ptr_[row][col]);
00324 }
00325 }
00326
00327 private:
00328
00329 Epetra_SerialMatrix(const Epetra_SerialMatrix& rhs)
00330 {
00331 }
00332
00333 Epetra_SerialMatrix& operator=(const Epetra_SerialMatrix& rhs)
00334 {
00335 return(*this);
00336 }
00337
00338 int NumMyRows_;
00339 int NumMyCols_;
00340 int NumMyDiagonals_;
00341 int NumMyNonzeros_;
00342
00343 mutable std::map<int,double>::iterator where_;
00344 mutable std::vector<std::map<int,double> > ptr_;
00345
00346 Teuchos::RefCountPtr<Epetra_Map> RowMap_;
00347 Teuchos::RefCountPtr<Epetra_Map> ColMap_;
00348
00349 };
00350
00351 class SerialMatrix : public Operator
00352 {
00353 public:
00354 SerialMatrix()
00355 {
00356 Matrix_ = 0;
00357 }
00358
00359 SerialMatrix& operator()(const SerialMatrix& rhs)
00360 {
00361 Matrix_ = rhs.Matrix_;
00362 Operator::operator=(rhs);
00363 return(*this);
00364 }
00365
00366 SerialMatrix(const Space& RowSpace, const Space& ColSpace)
00367 {
00368 Matrix_ = new Epetra_SerialMatrix(RowSpace, ColSpace);
00369
00370 Reshape(RowSpace, ColSpace, Matrix_, true);
00371 }
00372
00373 inline double& operator()(const int row, const int col)
00374 {
00375 return((*Matrix_)(row, col));
00376 }
00377
00378 std::ostream& Print(std::ostream& os, const bool verbose = true) const
00379 {
00380 int Length = Matrix_->MaxNumEntries();
00381 std::vector<double> Values(Length);
00382 std::vector<int> Indices(Length);
00383
00384 os << endl;
00385 os << "*** MLAPI::SerialMatrix ***" << endl;
00386 os << "Label = " << GetLabel() << endl;
00387 os << "Number of rows = " << Matrix_->NumMyRows() << endl;
00388 os << "Number of columns = " << Matrix_->NumMyCols() << endl;
00389 os << endl;
00390 os.width(10); os << "row ID";
00391 os.width(10); os << "col ID";
00392 os.width(30); os << "value";
00393 os << endl;
00394 os << endl;
00395
00396 for (int i = 0 ; i < Matrix_->NumMyRows() ; ++i) {
00397 int NnzRow = 0;
00398 Matrix_->ExtractMyRowCopy(i, Length, NnzRow, &Values[0], &Indices[0]);
00399 for (int j = 0 ; j < NnzRow ; ++j) {
00400 os.width(10); os << i;
00401 os.width(10); os << Indices[j];
00402 os.width(30); os << Values[j];
00403 os << endl;
00404 }
00405 }
00406 return(os);
00407 }
00408
00409 private:
00410 Epetra_SerialMatrix* Matrix_;
00411 };
00412
00413 }
00414
00415 #endif // ifndef MLAPI_SERIALMATRIX_H