|
EpetraExt
Development
|
00001 //@HEADER 00002 // *********************************************************************** 00003 // 00004 // EpetraExt: Epetra Extended - Linear Algebra Services Package 00005 // Copyright (2011) Sandia Corporation 00006 // 00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, 00008 // the U.S. Government retains certain rights in this software. 00009 // 00010 // Redistribution and use in source and binary forms, with or without 00011 // modification, are permitted provided that the following conditions are 00012 // met: 00013 // 00014 // 1. Redistributions of source code must retain the above copyright 00015 // notice, this list of conditions and the following disclaimer. 00016 // 00017 // 2. Redistributions in binary form must reproduce the above copyright 00018 // notice, this list of conditions and the following disclaimer in the 00019 // documentation and/or other materials provided with the distribution. 00020 // 00021 // 3. Neither the name of the Corporation nor the names of the 00022 // contributors may be used to endorse or promote products derived from 00023 // this software without specific prior written permission. 00024 // 00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY 00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE 00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00036 // 00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 00038 // 00039 // *********************************************************************** 00040 //@HEADER 00041 00042 #include "EpetraExt_PutRowMatrix.h" 00043 #include "Epetra_Comm.h" 00044 #include "Epetra_Map.h" 00045 #include "Epetra_Vector.h" 00046 #include "Epetra_IntVector.h" 00047 #include "Epetra_SerialDenseVector.h" 00048 #include "Epetra_IntSerialDenseVector.h" 00049 #include "Epetra_Import.h" 00050 #include "Epetra_RowMatrix.h" 00051 #include "Epetra_CrsMatrix.h" 00052 00053 using namespace Matlab; 00054 namespace Matlab { 00055 00056 int CopyRowMatrix(mxArray* matlabA, const Epetra_RowMatrix& A) { 00057 int valueCount = 0; 00058 //int* valueCount = &temp; 00059 00060 Epetra_Map map = A.RowMatrixRowMap(); 00061 const Epetra_Comm & comm = map.Comm(); 00062 int numProc = comm.NumProc(); 00063 00064 if (numProc==1) 00065 DoCopyRowMatrix(matlabA, valueCount, A); 00066 else { 00067 int numRows = map.NumMyElements(); 00068 00069 //cout << "creating allGidsMap\n"; 00070 Epetra_Map allGidsMap(-1, numRows, 0,comm); 00071 //cout << "done creating allGidsMap\n"; 00072 00073 Epetra_IntVector allGids(allGidsMap); 00074 for (int i=0; i<numRows; i++) allGids[i] = map.GID(i); 00075 00076 // Now construct a RowMatrix on PE 0 by strip-mining the rows of the input matrix A. 00077 int numChunks = numProc; 00078 int stripSize = allGids.GlobalLength()/numChunks; 00079 int remainder = allGids.GlobalLength()%numChunks; 00080 int curStart = 0; 00081 int curStripSize = 0; 00082 Epetra_IntSerialDenseVector importGidList; 00083 int numImportGids = 0; 00084 if (comm.MyPID()==0) 00085 importGidList.Size(stripSize+1); // Set size of vector to max needed 00086 for (int i=0; i<numChunks; i++) { 00087 if (comm.MyPID()==0) { // Only PE 0 does this part 00088 curStripSize = stripSize; 00089 if (i<remainder) curStripSize++; // handle leftovers 00090 for (int j=0; j<curStripSize; j++) importGidList[j] = j + curStart; 00091 curStart += curStripSize; 00092 } 00093 // The following import map will be non-trivial only on PE 0. 00094 //cout << "creating importGidMap\n"; 00095 Epetra_Map importGidMap(-1, curStripSize, importGidList.Values(), 0, comm); 00096 //cout << "done creating importGidMap\n"; 00097 Epetra_Import gidImporter(importGidMap, allGidsMap); 00098 Epetra_IntVector importGids(importGidMap); 00099 if (importGids.Import(allGids, gidImporter, Insert)) return(-1); 00100 00101 // importGids now has a list of GIDs for the current strip of matrix rows. 00102 // Use these values to build another importer that will get rows of the matrix. 00103 00104 // The following import map will be non-trivial only on PE 0. 00105 //cout << "creating importMap\n"; 00106 //cout << "A.RowMatrixRowMap().MinAllGID: " << A.RowMatrixRowMap().MinAllGID() << "\n"; 00107 Epetra_Map importMap(-1, importGids.MyLength(), importGids.Values(), A.RowMatrixRowMap().MinAllGID(), comm); 00108 //cout << "done creating importMap\n"; 00109 Epetra_Import importer(importMap, map); 00110 Epetra_CrsMatrix importA(Copy, importMap, 0); 00111 if (importA.Import(A, importer, Insert)) return(-1); 00112 if (importA.FillComplete()) return(-1); 00113 00114 // Finally we are ready to write this strip of the matrix to ostream 00115 if (DoCopyRowMatrix(matlabA, valueCount, importA)) return(-1); 00116 } 00117 } 00118 00119 if (A.RowMatrixRowMap().Comm().MyPID() == 0) { 00120 // set max cap 00121 int* matlabAcolumnIndicesPtr = mxGetJc(matlabA); 00122 matlabAcolumnIndicesPtr[A.NumGlobalRows()] = valueCount; 00123 } 00124 00125 return(0); 00126 } 00127 00128 int DoCopyRowMatrix(mxArray* matlabA, int& valueCount, const Epetra_RowMatrix& A) { 00129 //cout << "doing DoCopyRowMatrix\n"; 00130 int ierr = 0; 00131 int numRows = A.NumGlobalRows(); 00132 //cout << "numRows: " << numRows << "\n"; 00133 Epetra_Map rowMap = A.RowMatrixRowMap(); 00134 Epetra_Map colMap = A.RowMatrixColMap(); 00135 int minAllGID = rowMap.MinAllGID(); 00136 00137 const Epetra_Comm & comm = rowMap.Comm(); 00138 //cout << "did global setup\n"; 00139 if (comm.MyPID()!=0) { 00140 if (A.NumMyRows()!=0) ierr = -1; 00141 if (A.NumMyCols()!=0) ierr = -1; 00142 } 00143 else { 00144 // declare and get initial values of all matlabA pointers 00145 double* matlabAvaluesPtr = mxGetPr(matlabA); 00146 int* matlabAcolumnIndicesPtr = mxGetJc(matlabA); 00147 int* matlabArowIndicesPtr = mxGetIr(matlabA); 00148 00149 // set all matlabA pointers to the proper offset 00150 matlabAvaluesPtr += valueCount; 00151 matlabArowIndicesPtr += valueCount; 00152 00153 if (numRows!=A.NumMyRows()) ierr = -1; 00154 Epetra_SerialDenseVector values(A.MaxNumEntries()); 00155 Epetra_IntSerialDenseVector indices(A.MaxNumEntries()); 00156 //cout << "did proc0 setup\n"; 00157 for (int i=0; i<numRows; i++) { 00158 //cout << "extracting a row\n"; 00159 int I = rowMap.GID(i); 00160 int numEntries = 0; 00161 if (A.ExtractMyRowCopy(i, values.Length(), numEntries, 00162 values.Values(), indices.Values())) return(-1); 00163 matlabAcolumnIndicesPtr[I - minAllGID] = valueCount; // set the starting index of column I 00164 double* serialValuesPtr = values.Values(); 00165 for (int j=0; j<numEntries; j++) { 00166 int J = colMap.GID(indices[j]); 00167 *matlabAvaluesPtr = *serialValuesPtr++; 00168 *matlabArowIndicesPtr = J; 00169 // increment matlabA pointers 00170 matlabAvaluesPtr++; 00171 matlabArowIndicesPtr++; 00172 valueCount++; 00173 } 00174 } 00175 //cout << "proc0 row extraction for this chunck is done\n"; 00176 } 00177 00178 /* 00179 if (comm.MyPID() == 0) { 00180 cout << "printing matlabA pointers\n"; 00181 double* matlabAvaluesPtr = mxGetPr(matlabA); 00182 int* matlabAcolumnIndicesPtr = mxGetJc(matlabA); 00183 int* matlabArowIndicesPtr = mxGetIr(matlabA); 00184 for(int i=0; i < numRows; i++) { 00185 for(int j=0; j < A.MaxNumEntries(); j++) { 00186 cout << "*matlabAvaluesPtr: " << *matlabAvaluesPtr++ << " *matlabAcolumnIndicesPtr: " << *matlabAcolumnIndicesPtr++ << " *matlabArowIndicesPtr" << *matlabArowIndicesPtr++ << "\n"; 00187 } 00188 } 00189 00190 cout << "done printing matlabA pointers\n"; 00191 } 00192 */ 00193 00194 int ierrGlobal; 00195 comm.MinAll(&ierr, &ierrGlobal, 1); // If any processor has -1, all return -1 00196 return(ierrGlobal); 00197 } 00198 00199 } // namespace Matlab
1.7.6.1