|
Tpetra Matrix/Vector Services
Version of the Day
|
00001 /* 00002 //@HEADER 00003 // ************************************************************************ 00004 // 00005 // Kokkos: Node API and Parallel Node Kernels 00006 // Copyright (2008) Sandia Corporation 00007 // 00008 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, 00009 // the U.S. Government retains certain rights in this software. 00010 // 00011 // Redistribution and use in source and binary forms, with or without 00012 // modification, are permitted provided that the following conditions are 00013 // met: 00014 // 00015 // 1. Redistributions of source code must retain the above copyright 00016 // notice, this list of conditions and the following disclaimer. 00017 // 00018 // 2. Redistributions in binary form must reproduce the above copyright 00019 // notice, this list of conditions and the following disclaimer in the 00020 // documentation and/or other materials provided with the distribution. 00021 // 00022 // 3. Neither the name of the Corporation nor the names of the 00023 // contributors may be used to endorse or promote products derived from 00024 // this software without specific prior written permission. 00025 // 00026 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY 00027 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 00028 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 00029 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE 00030 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 00031 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 00032 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00033 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00034 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00035 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00036 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00037 // 00038 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 00039 // 00040 // ************************************************************************ 00041 //@HEADER 00042 */ 00043 00044 //Note this code lives only temporarily in Tpetra 00045 //As soon as GEMM kernels exist in KokkosLinAlg and thus a depnedency on Teuchos 00046 //can be eliminated the code will move to KokkosLinAlg. 00047 00048 #if defined(KOKKOS_MULTIVECTOR_H_) && defined(TPETRA_KOKKOS_REFACTOR_MULTIVECTOR_DEF_HPP) 00049 00050 #include<Teuchos_BLAS.hpp> 00051 #include<Teuchos_as.hpp> 00052 00053 #ifdef KOKKOS_HAVE_CUDA 00054 #include<cublas.h> 00055 #endif 00056 namespace Kokkos { 00057 namespace Impl { 00058 00059 template<class ViewType> 00060 size_t getStride2DView (ViewType A) { 00061 size_t stride[8]; 00062 A.stride (stride); 00063 return A.dimension_1 () > 1 ? stride[1] : A.dimension_0 (); 00064 } 00065 } 00066 00073 template <typename Scalar, typename DeviceType> 00074 struct DeviceGEMM { 00075 public: 00076 static void 00077 GEMM (const Teuchos::ETransp transA, 00078 const Teuchos::ETransp transB, 00079 const Scalar alpha, 00080 View<const Scalar**, LayoutLeft, DeviceType> A, 00081 View<const Scalar**, LayoutLeft, DeviceType> B, 00082 const Scalar beta, 00083 View<Scalar**, LayoutLeft, DeviceType> C) 00084 { 00085 Teuchos::BLAS<int,Scalar> blas; 00086 const int m = static_cast<int> (C.dimension_0 ()), 00087 n = static_cast<int> (C.dimension_1 ()), 00088 k = (transA == Teuchos::NO_TRANS ? A.dimension_1 () : A.dimension_0 ()), 00089 lda = static_cast<int> (Impl::getStride2DView (A)), 00090 ldb = static_cast<int> (Impl::getStride2DView (B)), 00091 ldc = static_cast<int> (Impl::getStride2DView (C)); 00092 // For some BLAS implementations (e.g., MKL), GEMM when B has 00093 // one column may be signficantly less efficient than GEMV. 00094 if (n == 1 && transB == Teuchos::NO_TRANS) { 00095 blas.GEMV (transA, A.dimension_0 (), A.dimension_1 (), alpha, 00096 A.ptr_on_device(), lda, 00097 B.ptr_on_device(), static_cast<int> (1), 00098 beta, C.ptr_on_device(), static_cast<int> (1)); 00099 } 00100 else { 00101 blas.GEMM (transA, transB, m, n, k, alpha, 00102 A.ptr_on_device(), lda, 00103 B.ptr_on_device(), ldb, 00104 beta, C.ptr_on_device(), ldc); 00105 } 00106 } 00107 }; 00108 00109 // template <typename Scalar> 00110 // struct DeviceGEMM<Scalar,Serial> { 00111 // public: 00112 // static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, Scalar alpha, 00113 // View<const Scalar**,LayoutLeft,Serial> A, View<const Scalar**,LayoutLeft,Serial> B, 00114 // Scalar beta, View<Scalar**,Serial> C) { 00115 // Teuchos::BLAS<int,Scalar> blas; 00116 // const int m = Teuchos::as<int>(C.dimension_0()), 00117 // n = Teuchos::as<int>(C.dimension_1()), 00118 // k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()), 00119 // lda = Teuchos::as<int>(Impl::getStride2DView(A)), 00120 // ldb = Teuchos::as<int>(Impl::getStride2DView(B)), 00121 // ldc = Teuchos::as<int>(Impl::getStride2DView(C)); 00122 // // For some BLAS implementations (i.e. MKL), GEMM when B has one column 00123 // // is signficantly less efficient 00124 // if (n == 1 && transB == Teuchos::NO_TRANS) 00125 // blas.GEMV(transA, A.dimension_0(), A.dimension_1(), alpha, A.ptr_on_device(), lda, B.ptr_on_device(), Teuchos::as<int>(1), beta, C.ptr_on_device(), Teuchos::as<int>(1)); 00126 // else 00127 // blas.GEMM(transA, transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc); 00128 // } 00129 // }; 00130 00131 // #ifdef KOKKOS_HAVE_PTHREAD 00132 // template <typename Scalar> 00133 // struct DeviceGEMM<Scalar,Threads> { 00134 // public: 00135 // static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, Scalar alpha, 00136 // View<const Scalar**,LayoutLeft,Threads> A, View<const Scalar**,LayoutLeft,Threads> B, 00137 // Scalar beta, View<Scalar**,LayoutLeft,Threads> C) { 00138 // Teuchos::BLAS<int,Scalar> blas; 00139 // const int m = Teuchos::as<int>(C.dimension_0()), 00140 // n = Teuchos::as<int>(C.dimension_1()), 00141 // k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()), 00142 // lda = Teuchos::as<int>(Impl::getStride2DView(A)), 00143 // ldb = Teuchos::as<int>(Impl::getStride2DView(B)), 00144 // ldc = Teuchos::as<int>(Impl::getStride2DView(C)); 00145 // blas.GEMM(transA, transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc); 00146 // } 00147 // }; 00148 // #endif 00149 00150 // #ifdef KOKKOS_HAVE_OPENMP 00151 // template <typename Scalar> 00152 // struct DeviceGEMM<Scalar,OpenMP> { 00153 // public: 00154 // static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, Scalar alpha, 00155 // View<const Scalar**,LayoutLeft,OpenMP> A, View<const Scalar**,LayoutLeft,OpenMP> B, 00156 // Scalar beta, View<Scalar**,LayoutLeft,OpenMP> C) { 00157 // Teuchos::BLAS<int,Scalar> blas; 00158 // const int m = Teuchos::as<int>(C.dimension_0()), 00159 // n = Teuchos::as<int>(C.dimension_1()), 00160 // k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()), 00161 // lda = Teuchos::as<int>(Impl::getStride2DView(A)), 00162 // ldb = Teuchos::as<int>(Impl::getStride2DView(B)), 00163 // ldc = Teuchos::as<int>(Impl::getStride2DView(C)); 00164 // blas.GEMM(transA, transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc); 00165 // } 00166 // }; 00167 // #endif 00168 00169 #ifdef KOKKOS_HAVE_CUDA 00170 template <typename Scalar> 00171 struct DeviceGEMM<Scalar,Cuda> { 00172 public: 00173 static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, Scalar alpha, 00174 View<const Scalar**,LayoutLeft,Cuda> A, View<const Scalar**,LayoutLeft,Cuda> B, 00175 Scalar beta, View<Scalar**,LayoutLeft,Cuda> C) { 00176 TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "DeviceGEMM: Kokkos::Cuda has no support for GEMM operations over Scalar=" << Teuchos::typeName(alpha) << "."); 00177 } 00178 }; 00179 00180 00181 template <> 00182 struct DeviceGEMM<float,Cuda> { 00183 public: 00184 static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, float alpha, 00185 View<const float**,LayoutLeft,Cuda> A, View<const float**,LayoutLeft,Cuda> B, 00186 float beta, View<float**,LayoutLeft,Cuda> C) { 00187 const int m = Teuchos::as<int>(C.dimension_0()), 00188 n = Teuchos::as<int>(C.dimension_1()), 00189 k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()), 00190 lda = Teuchos::as<int>(Impl::getStride2DView(A)), 00191 ldb = Teuchos::as<int>(Impl::getStride2DView(B)), 00192 ldc = Teuchos::as<int>(Impl::getStride2DView(C)); 00193 const char char_transA = (transA == Teuchos::NO_TRANS ? 'N' : 'T'), 00194 char_transB = (transB == Teuchos::NO_TRANS ? 'N' : 'T'); 00195 cublasSgemm(char_transA, char_transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc); 00196 #ifdef HAVE_KOKKOS_DEBUG 00197 cublasStatus info = cublasGetError(); 00198 TEUCHOS_TEST_FOR_EXCEPTION( info != CUBLAS_STATUS_SUCCESS, std::runtime_error, "cublasSgemm failed with status " << info << "." ); 00199 #endif 00200 } 00201 }; 00202 00203 template <> 00204 struct DeviceGEMM<double,Cuda> { 00205 public: 00206 static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, double alpha, 00207 View<const double**,LayoutLeft,Cuda> A, View<const double**,LayoutLeft,Cuda> B, 00208 double beta, View<double**,LayoutLeft,Cuda> C) { 00209 const int m = Teuchos::as<int>(C.dimension_0()), 00210 n = Teuchos::as<int>(C.dimension_1()), 00211 k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()), 00212 lda = Teuchos::as<int>(Impl::getStride2DView(A)), 00213 ldb = Teuchos::as<int>(Impl::getStride2DView(B)), 00214 ldc = Teuchos::as<int>(Impl::getStride2DView(C)); 00215 const char char_transA = (transA == Teuchos::NO_TRANS ? 'N' : 'T'), 00216 char_transB = (transB == Teuchos::NO_TRANS ? 'N' : 'T'); 00217 cublasDgemm(char_transA, char_transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc); 00218 #ifdef HAVE_KOKKOS_DEBUG 00219 cublasStatus info = cublasGetError(); 00220 TEUCHOS_TEST_FOR_EXCEPTION( info != CUBLAS_STATUS_SUCCESS, std::runtime_error, "cublasDgemm failed with status " << info << "." ); 00221 #endif 00222 } 00223 }; 00224 00225 00226 #endif 00227 } 00228 #endif 00229
1.7.6.1