PlayaDenseSerialMatrix.cpp
Go to the documentation of this file.
00001 /* @HEADER@ */
00002 // ************************************************************************
00003 // 
00004 //                 Playa: Programmable Linear Algebra
00005 //                 Copyright 2012 Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 //
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Kevin Long (kevin.long@ttu.edu)
00038 // 
00039 
00040 /* @HEADER@ */
00041 
00042 #include "PlayaDenseSerialMatrix.hpp"
00043 #include "PlayaDenseSerialMatrixFactory.hpp"
00044 #include "PlayaSerialVector.hpp"
00045 #include "PlayaVectorSpaceDecl.hpp"  
00046 #include "PlayaVectorDecl.hpp"
00047 #include "PlayaLinearOperatorDecl.hpp"
00048 #include "Teuchos_BLAS.hpp"
00049 
00050 #ifndef HAVE_TEUCHOS_EXPLICIT_INSTANTIATION
00051 #include "PlayaLinearOperatorImpl.hpp"
00052 #include "PlayaVectorImpl.hpp"
00053 #endif
00054 
00055 extern "C"
00056 {
00057 void dgesv_(int *n, int *nrhs, double *a, int* lda, 
00058   int *ipiv, double *b, int *ldb, int *info);
00059 
00060 void dgesvd_( char* jobu, char* jobvt, int* m, int* n, double* a,
00061   int* lda, double* s, double* u, int* ldu, double* vt, int* ldvt,
00062   double* work, int* lwork, int* info );
00063 }
00064 using std::max;
00065 using std::min;
00066 
00067 using namespace Playa;
00068 using namespace Teuchos;
00069 
00070 using std::setw;
00071 
00072 DenseSerialMatrix::DenseSerialMatrix(
00073   const VectorSpace<double>& domain,
00074   const VectorSpace<double>& range)
00075   : LinearOpWithSpaces<double>(domain, range),
00076     nRows_(range.dim()),
00077     nCols_(domain.dim()),
00078     data_(nRows_*nCols_)
00079 {}
00080 
00081 
00082 void DenseSerialMatrix::apply(
00083   Teuchos::ETransp transApplyType,
00084   const Vector<double>& in,
00085   Vector<double> out) const
00086 {
00087   const SerialVector* rvIn = SerialVector::getConcrete(in);
00088   SerialVector* rvOut = SerialVector::getConcrete(out);
00089 
00090   Teuchos::BLAS<int, double> blas;
00091   int lda = numRows();
00092   blas.GEMV(transApplyType, numRows(), numCols(), 1.0, dataPtr(), 
00093     lda, rvIn->dataPtr(), 1, 1.0, rvOut->dataPtr(), 1);
00094 }
00095 
00096 void DenseSerialMatrix::addToRow(int globalRowIndex,
00097   int nElemsToInsert,
00098   const int* globalColumnIndices,
00099   const double* elementValues)
00100 {
00101   int r = globalRowIndex;
00102   for (int k=0; k<nElemsToInsert; k++)
00103   {
00104     int c = globalColumnIndices[k];
00105     double x = elementValues[k];
00106     data_[r + c*numRows()] = x;
00107   }
00108 }
00109 
00110 void DenseSerialMatrix::zero()
00111 {
00112   for (int i=0; i<data_.size(); i++) data_[i] = 0.0;
00113 }
00114 
00115 
00116 void DenseSerialMatrix::print(std::ostream& os) const
00117 {
00118   if (numCols() <= 4)
00119   {
00120     for (int i=0; i<numRows(); i++)
00121     {
00122       for (int j=0; j<numCols(); j++)
00123       {
00124         os << setw(16) << data_[i+numRows()*j];
00125       }
00126       os << std::endl;
00127     }
00128   }
00129   else
00130   {
00131     for (int i=0; i<numRows(); i++)
00132     {
00133       for (int j=0; j<numCols(); j++)
00134       {
00135         os << setw(6) << i << setw(6) << j << setw(20) << data_[i+numRows()*j]
00136            << std::endl;
00137       }
00138     }
00139   }
00140 }
00141 
00142 void DenseSerialMatrix::setRow(int row, const Array<double>& rowVals)
00143 {
00144   TEUCHOS_TEST_FOR_EXCEPT(rowVals.size() != numCols());
00145   TEUCHOS_TEST_FOR_EXCEPT(row < 0);
00146   TEUCHOS_TEST_FOR_EXCEPT(row >= numRows());
00147 
00148   for (int i=0; i<rowVals.size(); i++)
00149   {
00150     data_[row+numRows()*i] = rowVals[i];
00151   }
00152 }
00153 
00154 
00155 namespace Playa
00156 {
00157 
00158 
00159 SolverState<double> denseSolve(const LinearOperator<double>& A,
00160   const Vector<double>& b,
00161   Vector<double>& x)
00162 {
00163   const DenseSerialMatrix* Aptr 
00164     = dynamic_cast<const DenseSerialMatrix*>(A.ptr().get());
00165   TEUCHOS_TEST_FOR_EXCEPT(Aptr==0);
00166   /* make a working copy, because dgesv will overwrite the matrix */
00167   DenseSerialMatrix tmp = *Aptr;
00168   /* Allocate a vector for the solution */
00169   x = b.copy();
00170   SerialVector* xptr 
00171     = dynamic_cast<SerialVector*>(x.ptr().get());
00172   
00173   int N = Aptr->numRows();
00174   int nRHS = 1;
00175   int LDA = N;
00176   Array<int> iPiv(N);
00177   int LDB = N;
00178   int info = 0;
00179   dgesv_(&N, &nRHS, tmp.dataPtr(), &LDA, &(iPiv[0]), xptr->dataPtr(),
00180     &LDB, &info);
00181 
00182   if (info == 0)
00183   {
00184     return SolverState<double>(SolveConverged, "solve OK",
00185       0, 0.0);
00186   }
00187   else 
00188   {
00189     return SolverState<double>(SolveCrashed, "solve crashed with dgesv info="
00190       + Teuchos::toString(info),
00191       0, 0.0);
00192   }
00193 }
00194 
00195 
00196 void denseSVD(const LinearOperator<double>& A,
00197   LinearOperator<double>& U,  
00198   Vector<double>& Sigma,
00199   LinearOperator<double>& Vt)
00200 {
00201   VectorSpace<double> mSpace = A.range();
00202   VectorSpace<double> nSpace = A.domain();
00203 
00204   const DenseSerialMatrix* Aptr 
00205     = dynamic_cast<const DenseSerialMatrix*>(A.ptr().get());
00206   TEUCHOS_TEST_FOR_EXCEPT(Aptr==0);
00207   /* make a working copy, because dgesvd will overwrite the matrix */
00208   DenseSerialMatrix ATmp = *Aptr;
00209 
00210   int M = ATmp.numRows();
00211   int N = ATmp.numCols();
00212   int S = min(M, N);
00213   
00214   VectorSpace<double> sSpace;
00215   if (S==M) sSpace = mSpace;
00216   else sSpace = nSpace;
00217 
00218   Sigma = sSpace.createMember();
00219   SerialVector* sigPtr
00220     = dynamic_cast<SerialVector*>(Sigma.ptr().get());
00221   TEUCHOS_TEST_FOR_EXCEPT(sigPtr==0);
00222 
00223   DenseSerialMatrixFactory umf(sSpace, mSpace);
00224   DenseSerialMatrixFactory vtmf(nSpace, sSpace);
00225   
00226   U = umf.createMatrix();
00227   Vt = vtmf.createMatrix();
00228 
00229   DenseSerialMatrix* UPtr 
00230     = dynamic_cast<DenseSerialMatrix*>(U.ptr().get());
00231   TEUCHOS_TEST_FOR_EXCEPT(UPtr==0);
00232 
00233   DenseSerialMatrix* VtPtr 
00234     = dynamic_cast<DenseSerialMatrix*>(Vt.ptr().get());
00235   TEUCHOS_TEST_FOR_EXCEPT(VtPtr==0);
00236   
00237   double* uData = UPtr->dataPtr();
00238   double* vtData = VtPtr->dataPtr();
00239   double* aData = ATmp.dataPtr();
00240   double* sData = sigPtr->dataPtr();
00241 
00242   char jobu = 'S';
00243   char jobvt = 'S';
00244  
00245   int LDA = M;
00246   int LDU = M;
00247   int LDVT = S;
00248 
00249   int LWORK = max(1, max(3*min(M,N)+max(M,N), 5*min(M,N)));
00250   Array<double> work(LWORK);
00251   
00252   int info = 0;
00253 
00254   dgesvd_(&jobu, &jobvt, &M, &N, aData, &LDA, sData, uData, &LDU, 
00255     vtData, &LDVT, &(work[0]), &LWORK, &info);
00256 
00257   TEUCHOS_TEST_FOR_EXCEPTION(info != 0, std::runtime_error,
00258     "dgesvd failed with error code info=" << info);
00259 
00260   
00261   
00262 }
00263 
00264 }

Site Contact