00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042 #include "PlayaDenseSerialMatrix.hpp"
00043 #include "PlayaDenseSerialMatrixFactory.hpp"
00044 #include "PlayaSerialVector.hpp"
00045 #include "PlayaVectorSpaceDecl.hpp"
00046 #include "PlayaVectorDecl.hpp"
00047 #include "PlayaLinearOperatorDecl.hpp"
00048 #include "Teuchos_BLAS.hpp"
00049
00050 #ifndef HAVE_TEUCHOS_EXPLICIT_INSTANTIATION
00051 #include "PlayaLinearOperatorImpl.hpp"
00052 #include "PlayaVectorImpl.hpp"
00053 #endif
00054
00055 extern "C"
00056 {
00057 void dgesv_(int *n, int *nrhs, double *a, int* lda,
00058 int *ipiv, double *b, int *ldb, int *info);
00059
00060 void dgesvd_( char* jobu, char* jobvt, int* m, int* n, double* a,
00061 int* lda, double* s, double* u, int* ldu, double* vt, int* ldvt,
00062 double* work, int* lwork, int* info );
00063 }
00064 using std::max;
00065 using std::min;
00066
00067 using namespace Playa;
00068 using namespace Teuchos;
00069
00070 using std::setw;
00071
00072 DenseSerialMatrix::DenseSerialMatrix(
00073 const VectorSpace<double>& domain,
00074 const VectorSpace<double>& range)
00075 : LinearOpWithSpaces<double>(domain, range),
00076 nRows_(range.dim()),
00077 nCols_(domain.dim()),
00078 data_(nRows_*nCols_)
00079 {}
00080
00081
00082 void DenseSerialMatrix::apply(
00083 Teuchos::ETransp transApplyType,
00084 const Vector<double>& in,
00085 Vector<double> out) const
00086 {
00087 const SerialVector* rvIn = SerialVector::getConcrete(in);
00088 SerialVector* rvOut = SerialVector::getConcrete(out);
00089
00090 Teuchos::BLAS<int, double> blas;
00091 int lda = numRows();
00092 blas.GEMV(transApplyType, numRows(), numCols(), 1.0, dataPtr(),
00093 lda, rvIn->dataPtr(), 1, 1.0, rvOut->dataPtr(), 1);
00094 }
00095
00096 void DenseSerialMatrix::addToRow(int globalRowIndex,
00097 int nElemsToInsert,
00098 const int* globalColumnIndices,
00099 const double* elementValues)
00100 {
00101 int r = globalRowIndex;
00102 for (int k=0; k<nElemsToInsert; k++)
00103 {
00104 int c = globalColumnIndices[k];
00105 double x = elementValues[k];
00106 data_[r + c*numRows()] = x;
00107 }
00108 }
00109
00110 void DenseSerialMatrix::zero()
00111 {
00112 for (int i=0; i<data_.size(); i++) data_[i] = 0.0;
00113 }
00114
00115
00116 void DenseSerialMatrix::print(std::ostream& os) const
00117 {
00118 if (numCols() <= 4)
00119 {
00120 for (int i=0; i<numRows(); i++)
00121 {
00122 for (int j=0; j<numCols(); j++)
00123 {
00124 os << setw(16) << data_[i+numRows()*j];
00125 }
00126 os << std::endl;
00127 }
00128 }
00129 else
00130 {
00131 for (int i=0; i<numRows(); i++)
00132 {
00133 for (int j=0; j<numCols(); j++)
00134 {
00135 os << setw(6) << i << setw(6) << j << setw(20) << data_[i+numRows()*j]
00136 << std::endl;
00137 }
00138 }
00139 }
00140 }
00141
00142 void DenseSerialMatrix::setRow(int row, const Array<double>& rowVals)
00143 {
00144 TEUCHOS_TEST_FOR_EXCEPT(rowVals.size() != numCols());
00145 TEUCHOS_TEST_FOR_EXCEPT(row < 0);
00146 TEUCHOS_TEST_FOR_EXCEPT(row >= numRows());
00147
00148 for (int i=0; i<rowVals.size(); i++)
00149 {
00150 data_[row+numRows()*i] = rowVals[i];
00151 }
00152 }
00153
00154
00155 namespace Playa
00156 {
00157
00158
00159 SolverState<double> denseSolve(const LinearOperator<double>& A,
00160 const Vector<double>& b,
00161 Vector<double>& x)
00162 {
00163 const DenseSerialMatrix* Aptr
00164 = dynamic_cast<const DenseSerialMatrix*>(A.ptr().get());
00165 TEUCHOS_TEST_FOR_EXCEPT(Aptr==0);
00166
00167 DenseSerialMatrix tmp = *Aptr;
00168
00169 x = b.copy();
00170 SerialVector* xptr
00171 = dynamic_cast<SerialVector*>(x.ptr().get());
00172
00173 int N = Aptr->numRows();
00174 int nRHS = 1;
00175 int LDA = N;
00176 Array<int> iPiv(N);
00177 int LDB = N;
00178 int info = 0;
00179 dgesv_(&N, &nRHS, tmp.dataPtr(), &LDA, &(iPiv[0]), xptr->dataPtr(),
00180 &LDB, &info);
00181
00182 if (info == 0)
00183 {
00184 return SolverState<double>(SolveConverged, "solve OK",
00185 0, 0.0);
00186 }
00187 else
00188 {
00189 return SolverState<double>(SolveCrashed, "solve crashed with dgesv info="
00190 + Teuchos::toString(info),
00191 0, 0.0);
00192 }
00193 }
00194
00195
00196 void denseSVD(const LinearOperator<double>& A,
00197 LinearOperator<double>& U,
00198 Vector<double>& Sigma,
00199 LinearOperator<double>& Vt)
00200 {
00201 VectorSpace<double> mSpace = A.range();
00202 VectorSpace<double> nSpace = A.domain();
00203
00204 const DenseSerialMatrix* Aptr
00205 = dynamic_cast<const DenseSerialMatrix*>(A.ptr().get());
00206 TEUCHOS_TEST_FOR_EXCEPT(Aptr==0);
00207
00208 DenseSerialMatrix ATmp = *Aptr;
00209
00210 int M = ATmp.numRows();
00211 int N = ATmp.numCols();
00212 int S = min(M, N);
00213
00214 VectorSpace<double> sSpace;
00215 if (S==M) sSpace = mSpace;
00216 else sSpace = nSpace;
00217
00218 Sigma = sSpace.createMember();
00219 SerialVector* sigPtr
00220 = dynamic_cast<SerialVector*>(Sigma.ptr().get());
00221 TEUCHOS_TEST_FOR_EXCEPT(sigPtr==0);
00222
00223 DenseSerialMatrixFactory umf(sSpace, mSpace);
00224 DenseSerialMatrixFactory vtmf(nSpace, sSpace);
00225
00226 U = umf.createMatrix();
00227 Vt = vtmf.createMatrix();
00228
00229 DenseSerialMatrix* UPtr
00230 = dynamic_cast<DenseSerialMatrix*>(U.ptr().get());
00231 TEUCHOS_TEST_FOR_EXCEPT(UPtr==0);
00232
00233 DenseSerialMatrix* VtPtr
00234 = dynamic_cast<DenseSerialMatrix*>(Vt.ptr().get());
00235 TEUCHOS_TEST_FOR_EXCEPT(VtPtr==0);
00236
00237 double* uData = UPtr->dataPtr();
00238 double* vtData = VtPtr->dataPtr();
00239 double* aData = ATmp.dataPtr();
00240 double* sData = sigPtr->dataPtr();
00241
00242 char jobu = 'S';
00243 char jobvt = 'S';
00244
00245 int LDA = M;
00246 int LDU = M;
00247 int LDVT = S;
00248
00249 int LWORK = max(1, max(3*min(M,N)+max(M,N), 5*min(M,N)));
00250 Array<double> work(LWORK);
00251
00252 int info = 0;
00253
00254 dgesvd_(&jobu, &jobvt, &M, &N, aData, &LDA, sData, uData, &LDU,
00255 vtData, &LDVT, &(work[0]), &LWORK, &info);
00256
00257 TEUCHOS_TEST_FOR_EXCEPTION(info != 0, std::runtime_error,
00258 "dgesvd failed with error code info=" << info);
00259
00260
00261
00262 }
00263
00264 }