Tpetra Matrix/Vector Services  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Defines
Kokkos_MV_GEMM.hpp
00001 /*
00002 //@HEADER
00003 // ************************************************************************
00004 //
00005 //          Kokkos: Node API and Parallel Node Kernels
00006 //              Copyright (2008) Sandia Corporation
00007 //
00008 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00009 // the U.S. Government retains certain rights in this software.
00010 //
00011 // Redistribution and use in source and binary forms, with or without
00012 // modification, are permitted provided that the following conditions are
00013 // met:
00014 //
00015 // 1. Redistributions of source code must retain the above copyright
00016 // notice, this list of conditions and the following disclaimer.
00017 //
00018 // 2. Redistributions in binary form must reproduce the above copyright
00019 // notice, this list of conditions and the following disclaimer in the
00020 // documentation and/or other materials provided with the distribution.
00021 //
00022 // 3. Neither the name of the Corporation nor the names of the
00023 // contributors may be used to endorse or promote products derived from
00024 // this software without specific prior written permission.
00025 //
00026 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00027 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00028 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00029 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00030 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00031 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00032 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00033 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00034 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00035 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00036 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00037 //
00038 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
00039 //
00040 // ************************************************************************
00041 //@HEADER
00042 */
00043 
00044 //Note this code lives only temporarily in Tpetra
00045 //As soon as GEMM kernels exist in KokkosLinAlg and thus a depnedency on Teuchos
00046 //can be eliminated the code will move to KokkosLinAlg.
00047 
00048 #if defined(KOKKOS_MULTIVECTOR_H_) && defined(TPETRA_KOKKOS_REFACTOR_MULTIVECTOR_DEF_HPP)
00049 
00050 #include<Teuchos_BLAS.hpp>
00051 #include<Teuchos_as.hpp>
00052 
00053 #ifdef KOKKOS_HAVE_CUDA
00054 #include<cublas.h>
00055 #endif
00056 namespace Kokkos {
00057   namespace Impl {
00058 
00059     template<class ViewType>
00060     size_t getStride2DView (ViewType A) {
00061       size_t stride[8];
00062       A.stride (stride);
00063       return A.dimension_1 () > 1 ? stride[1] : A.dimension_0 ();
00064     }
00065   }
00066 
00073   template <typename Scalar, typename DeviceType>
00074   struct DeviceGEMM {
00075   public:
00076     static void
00077     GEMM (const Teuchos::ETransp transA,
00078           const Teuchos::ETransp transB,
00079           const Scalar alpha,
00080           View<const Scalar**, LayoutLeft, DeviceType> A,
00081           View<const Scalar**, LayoutLeft, DeviceType> B,
00082           const Scalar beta,
00083           View<Scalar**, LayoutLeft, DeviceType> C)
00084     {
00085       Teuchos::BLAS<int,Scalar> blas;
00086       const int m = static_cast<int> (C.dimension_0 ()),
00087         n = static_cast<int> (C.dimension_1 ()),
00088         k = (transA == Teuchos::NO_TRANS ? A.dimension_1 () : A.dimension_0 ()),
00089         lda = static_cast<int> (Impl::getStride2DView (A)),
00090         ldb = static_cast<int> (Impl::getStride2DView (B)),
00091         ldc = static_cast<int> (Impl::getStride2DView (C));
00092       // For some BLAS implementations (e.g., MKL), GEMM when B has
00093       // one column may be signficantly less efficient than GEMV.
00094       if (n == 1 && transB == Teuchos::NO_TRANS) {
00095         blas.GEMV (transA, A.dimension_0 (), A.dimension_1 (), alpha,
00096                    A.ptr_on_device(), lda,
00097                    B.ptr_on_device(), static_cast<int> (1),
00098                    beta, C.ptr_on_device(), static_cast<int> (1));
00099       }
00100       else {
00101         blas.GEMM (transA, transB, m, n, k, alpha,
00102                    A.ptr_on_device(), lda,
00103                    B.ptr_on_device(), ldb,
00104                    beta, C.ptr_on_device(), ldc);
00105       }
00106     }
00107   };
00108 
00109 //   template <typename Scalar>
00110 //   struct DeviceGEMM<Scalar,Serial> {
00111 //     public:
00112 //       static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, Scalar alpha,
00113 //           View<const Scalar**,LayoutLeft,Serial> A, View<const Scalar**,LayoutLeft,Serial> B,
00114 //           Scalar beta, View<Scalar**,Serial> C) {
00115 //         Teuchos::BLAS<int,Scalar> blas;
00116 //         const int m = Teuchos::as<int>(C.dimension_0()),
00117 //                   n = Teuchos::as<int>(C.dimension_1()),
00118 //                   k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()),
00119 //                   lda = Teuchos::as<int>(Impl::getStride2DView(A)),
00120 //                   ldb = Teuchos::as<int>(Impl::getStride2DView(B)),
00121 //                   ldc = Teuchos::as<int>(Impl::getStride2DView(C));
00122 //         // For some BLAS implementations (i.e. MKL), GEMM when B has one column
00123 //         // is signficantly less efficient
00124 //         if (n == 1 && transB == Teuchos::NO_TRANS)
00125 //           blas.GEMV(transA, A.dimension_0(), A.dimension_1(), alpha, A.ptr_on_device(), lda, B.ptr_on_device(), Teuchos::as<int>(1), beta, C.ptr_on_device(), Teuchos::as<int>(1));
00126 //         else
00127 //           blas.GEMM(transA, transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc);
00128 //       }
00129 //   };
00130 
00131 // #ifdef KOKKOS_HAVE_PTHREAD
00132 //   template <typename Scalar>
00133 //   struct DeviceGEMM<Scalar,Threads> {
00134 //     public:
00135 //       static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, Scalar alpha,
00136 //           View<const Scalar**,LayoutLeft,Threads> A, View<const Scalar**,LayoutLeft,Threads> B,
00137 //           Scalar beta, View<Scalar**,LayoutLeft,Threads> C) {
00138 //         Teuchos::BLAS<int,Scalar> blas;
00139 //         const int m = Teuchos::as<int>(C.dimension_0()),
00140 //                   n = Teuchos::as<int>(C.dimension_1()),
00141 //                   k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()),
00142 //                   lda = Teuchos::as<int>(Impl::getStride2DView(A)),
00143 //                   ldb = Teuchos::as<int>(Impl::getStride2DView(B)),
00144 //                   ldc = Teuchos::as<int>(Impl::getStride2DView(C));
00145 //         blas.GEMM(transA, transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc);
00146 //       }
00147 //   };
00148 // #endif
00149 
00150 // #ifdef KOKKOS_HAVE_OPENMP
00151 //   template <typename Scalar>
00152 //   struct DeviceGEMM<Scalar,OpenMP> {
00153 //     public:
00154 //       static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, Scalar alpha,
00155 //           View<const Scalar**,LayoutLeft,OpenMP> A, View<const Scalar**,LayoutLeft,OpenMP> B,
00156 //           Scalar beta, View<Scalar**,LayoutLeft,OpenMP> C) {
00157 //         Teuchos::BLAS<int,Scalar> blas;
00158 //         const int m = Teuchos::as<int>(C.dimension_0()),
00159 //                   n = Teuchos::as<int>(C.dimension_1()),
00160 //                   k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()),
00161 //                   lda = Teuchos::as<int>(Impl::getStride2DView(A)),
00162 //                   ldb = Teuchos::as<int>(Impl::getStride2DView(B)),
00163 //                   ldc = Teuchos::as<int>(Impl::getStride2DView(C));
00164 //         blas.GEMM(transA, transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc);
00165 //       }
00166 //   };
00167 // #endif
00168 
00169 #ifdef KOKKOS_HAVE_CUDA
00170   template <typename Scalar>
00171   struct DeviceGEMM<Scalar,Cuda> {
00172     public:
00173       static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, Scalar alpha,
00174           View<const Scalar**,LayoutLeft,Cuda> A, View<const Scalar**,LayoutLeft,Cuda> B,
00175           Scalar beta, View<Scalar**,LayoutLeft,Cuda> C) {
00176         TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "DeviceGEMM: Kokkos::Cuda has no support for GEMM operations over Scalar=" << Teuchos::typeName(alpha) << ".");
00177       }
00178   };
00179 
00180 
00181   template <>
00182   struct DeviceGEMM<float,Cuda> {
00183     public:
00184       static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, float alpha,
00185           View<const float**,LayoutLeft,Cuda> A, View<const float**,LayoutLeft,Cuda> B,
00186           float beta, View<float**,LayoutLeft,Cuda> C) {
00187         const int m = Teuchos::as<int>(C.dimension_0()),
00188                   n = Teuchos::as<int>(C.dimension_1()),
00189                   k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()),
00190                   lda = Teuchos::as<int>(Impl::getStride2DView(A)),
00191                   ldb = Teuchos::as<int>(Impl::getStride2DView(B)),
00192                   ldc = Teuchos::as<int>(Impl::getStride2DView(C));
00193         const char char_transA = (transA == Teuchos::NO_TRANS ? 'N' : 'T'),
00194                    char_transB = (transB == Teuchos::NO_TRANS ? 'N' : 'T');
00195         cublasSgemm(char_transA, char_transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc);
00196 #ifdef HAVE_KOKKOS_DEBUG
00197         cublasStatus info = cublasGetError();
00198         TEUCHOS_TEST_FOR_EXCEPTION( info != CUBLAS_STATUS_SUCCESS, std::runtime_error, "cublasSgemm failed with status " << info << "." );
00199 #endif
00200       }
00201   };
00202 
00203   template <>
00204   struct DeviceGEMM<double,Cuda> {
00205     public:
00206       static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, double alpha,
00207           View<const double**,LayoutLeft,Cuda> A, View<const double**,LayoutLeft,Cuda> B,
00208           double beta, View<double**,LayoutLeft,Cuda> C) {
00209         const int m = Teuchos::as<int>(C.dimension_0()),
00210                   n = Teuchos::as<int>(C.dimension_1()),
00211                   k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()),
00212                   lda = Teuchos::as<int>(Impl::getStride2DView(A)),
00213                   ldb = Teuchos::as<int>(Impl::getStride2DView(B)),
00214                   ldc = Teuchos::as<int>(Impl::getStride2DView(C));
00215         const char char_transA = (transA == Teuchos::NO_TRANS ? 'N' : 'T'),
00216                    char_transB = (transB == Teuchos::NO_TRANS ? 'N' : 'T');
00217         cublasDgemm(char_transA, char_transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc);
00218 #ifdef HAVE_KOKKOS_DEBUG
00219         cublasStatus info = cublasGetError();
00220         TEUCHOS_TEST_FOR_EXCEPTION( info != CUBLAS_STATUS_SUCCESS, std::runtime_error, "cublasDgemm failed with status " << info << "." );
00221 #endif
00222       }
00223   };
00224 
00225 
00226 #endif
00227 }
00228 #endif
00229 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Defines