|
Teuchos - Trilinos Tools Package
Version of the Day
|
00001 // @HEADER 00002 // *********************************************************************** 00003 // 00004 // Teuchos: Common Tools Package 00005 // Copyright (2004) Sandia Corporation 00006 // 00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive 00008 // license for use of this work by or on behalf of the U.S. Government. 00009 // 00010 // Redistribution and use in source and binary forms, with or without 00011 // modification, are permitted provided that the following conditions are 00012 // met: 00013 // 00014 // 1. Redistributions of source code must retain the above copyright 00015 // notice, this list of conditions and the following disclaimer. 00016 // 00017 // 2. Redistributions in binary form must reproduce the above copyright 00018 // notice, this list of conditions and the following disclaimer in the 00019 // documentation and/or other materials provided with the distribution. 00020 // 00021 // 3. Neither the name of the Corporation nor the names of the 00022 // contributors may be used to endorse or promote products derived from 00023 // this software without specific prior written permission. 00024 // 00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY 00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE 00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00036 // 00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 00038 // 00039 // *********************************************************************** 00040 // @HEADER 00041 00042 // Kris 00043 // 06.16.03 -- Start over from scratch 00044 // 06.16.03 -- Initial templatization (Tpetra_BLAS.cpp is no longer needed) 00045 // 06.18.03 -- Changed xxxxx_() function calls to XXXXX_F77() 00046 // -- Added warning messages for generic calls 00047 // 07.08.03 -- Move into Teuchos package/namespace 00048 // 07.24.03 -- The first iteration of BLAS generics is nearing completion. Caveats: 00049 // * TRSM isn't finished yet; it works for one case at the moment (left side, upper tri., no transpose, no unit diag.) 00050 // * Many of the generic implementations are quite inefficient, ugly, or both. I wrote these to be easy to debug, not for efficiency or legibility. The next iteration will improve both of these aspects as much as possible. 00051 // * Very little verification of input parameters is done, save for the character-type arguments (TRANS, etc.) which is quite robust. 00052 // * All of the routines that make use of both an incx and incy parameter (which includes much of the L1 BLAS) are set up to work iff incx == incy && incx > 0. Allowing for differing/negative values of incx/incy should be relatively trivial. 00053 // * All of the L2/L3 routines assume that the entire matrix is being used (that is, if A is mxn, lda = m); they don't work on submatrices yet. This *should* be a reasonably trivial thing to fix, as well. 00054 // -- Removed warning messages for generic calls 00055 // 08.08.03 -- TRSM now works for all cases where SIDE == L and DIAG == N. DIAG == U is implemented but does not work correctly; SIDE == R is not yet implemented. 00056 // 08.14.03 -- TRSM now works for all cases and accepts (and uses) leading-dimension information. 00057 // 09.26.03 -- character input replaced with enumerated input to cause compiling errors and not run-time errors ( suggested by RAB ). 00058 00059 #ifndef _TEUCHOS_BLAS_HPP_ 00060 #define _TEUCHOS_BLAS_HPP_ 00061 00069 #include "Teuchos_ConfigDefs.hpp" 00070 #include "Teuchos_ScalarTraits.hpp" 00071 #include "Teuchos_OrdinalTraits.hpp" 00072 #include "Teuchos_BLAS_types.hpp" 00073 #include "Teuchos_Assert.hpp" 00074 00107 namespace Teuchos 00108 { 00109 extern TEUCHOSNUMERICS_LIB_DLL_EXPORT const char ESideChar[]; 00110 extern TEUCHOSNUMERICS_LIB_DLL_EXPORT const char ETranspChar[]; 00111 extern TEUCHOSNUMERICS_LIB_DLL_EXPORT const char EUploChar[]; 00112 extern TEUCHOSNUMERICS_LIB_DLL_EXPORT const char EDiagChar[]; 00113 extern TEUCHOSNUMERICS_LIB_DLL_EXPORT const char ETypeChar[]; 00114 00116 00121 template<typename OrdinalType, typename ScalarType> 00122 class DefaultBLASImpl 00123 { 00124 00125 typedef typename Teuchos::ScalarTraits<ScalarType>::magnitudeType MagnitudeType; 00126 00127 public: 00129 00130 00132 inline DefaultBLASImpl(void) {} 00133 00135 inline DefaultBLASImpl(const DefaultBLASImpl<OrdinalType, ScalarType>& /*BLAS_source*/) {} 00136 00138 inline virtual ~DefaultBLASImpl(void) {} 00140 00142 00143 00145 void ROTG(ScalarType* da, ScalarType* db, MagnitudeType* c, ScalarType* s) const; 00146 00148 void ROT(const OrdinalType n, ScalarType* dx, const OrdinalType incx, ScalarType* dy, const OrdinalType incy, MagnitudeType* c, ScalarType* s) const; 00149 00151 void SCAL(const OrdinalType n, const ScalarType alpha, ScalarType* x, const OrdinalType incx) const; 00152 00154 void COPY(const OrdinalType n, const ScalarType* x, const OrdinalType incx, ScalarType* y, const OrdinalType incy) const; 00155 00157 template <typename alpha_type, typename x_type> 00158 void AXPY(const OrdinalType n, const alpha_type alpha, const x_type* x, const OrdinalType incx, ScalarType* y, const OrdinalType incy) const; 00159 00161 typename ScalarTraits<ScalarType>::magnitudeType ASUM(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const; 00162 00164 template <typename x_type, typename y_type> 00165 ScalarType DOT(const OrdinalType n, const x_type* x, const OrdinalType incx, const y_type* y, const OrdinalType incy) const; 00166 00168 typename ScalarTraits<ScalarType>::magnitudeType NRM2(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const; 00169 00171 OrdinalType IAMAX(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const; 00173 00175 00176 00178 template <typename alpha_type, typename A_type, typename x_type, typename beta_type> 00179 void GEMV(ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type alpha, const A_type* A, 00180 const OrdinalType lda, const x_type* x, const OrdinalType incx, const beta_type beta, ScalarType* y, const OrdinalType incy) const; 00181 00183 template <typename A_type> 00184 void TRMV(EUplo uplo, ETransp trans, EDiag diag, const OrdinalType n, const A_type* A, 00185 const OrdinalType lda, ScalarType* x, const OrdinalType incx) const; 00186 00189 template <typename alpha_type, typename x_type, typename y_type> 00190 void GER(const OrdinalType m, const OrdinalType n, const alpha_type alpha, const x_type* x, const OrdinalType incx, 00191 const y_type* y, const OrdinalType incy, ScalarType* A, const OrdinalType lda) const; 00193 00195 00196 00203 template <typename alpha_type, typename A_type, typename B_type, typename beta_type> 00204 void GEMM(ETransp transa, ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type alpha, const A_type* A, const OrdinalType lda, const B_type* B, const OrdinalType ldb, const beta_type beta, ScalarType* C, const OrdinalType ldc) const; 00205 00207 template <typename alpha_type, typename A_type, typename B_type, typename beta_type> 00208 void SYMM(ESide side, EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type alpha, const A_type* A, const OrdinalType lda, const B_type* B, const OrdinalType ldb, const beta_type beta, ScalarType* C, const OrdinalType ldc) const; 00209 00211 template <typename alpha_type, typename A_type, typename beta_type> 00212 void SYRK(EUplo uplo, ETransp trans, const OrdinalType n, const OrdinalType k, const alpha_type alpha, const A_type* A, const OrdinalType lda, const beta_type beta, ScalarType* C, const OrdinalType ldc) const; 00213 00215 template <typename alpha_type, typename A_type> 00216 void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const OrdinalType m, const OrdinalType n, 00217 const alpha_type alpha, const A_type* A, const OrdinalType lda, ScalarType* B, const OrdinalType ldb) const; 00218 00220 template <typename alpha_type, typename A_type> 00221 void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const OrdinalType m, const OrdinalType n, 00222 const alpha_type alpha, const A_type* A, const OrdinalType lda, ScalarType* B, const OrdinalType ldb) const; 00224 }; 00225 00226 template<typename OrdinalType, typename ScalarType> 00227 class TEUCHOSNUMERICS_LIB_DLL_EXPORT BLAS : public DefaultBLASImpl<OrdinalType,ScalarType> 00228 { 00229 00230 typedef typename Teuchos::ScalarTraits<ScalarType>::magnitudeType MagnitudeType; 00231 00232 public: 00234 00235 00237 inline BLAS(void) {} 00238 00240 inline BLAS(const BLAS<OrdinalType, ScalarType>& /*BLAS_source*/) {} 00241 00243 inline virtual ~BLAS(void) {} 00245 }; 00246 00247 //------------------------------------------------------------------------------------------ 00248 // LEVEL 1 BLAS ROUTINES 00249 //------------------------------------------------------------------------------------------ 00250 00258 namespace details { 00259 00260 // Compute magnitude. 00261 template<typename ScalarType, bool isComplex> 00262 class MagValue { 00263 public: 00264 void 00265 blas_dabs1(const ScalarType* a, typename ScalarTraits<ScalarType>::magnitudeType* ret) const; 00266 }; 00267 00268 // Complex-arithmetic specialization. 00269 template<typename ScalarType> 00270 class MagValue<ScalarType, true> { 00271 public: 00272 void 00273 blas_dabs1(const ScalarType* a, typename ScalarTraits<ScalarType>::magnitudeType* ret) const; 00274 }; 00275 00276 // Real-arithmetic specialization. 00277 template<typename ScalarType> 00278 class MagValue<ScalarType, false> { 00279 public: 00280 void 00281 blas_dabs1(const ScalarType* a, ScalarType* ret) const; 00282 }; 00283 00284 template<typename ScalarType, bool isComplex> 00285 class GivensRotator { 00286 public: 00287 void 00288 ROTG (ScalarType* a, 00289 ScalarType* b, 00290 typename ScalarTraits<ScalarType>::magnitudeType* c, 00291 ScalarType* s) const; 00292 }; 00293 00294 // Complex-arithmetic specialization. 00295 template<typename ScalarType> 00296 class GivensRotator<ScalarType, true> { 00297 public: 00298 void 00299 ROTG (ScalarType* ca, 00300 ScalarType* cb, 00301 typename ScalarTraits<ScalarType>::magnitudeType* c, 00302 ScalarType* s) const; 00303 }; 00304 00305 // Real-arithmetic specialization. 00306 template<typename ScalarType> 00307 class GivensRotator<ScalarType, false> { 00308 public: 00309 void 00310 ROTG (ScalarType* da, 00311 ScalarType* db, 00312 ScalarType* c, 00313 ScalarType* s) const; 00314 private: 00327 ScalarType SIGN (ScalarType x, ScalarType y) const { 00328 typedef ScalarTraits<ScalarType> STS; 00329 00330 if (y > STS::zero()) { 00331 return STS::magnitude (x); 00332 } else if (y < STS::zero()) { 00333 return -STS::magnitude (x); 00334 } else { // y == STS::zero() 00335 // Suppose that ScalarType implements signed zero, as IEEE 00336 // 754 - compliant floating-point numbers should. You can't 00337 // use == to test for signed zero, since +0 == -0. However, 00338 // 1/0 = Inf > 0 and 1/-0 = -Inf < 0. Let's hope ScalarType 00339 // supports Inf... we don't need to test for Inf, just see 00340 // if it's greater than or less than zero. 00341 // 00342 // NOTE: This ONLY works if ScalarType is real. Complex 00343 // infinity doesn't have a sign, so we can't compare it with 00344 // zero. That's OK, because finite complex numbers don't 00345 // have a sign either; they have an angle. 00346 ScalarType signedInfinity = STS::one() / y; 00347 if (signedInfinity > STS::zero()) { 00348 return STS::magnitude (x); 00349 } else { 00350 // Even if ScalarType doesn't implement signed zero, 00351 // Fortran's SIGN intrinsic returns -ABS(X) if the second 00352 // argument Y is zero. We imitate this behavior here. 00353 return -STS::magnitude (x); 00354 } 00355 } 00356 } 00357 }; 00358 00359 // Implementation of complex-arithmetic specialization. 00360 template<typename ScalarType> 00361 void 00362 GivensRotator<ScalarType, true>:: 00363 ROTG (ScalarType* ca, 00364 ScalarType* cb, 00365 typename ScalarTraits<ScalarType>::magnitudeType* c, 00366 ScalarType* s) const 00367 { 00368 typedef ScalarTraits<ScalarType> STS; 00369 typedef typename STS::magnitudeType MagnitudeType; 00370 typedef ScalarTraits<MagnitudeType> STM; 00371 00372 // This is a straightforward translation into C++ of the 00373 // reference BLAS' implementation of ZROTG. You can get 00374 // the Fortran 77 source code of ZROTG here: 00375 // 00376 // http://www.netlib.org/blas/zrotg.f 00377 // 00378 // I used the following rules to translate Fortran types and 00379 // intrinsic functions into C++: 00380 // 00381 // DOUBLE PRECISION -> MagnitudeType 00382 // DOUBLE COMPLEX -> ScalarType 00383 // CDABS -> STS::magnitude 00384 // DCMPLX -> ScalarType constructor (assuming that ScalarType 00385 // is std::complex<MagnitudeType>) 00386 // DCONJG -> STS::conjugate 00387 // DSQRT -> STM::squareroot 00388 ScalarType alpha; 00389 MagnitudeType norm, scale; 00390 00391 if (STS::magnitude (*ca) == STM::zero()) { 00392 *c = STM::zero(); 00393 *s = STS::one(); 00394 *ca = *cb; 00395 } else { 00396 scale = STS::magnitude (*ca) + STS::magnitude (*cb); 00397 { // I introduced temporaries into the translated BLAS code in 00398 // order to make the expression easier to read and also save a 00399 // few floating-point operations. 00400 const MagnitudeType ca_scaled = 00401 STS::magnitude (*ca / ScalarType(scale, STM::zero())); 00402 const MagnitudeType cb_scaled = 00403 STS::magnitude (*cb / ScalarType(scale, STM::zero())); 00404 norm = scale * 00405 STM::squareroot (ca_scaled*ca_scaled + cb_scaled*cb_scaled); 00406 } 00407 alpha = *ca / STS::magnitude (*ca); 00408 *c = STS::magnitude (*ca) / norm; 00409 *s = alpha * STS::conjugate (*cb) / norm; 00410 *ca = alpha * norm; 00411 } 00412 } 00413 00414 // Implementation of real-arithmetic specialization. 00415 template<typename ScalarType> 00416 void 00417 GivensRotator<ScalarType, false>:: 00418 ROTG (ScalarType* da, 00419 ScalarType* db, 00420 ScalarType* c, 00421 ScalarType* s) const 00422 { 00423 typedef ScalarTraits<ScalarType> STS; 00424 00425 // This is a straightforward translation into C++ of the 00426 // reference BLAS' implementation of DROTG. You can get 00427 // the Fortran 77 source code of DROTG here: 00428 // 00429 // http://www.netlib.org/blas/drotg.f 00430 // 00431 // I used the following rules to translate Fortran types and 00432 // intrinsic functions into C++: 00433 // 00434 // DOUBLE PRECISION -> ScalarType 00435 // DABS -> STS::magnitude 00436 // DSQRT -> STM::squareroot 00437 // DSIGN -> SIGN (see below) 00438 // 00439 // DSIGN(x,y) (the old DOUBLE PRECISION type-specific form of 00440 // the Fortran type-generic SIGN intrinsic) required special 00441 // translation, which we did in a separate utility function in 00442 // the specializaton of GivensRotator for real arithmetic. 00443 // (ROTG for complex arithmetic doesn't require this function.) 00444 // C99 provides a copysign() math library function, but we are 00445 // not able to rely on the existence of C99 functions here. 00446 ScalarType r, roe, scale, z; 00447 00448 roe = *db; 00449 if (STS::magnitude (*da) > STS::magnitude (*db)) { 00450 roe = *da; 00451 } 00452 scale = STS::magnitude (*da) + STS::magnitude (*db); 00453 if (scale == STS::zero()) { 00454 *c = STS::one(); 00455 *s = STS::zero(); 00456 r = STS::zero(); 00457 z = STS::zero(); 00458 } else { 00459 // I introduced temporaries into the translated BLAS code in 00460 // order to make the expression easier to read and also save 00461 // a few floating-point operations. 00462 const ScalarType da_scaled = *da / scale; 00463 const ScalarType db_scaled = *db / scale; 00464 r = scale * STS::squareroot (da_scaled*da_scaled + db_scaled*db_scaled); 00465 r = SIGN (STS::one(), roe) * r; 00466 *c = *da / r; 00467 *s = *db / r; 00468 z = STS::one(); 00469 if (STS::magnitude (*da) > STS::magnitude (*db)) { 00470 z = *s; 00471 } 00472 if (STS::magnitude (*db) >= STS::magnitude (*da) && *c != STS::zero()) { 00473 z = STS::one() / *c; 00474 } 00475 } 00476 00477 *da = r; 00478 *db = z; 00479 } 00480 00481 // Real-valued implementation of MagValue 00482 template<typename ScalarType> 00483 void 00484 MagValue<ScalarType, false>:: 00485 blas_dabs1(const ScalarType* a, ScalarType* ret) const 00486 { 00487 *ret = Teuchos::ScalarTraits<ScalarType>::magnitude( *a ); 00488 } 00489 00490 // Complex-valued implementation of MagValue 00491 template<typename ScalarType> 00492 void 00493 MagValue<ScalarType, true>:: 00494 blas_dabs1(const ScalarType* a, typename ScalarTraits<ScalarType>::magnitudeType* ret) const 00495 { 00496 *ret = ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::magnitude(a->real()); 00497 *ret += ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::magnitude(a->imag()); 00498 } 00499 00500 } // namespace details 00501 00502 template<typename OrdinalType, typename ScalarType> 00503 void 00504 DefaultBLASImpl<OrdinalType, ScalarType>:: 00505 ROTG (ScalarType* da, 00506 ScalarType* db, 00507 MagnitudeType* c, 00508 ScalarType* s) const 00509 { 00510 typedef ScalarTraits<ScalarType> STS; 00511 details::GivensRotator<ScalarType, STS::isComplex> rotator; 00512 rotator.ROTG (da, db, c, s); 00513 } 00514 00515 template<typename OrdinalType, typename ScalarType> 00516 void DefaultBLASImpl<OrdinalType,ScalarType>::ROT(const OrdinalType n, ScalarType* dx, const OrdinalType incx, ScalarType* dy, const OrdinalType incy, MagnitudeType* c, ScalarType* s) const 00517 { 00518 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00519 ScalarType sconj = Teuchos::ScalarTraits<ScalarType>::conjugate(*s); 00520 if (n <= 0) return; 00521 if (incx==1 && incy==1) { 00522 for(OrdinalType i=0; i<n; ++i) { 00523 ScalarType temp = *c*dx[i] + sconj*dy[i]; 00524 dy[i] = *c*dy[i] - sconj*dx[i]; 00525 dx[i] = temp; 00526 } 00527 } 00528 else { 00529 OrdinalType ix = 0, iy = 0; 00530 if (incx < izero) ix = (-n+1)*incx; 00531 if (incy < izero) iy = (-n+1)*incy; 00532 for(OrdinalType i=0; i<n; ++i) { 00533 ScalarType temp = *c*dx[ix] + sconj*dy[iy]; 00534 dy[iy] = *c*dy[iy] - sconj*dx[ix]; 00535 dx[ix] = temp; 00536 ix += incx; 00537 iy += incy; 00538 } 00539 } 00540 } 00541 00542 template<typename OrdinalType, typename ScalarType> 00543 void DefaultBLASImpl<OrdinalType, ScalarType>::SCAL(const OrdinalType n, const ScalarType alpha, ScalarType* x, const OrdinalType incx) const 00544 { 00545 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00546 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00547 OrdinalType i, ix = izero; 00548 00549 if ( n < ione || incx < ione ) 00550 return; 00551 00552 // Scale the vector. 00553 for(i = izero; i < n; i++) 00554 { 00555 x[ix] *= alpha; 00556 ix += incx; 00557 } 00558 } /* end SCAL */ 00559 00560 template<typename OrdinalType, typename ScalarType> 00561 void DefaultBLASImpl<OrdinalType, ScalarType>::COPY(const OrdinalType n, const ScalarType* x, const OrdinalType incx, ScalarType* y, const OrdinalType incy) const 00562 { 00563 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00564 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00565 OrdinalType i, ix = izero, iy = izero; 00566 if ( n > izero ) { 00567 // Set the initial indices (ix, iy). 00568 if (incx < izero) { ix = (-n+ione)*incx; } 00569 if (incy < izero) { iy = (-n+ione)*incy; } 00570 00571 for(i = izero; i < n; i++) 00572 { 00573 y[iy] = x[ix]; 00574 ix += incx; 00575 iy += incy; 00576 } 00577 } 00578 } /* end COPY */ 00579 00580 template<typename OrdinalType, typename ScalarType> 00581 template <typename alpha_type, typename x_type> 00582 void DefaultBLASImpl<OrdinalType, ScalarType>::AXPY(const OrdinalType n, const alpha_type alpha, const x_type* x, const OrdinalType incx, ScalarType* y, const OrdinalType incy) const 00583 { 00584 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00585 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00586 OrdinalType i, ix = izero, iy = izero; 00587 if( n > izero && alpha != ScalarTraits<alpha_type>::zero()) 00588 { 00589 // Set the initial indices (ix, iy). 00590 if (incx < izero) { ix = (-n+ione)*incx; } 00591 if (incy < izero) { iy = (-n+ione)*incy; } 00592 00593 for(i = izero; i < n; i++) 00594 { 00595 y[iy] += alpha * x[ix]; 00596 ix += incx; 00597 iy += incy; 00598 } 00599 } 00600 } /* end AXPY */ 00601 00602 template<typename OrdinalType, typename ScalarType> 00603 typename ScalarTraits<ScalarType>::magnitudeType DefaultBLASImpl<OrdinalType, ScalarType>::ASUM(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const 00604 { 00605 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00606 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00607 typename ScalarTraits<ScalarType>::magnitudeType temp, result = 00608 ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::zero(); 00609 OrdinalType i, ix = izero; 00610 00611 if ( n < ione || incx < ione ) 00612 return result; 00613 00614 details::MagValue<ScalarType, ScalarTraits<ScalarType>::isComplex> mval; 00615 for (i = izero; i < n; i++) 00616 { 00617 mval.blas_dabs1( &x[ix], &temp ); 00618 result += temp; 00619 ix += incx; 00620 } 00621 00622 return result; 00623 } /* end ASUM */ 00624 00625 template<typename OrdinalType, typename ScalarType> 00626 template <typename x_type, typename y_type> 00627 ScalarType DefaultBLASImpl<OrdinalType, ScalarType>::DOT(const OrdinalType n, const x_type* x, const OrdinalType incx, const y_type* y, const OrdinalType incy) const 00628 { 00629 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00630 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00631 ScalarType result = ScalarTraits<ScalarType>::zero(); 00632 OrdinalType i, ix = izero, iy = izero; 00633 if( n > izero ) 00634 { 00635 // Set the initial indices (ix,iy). 00636 if (incx < izero) { ix = (-n+ione)*incx; } 00637 if (incy < izero) { iy = (-n+ione)*incy; } 00638 00639 for(i = izero; i < n; i++) 00640 { 00641 result += ScalarTraits<x_type>::conjugate(x[ix]) * y[iy]; 00642 ix += incx; 00643 iy += incy; 00644 } 00645 } 00646 return result; 00647 } /* end DOT */ 00648 00649 template<typename OrdinalType, typename ScalarType> 00650 typename ScalarTraits<ScalarType>::magnitudeType DefaultBLASImpl<OrdinalType, ScalarType>::NRM2(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const 00651 { 00652 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00653 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00654 typename ScalarTraits<ScalarType>::magnitudeType result = 00655 ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::zero(); 00656 OrdinalType i, ix = izero; 00657 00658 if ( n < ione || incx < ione ) 00659 return result; 00660 00661 for(i = izero; i < n; i++) 00662 { 00663 result += ScalarTraits<ScalarType>::magnitude(ScalarTraits<ScalarType>::conjugate(x[ix]) * x[ix]); 00664 ix += incx; 00665 } 00666 result = ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::squareroot(result); 00667 return result; 00668 } /* end NRM2 */ 00669 00670 template<typename OrdinalType, typename ScalarType> 00671 OrdinalType DefaultBLASImpl<OrdinalType, ScalarType>::IAMAX(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const 00672 { 00673 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00674 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00675 OrdinalType result = izero, ix = izero, i; 00676 typename ScalarTraits<ScalarType>::magnitudeType curval = 00677 ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::zero(); 00678 typename ScalarTraits<ScalarType>::magnitudeType maxval = 00679 ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::zero(); 00680 00681 if ( n < ione || incx < ione ) 00682 return result; 00683 00684 details::MagValue<ScalarType, ScalarTraits<ScalarType>::isComplex> mval; 00685 00686 mval.blas_dabs1( &x[ix], &maxval ); 00687 ix += incx; 00688 for(i = ione; i < n; i++) 00689 { 00690 mval.blas_dabs1( &x[ix], &curval ); 00691 if(curval > maxval) 00692 { 00693 result = i; 00694 maxval = curval; 00695 } 00696 ix += incx; 00697 } 00698 00699 return result + 1; // the BLAS I?AMAX functions return 1-indexed (Fortran-style) values 00700 } /* end IAMAX */ 00701 00702 //------------------------------------------------------------------------------------------ 00703 // LEVEL 2 BLAS ROUTINES 00704 //------------------------------------------------------------------------------------------ 00705 template<typename OrdinalType, typename ScalarType> 00706 template <typename alpha_type, typename A_type, typename x_type, typename beta_type> 00707 void DefaultBLASImpl<OrdinalType, ScalarType>::GEMV(ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type alpha, const A_type* A, const OrdinalType lda, const x_type* x, const OrdinalType incx, const beta_type beta, ScalarType* y, const OrdinalType incy) const 00708 { 00709 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00710 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00711 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 00712 beta_type beta_zero = ScalarTraits<beta_type>::zero(); 00713 x_type x_zero = ScalarTraits<x_type>::zero(); 00714 ScalarType y_zero = ScalarTraits<ScalarType>::zero(); 00715 beta_type beta_one = ScalarTraits<beta_type>::one(); 00716 bool noConj = true; 00717 bool BadArgument = false; 00718 00719 // Quick return if there is nothing to do! 00720 if( m == izero || n == izero || ( alpha == alpha_zero && beta == beta_one ) ){ return; } 00721 00722 // Otherwise, we need to check the argument list. 00723 if( m < izero ) { 00724 std::cout << "BLAS::GEMV Error: M == " << m << std::endl; 00725 BadArgument = true; 00726 } 00727 if( n < izero ) { 00728 std::cout << "BLAS::GEMV Error: N == " << n << std::endl; 00729 BadArgument = true; 00730 } 00731 if( lda < m ) { 00732 std::cout << "BLAS::GEMV Error: LDA < MAX(1,M)"<< std::endl; 00733 BadArgument = true; 00734 } 00735 if( incx == izero ) { 00736 std::cout << "BLAS::GEMV Error: INCX == 0"<< std::endl; 00737 BadArgument = true; 00738 } 00739 if( incy == izero ) { 00740 std::cout << "BLAS::GEMV Error: INCY == 0"<< std::endl; 00741 BadArgument = true; 00742 } 00743 00744 if(!BadArgument) { 00745 OrdinalType i, j, lenx, leny, ix, iy, jx, jy; 00746 OrdinalType kx = izero, ky = izero; 00747 ScalarType temp; 00748 00749 // Determine the lengths of the vectors x and y. 00750 if(ETranspChar[trans] == 'N') { 00751 lenx = n; 00752 leny = m; 00753 } else { 00754 lenx = m; 00755 leny = n; 00756 } 00757 00758 // Determine if this is a conjugate tranpose 00759 noConj = (ETranspChar[trans] == 'T'); 00760 00761 // Set the starting pointers for the vectors x and y if incx/y < 0. 00762 if (incx < izero ) { kx = (ione - lenx)*incx; } 00763 if (incy < izero ) { ky = (ione - leny)*incy; } 00764 00765 // Form y = beta*y 00766 ix = kx; iy = ky; 00767 if(beta != beta_one) { 00768 if (incy == ione) { 00769 if (beta == beta_zero) { 00770 for(i = izero; i < leny; i++) { y[i] = y_zero; } 00771 } else { 00772 for(i = izero; i < leny; i++) { y[i] *= beta; } 00773 } 00774 } else { 00775 if (beta == beta_zero) { 00776 for(i = izero; i < leny; i++) { 00777 y[iy] = y_zero; 00778 iy += incy; 00779 } 00780 } else { 00781 for(i = izero; i < leny; i++) { 00782 y[iy] *= beta; 00783 iy += incy; 00784 } 00785 } 00786 } 00787 } 00788 00789 // Return if we don't have to do anything more. 00790 if(alpha == alpha_zero) { return; } 00791 00792 if( ETranspChar[trans] == 'N' ) { 00793 // Form y = alpha*A*y 00794 jx = kx; 00795 if (incy == ione) { 00796 for(j = izero; j < n; j++) { 00797 if (x[jx] != x_zero) { 00798 temp = alpha*x[jx]; 00799 for(i = izero; i < m; i++) { 00800 y[i] += temp*A[j*lda + i]; 00801 } 00802 } 00803 jx += incx; 00804 } 00805 } else { 00806 for(j = izero; j < n; j++) { 00807 if (x[jx] != x_zero) { 00808 temp = alpha*x[jx]; 00809 iy = ky; 00810 for(i = izero; i < m; i++) { 00811 y[iy] += temp*A[j*lda + i]; 00812 iy += incy; 00813 } 00814 } 00815 jx += incx; 00816 } 00817 } 00818 } else { 00819 jy = ky; 00820 if (incx == ione) { 00821 for(j = izero; j < n; j++) { 00822 temp = y_zero; 00823 if ( noConj ) { 00824 for(i = izero; i < m; i++) { 00825 temp += A[j*lda + i]*x[i]; 00826 } 00827 } else { 00828 for(i = izero; i < m; i++) { 00829 temp += ScalarTraits<A_type>::conjugate(A[j*lda + i])*x[i]; 00830 } 00831 } 00832 y[jy] += alpha*temp; 00833 jy += incy; 00834 } 00835 } else { 00836 for(j = izero; j < n; j++) { 00837 temp = y_zero; 00838 ix = kx; 00839 if ( noConj ) { 00840 for (i = izero; i < m; i++) { 00841 temp += A[j*lda + i]*x[ix]; 00842 ix += incx; 00843 } 00844 } else { 00845 for (i = izero; i < m; i++) { 00846 temp += ScalarTraits<A_type>::conjugate(A[j*lda + i])*x[ix]; 00847 ix += incx; 00848 } 00849 } 00850 y[jy] += alpha*temp; 00851 jy += incy; 00852 } 00853 } 00854 } 00855 } /* if (!BadArgument) */ 00856 } /* end GEMV */ 00857 00858 template<typename OrdinalType, typename ScalarType> 00859 template <typename A_type> 00860 void DefaultBLASImpl<OrdinalType, ScalarType>::TRMV(EUplo uplo, ETransp trans, EDiag diag, const OrdinalType n, const A_type* A, const OrdinalType lda, ScalarType* x, const OrdinalType incx) const 00861 { 00862 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00863 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00864 ScalarType zero = ScalarTraits<ScalarType>::zero(); 00865 bool BadArgument = false; 00866 bool noConj = true; 00867 00868 // Quick return if there is nothing to do! 00869 if( n == izero ){ return; } 00870 00871 // Otherwise, we need to check the argument list. 00872 if( n < izero ) { 00873 std::cout << "BLAS::TRMV Error: N == " << n << std::endl; 00874 BadArgument = true; 00875 } 00876 if( lda < n ) { 00877 std::cout << "BLAS::TRMV Error: LDA < MAX(1,N)"<< std::endl; 00878 BadArgument = true; 00879 } 00880 if( incx == izero ) { 00881 std::cout << "BLAS::TRMV Error: INCX == 0"<< std::endl; 00882 BadArgument = true; 00883 } 00884 00885 if(!BadArgument) { 00886 OrdinalType i, j, ix, jx, kx = izero; 00887 ScalarType temp; 00888 bool noUnit = (EDiagChar[diag] == 'N'); 00889 00890 // Determine if this is a conjugate tranpose 00891 noConj = (ETranspChar[trans] == 'T'); 00892 00893 // Set the starting pointer for the vector x if incx < 0. 00894 if (incx < izero) { kx = (-n+ione)*incx; } 00895 00896 // Start the operations for a nontransposed triangular matrix 00897 if (ETranspChar[trans] == 'N') { 00898 /* Compute x = A*x */ 00899 if (EUploChar[uplo] == 'U') { 00900 /* A is an upper triangular matrix */ 00901 if (incx == ione) { 00902 for (j=izero; j<n; j++) { 00903 if (x[j] != zero) { 00904 temp = x[j]; 00905 for (i=izero; i<j; i++) { 00906 x[i] += temp*A[j*lda + i]; 00907 } 00908 if ( noUnit ) 00909 x[j] *= A[j*lda + j]; 00910 } 00911 } 00912 } else { 00913 jx = kx; 00914 for (j=izero; j<n; j++) { 00915 if (x[jx] != zero) { 00916 temp = x[jx]; 00917 ix = kx; 00918 for (i=izero; i<j; i++) { 00919 x[ix] += temp*A[j*lda + i]; 00920 ix += incx; 00921 } 00922 if ( noUnit ) 00923 x[jx] *= A[j*lda + j]; 00924 } 00925 jx += incx; 00926 } 00927 } /* if (incx == ione) */ 00928 } else { /* A is a lower triangular matrix */ 00929 if (incx == ione) { 00930 for (j=n-ione; j>-ione; j--) { 00931 if (x[j] != zero) { 00932 temp = x[j]; 00933 for (i=n-ione; i>j; i--) { 00934 x[i] += temp*A[j*lda + i]; 00935 } 00936 if ( noUnit ) 00937 x[j] *= A[j*lda + j]; 00938 } 00939 } 00940 } else { 00941 kx += (n-ione)*incx; 00942 jx = kx; 00943 for (j=n-ione; j>-ione; j--) { 00944 if (x[jx] != zero) { 00945 temp = x[jx]; 00946 ix = kx; 00947 for (i=n-ione; i>j; i--) { 00948 x[ix] += temp*A[j*lda + i]; 00949 ix -= incx; 00950 } 00951 if ( noUnit ) 00952 x[jx] *= A[j*lda + j]; 00953 } 00954 jx -= incx; 00955 } 00956 } 00957 } /* if (EUploChar[uplo]=='U') */ 00958 } else { /* A is transposed/conjugated */ 00959 /* Compute x = A'*x */ 00960 if (EUploChar[uplo]=='U') { 00961 /* A is an upper triangular matrix */ 00962 if (incx == ione) { 00963 for (j=n-ione; j>-ione; j--) { 00964 temp = x[j]; 00965 if ( noConj ) { 00966 if ( noUnit ) 00967 temp *= A[j*lda + j]; 00968 for (i=j-ione; i>-ione; i--) { 00969 temp += A[j*lda + i]*x[i]; 00970 } 00971 } else { 00972 if ( noUnit ) 00973 temp *= ScalarTraits<A_type>::conjugate(A[j*lda + j]); 00974 for (i=j-ione; i>-ione; i--) { 00975 temp += ScalarTraits<A_type>::conjugate(A[j*lda + i])*x[i]; 00976 } 00977 } 00978 x[j] = temp; 00979 } 00980 } else { 00981 jx = kx + (n-ione)*incx; 00982 for (j=n-ione; j>-ione; j--) { 00983 temp = x[jx]; 00984 ix = jx; 00985 if ( noConj ) { 00986 if ( noUnit ) 00987 temp *= A[j*lda + j]; 00988 for (i=j-ione; i>-ione; i--) { 00989 ix -= incx; 00990 temp += A[j*lda + i]*x[ix]; 00991 } 00992 } else { 00993 if ( noUnit ) 00994 temp *= ScalarTraits<A_type>::conjugate(A[j*lda + j]); 00995 for (i=j-ione; i>-ione; i--) { 00996 ix -= incx; 00997 temp += ScalarTraits<A_type>::conjugate(A[j*lda + i])*x[ix]; 00998 } 00999 } 01000 x[jx] = temp; 01001 jx -= incx; 01002 } 01003 } 01004 } else { 01005 /* A is a lower triangular matrix */ 01006 if (incx == ione) { 01007 for (j=izero; j<n; j++) { 01008 temp = x[j]; 01009 if ( noConj ) { 01010 if ( noUnit ) 01011 temp *= A[j*lda + j]; 01012 for (i=j+ione; i<n; i++) { 01013 temp += A[j*lda + i]*x[i]; 01014 } 01015 } else { 01016 if ( noUnit ) 01017 temp *= ScalarTraits<A_type>::conjugate(A[j*lda + j]); 01018 for (i=j+ione; i<n; i++) { 01019 temp += ScalarTraits<A_type>::conjugate(A[j*lda + i])*x[i]; 01020 } 01021 } 01022 x[j] = temp; 01023 } 01024 } else { 01025 jx = kx; 01026 for (j=izero; j<n; j++) { 01027 temp = x[jx]; 01028 ix = jx; 01029 if ( noConj ) { 01030 if ( noUnit ) 01031 temp *= A[j*lda + j]; 01032 for (i=j+ione; i<n; i++) { 01033 ix += incx; 01034 temp += A[j*lda + i]*x[ix]; 01035 } 01036 } else { 01037 if ( noUnit ) 01038 temp *= ScalarTraits<A_type>::conjugate(A[j*lda + j]); 01039 for (i=j+ione; i<n; i++) { 01040 ix += incx; 01041 temp += ScalarTraits<A_type>::conjugate(A[j*lda + i])*x[ix]; 01042 } 01043 } 01044 x[jx] = temp; 01045 jx += incx; 01046 } 01047 } 01048 } /* if (EUploChar[uplo]=='U') */ 01049 } /* if (ETranspChar[trans]=='N') */ 01050 } /* if (!BadArgument) */ 01051 } /* end TRMV */ 01052 01053 template<typename OrdinalType, typename ScalarType> 01054 template <typename alpha_type, typename x_type, typename y_type> 01055 void DefaultBLASImpl<OrdinalType, ScalarType>::GER(const OrdinalType m, const OrdinalType n, const alpha_type alpha, const x_type* x, const OrdinalType incx, const y_type* y, const OrdinalType incy, ScalarType* A, const OrdinalType lda) const 01056 { 01057 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 01058 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 01059 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 01060 y_type y_zero = ScalarTraits<y_type>::zero(); 01061 bool BadArgument = false; 01062 01063 // Quick return if there is nothing to do! 01064 if( m == izero || n == izero || alpha == alpha_zero ){ return; } 01065 01066 // Otherwise, we need to check the argument list. 01067 if( m < izero ) { 01068 std::cout << "BLAS::GER Error: M == " << m << std::endl; 01069 BadArgument = true; 01070 } 01071 if( n < izero ) { 01072 std::cout << "BLAS::GER Error: N == " << n << std::endl; 01073 BadArgument = true; 01074 } 01075 if( lda < m ) { 01076 std::cout << "BLAS::GER Error: LDA < MAX(1,M)"<< std::endl; 01077 BadArgument = true; 01078 } 01079 if( incx == 0 ) { 01080 std::cout << "BLAS::GER Error: INCX == 0"<< std::endl; 01081 BadArgument = true; 01082 } 01083 if( incy == 0 ) { 01084 std::cout << "BLAS::GER Error: INCY == 0"<< std::endl; 01085 BadArgument = true; 01086 } 01087 01088 if(!BadArgument) { 01089 OrdinalType i, j, ix, jy = izero, kx = izero; 01090 ScalarType temp; 01091 01092 // Set the starting pointers for the vectors x and y if incx/y < 0. 01093 if (incx < izero) { kx = (-m+ione)*incx; } 01094 if (incy < izero) { jy = (-n+ione)*incy; } 01095 01096 // Start the operations for incx == 1 01097 if( incx == ione ) { 01098 for( j=izero; j<n; j++ ) { 01099 if ( y[jy] != y_zero ) { 01100 temp = alpha*y[jy]; 01101 for ( i=izero; i<m; i++ ) { 01102 A[j*lda + i] += x[i]*temp; 01103 } 01104 } 01105 jy += incy; 01106 } 01107 } 01108 else { // Start the operations for incx != 1 01109 for( j=izero; j<n; j++ ) { 01110 if ( y[jy] != y_zero ) { 01111 temp = alpha*y[jy]; 01112 ix = kx; 01113 for( i=izero; i<m; i++ ) { 01114 A[j*lda + i] += x[ix]*temp; 01115 ix += incx; 01116 } 01117 } 01118 jy += incy; 01119 } 01120 } 01121 } /* if(!BadArgument) */ 01122 } /* end GER */ 01123 01124 //------------------------------------------------------------------------------------------ 01125 // LEVEL 3 BLAS ROUTINES 01126 //------------------------------------------------------------------------------------------ 01127 01128 template<typename OrdinalType, typename ScalarType> 01129 template <typename alpha_type, typename A_type, typename B_type, typename beta_type> 01130 void DefaultBLASImpl<OrdinalType, ScalarType>::GEMM(ETransp transa, ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type alpha, const A_type* A, const OrdinalType lda, const B_type* B, const OrdinalType ldb, const beta_type beta, ScalarType* C, const OrdinalType ldc) const 01131 { 01132 01133 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 01134 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 01135 beta_type beta_zero = ScalarTraits<beta_type>::zero(); 01136 B_type B_zero = ScalarTraits<B_type>::zero(); 01137 ScalarType C_zero = ScalarTraits<ScalarType>::zero(); 01138 beta_type beta_one = ScalarTraits<beta_type>::one(); 01139 OrdinalType i, j, p; 01140 OrdinalType NRowA = m, NRowB = k; 01141 ScalarType temp; 01142 bool BadArgument = false; 01143 bool conjA = false, conjB = false; 01144 01145 // Change dimensions of matrix if either matrix is transposed 01146 if( !(ETranspChar[transa]=='N') ) { 01147 NRowA = k; 01148 } 01149 if( !(ETranspChar[transb]=='N') ) { 01150 NRowB = n; 01151 } 01152 01153 // Quick return if there is nothing to do! 01154 if( (m==izero) || (n==izero) || (((alpha==alpha_zero)||(k==izero)) && (beta==beta_one)) ){ return; } 01155 if( m < izero ) { 01156 std::cout << "BLAS::GEMM Error: M == " << m << std::endl; 01157 BadArgument = true; 01158 } 01159 if( n < izero ) { 01160 std::cout << "BLAS::GEMM Error: N == " << n << std::endl; 01161 BadArgument = true; 01162 } 01163 if( k < izero ) { 01164 std::cout << "BLAS::GEMM Error: K == " << k << std::endl; 01165 BadArgument = true; 01166 } 01167 if( lda < NRowA ) { 01168 std::cout << "BLAS::GEMM Error: LDA < "<<NRowA<<std::endl; 01169 BadArgument = true; 01170 } 01171 if( ldb < NRowB ) { 01172 std::cout << "BLAS::GEMM Error: LDB < "<<NRowB<<std::endl; 01173 BadArgument = true; 01174 } 01175 if( ldc < m ) { 01176 std::cout << "BLAS::GEMM Error: LDC < MAX(1,M)"<< std::endl; 01177 BadArgument = true; 01178 } 01179 01180 if(!BadArgument) { 01181 01182 // Determine if this is a conjugate tranpose 01183 conjA = (ETranspChar[transa] == 'C'); 01184 conjB = (ETranspChar[transb] == 'C'); 01185 01186 // Only need to scale the resulting matrix C. 01187 if( alpha == alpha_zero ) { 01188 if( beta == beta_zero ) { 01189 for (j=izero; j<n; j++) { 01190 for (i=izero; i<m; i++) { 01191 C[j*ldc + i] = C_zero; 01192 } 01193 } 01194 } else { 01195 for (j=izero; j<n; j++) { 01196 for (i=izero; i<m; i++) { 01197 C[j*ldc + i] *= beta; 01198 } 01199 } 01200 } 01201 return; 01202 } 01203 // 01204 // Now start the operations. 01205 // 01206 if ( ETranspChar[transb]=='N' ) { 01207 if ( ETranspChar[transa]=='N' ) { 01208 // Compute C = alpha*A*B + beta*C 01209 for (j=izero; j<n; j++) { 01210 if( beta == beta_zero ) { 01211 for (i=izero; i<m; i++){ 01212 C[j*ldc + i] = C_zero; 01213 } 01214 } else if( beta != beta_one ) { 01215 for (i=izero; i<m; i++){ 01216 C[j*ldc + i] *= beta; 01217 } 01218 } 01219 for (p=izero; p<k; p++){ 01220 if (B[j*ldb + p] != B_zero ){ 01221 temp = alpha*B[j*ldb + p]; 01222 for (i=izero; i<m; i++) { 01223 C[j*ldc + i] += temp*A[p*lda + i]; 01224 } 01225 } 01226 } 01227 } 01228 } else if ( conjA ) { 01229 // Compute C = alpha*conj(A')*B + beta*C 01230 for (j=izero; j<n; j++) { 01231 for (i=izero; i<m; i++) { 01232 temp = C_zero; 01233 for (p=izero; p<k; p++) { 01234 temp += ScalarTraits<A_type>::conjugate(A[i*lda + p])*B[j*ldb + p]; 01235 } 01236 if (beta == beta_zero) { 01237 C[j*ldc + i] = alpha*temp; 01238 } else { 01239 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01240 } 01241 } 01242 } 01243 } else { 01244 // Compute C = alpha*A'*B + beta*C 01245 for (j=izero; j<n; j++) { 01246 for (i=izero; i<m; i++) { 01247 temp = C_zero; 01248 for (p=izero; p<k; p++) { 01249 temp += A[i*lda + p]*B[j*ldb + p]; 01250 } 01251 if (beta == beta_zero) { 01252 C[j*ldc + i] = alpha*temp; 01253 } else { 01254 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01255 } 01256 } 01257 } 01258 } 01259 } else if ( ETranspChar[transa]=='N' ) { 01260 if ( conjB ) { 01261 // Compute C = alpha*A*conj(B') + beta*C 01262 for (j=izero; j<n; j++) { 01263 if (beta == beta_zero) { 01264 for (i=izero; i<m; i++) { 01265 C[j*ldc + i] = C_zero; 01266 } 01267 } else if ( beta != beta_one ) { 01268 for (i=izero; i<m; i++) { 01269 C[j*ldc + i] *= beta; 01270 } 01271 } 01272 for (p=izero; p<k; p++) { 01273 if (B[p*ldb + j] != B_zero) { 01274 temp = alpha*ScalarTraits<B_type>::conjugate(B[p*ldb + j]); 01275 for (i=izero; i<m; i++) { 01276 C[j*ldc + i] += temp*A[p*lda + i]; 01277 } 01278 } 01279 } 01280 } 01281 } else { 01282 // Compute C = alpha*A*B' + beta*C 01283 for (j=izero; j<n; j++) { 01284 if (beta == beta_zero) { 01285 for (i=izero; i<m; i++) { 01286 C[j*ldc + i] = C_zero; 01287 } 01288 } else if ( beta != beta_one ) { 01289 for (i=izero; i<m; i++) { 01290 C[j*ldc + i] *= beta; 01291 } 01292 } 01293 for (p=izero; p<k; p++) { 01294 if (B[p*ldb + j] != B_zero) { 01295 temp = alpha*B[p*ldb + j]; 01296 for (i=izero; i<m; i++) { 01297 C[j*ldc + i] += temp*A[p*lda + i]; 01298 } 01299 } 01300 } 01301 } 01302 } 01303 } else if ( conjA ) { 01304 if ( conjB ) { 01305 // Compute C = alpha*conj(A')*conj(B') + beta*C 01306 for (j=izero; j<n; j++) { 01307 for (i=izero; i<m; i++) { 01308 temp = C_zero; 01309 for (p=izero; p<k; p++) { 01310 temp += ScalarTraits<A_type>::conjugate(A[i*lda + p]) 01311 * ScalarTraits<B_type>::conjugate(B[p*ldb + j]); 01312 } 01313 if (beta == beta_zero) { 01314 C[j*ldc + i] = alpha*temp; 01315 } else { 01316 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01317 } 01318 } 01319 } 01320 } else { 01321 // Compute C = alpha*conj(A')*B' + beta*C 01322 for (j=izero; j<n; j++) { 01323 for (i=izero; i<m; i++) { 01324 temp = C_zero; 01325 for (p=izero; p<k; p++) { 01326 temp += ScalarTraits<A_type>::conjugate(A[i*lda + p]) 01327 * B[p*ldb + j]; 01328 } 01329 if (beta == beta_zero) { 01330 C[j*ldc + i] = alpha*temp; 01331 } else { 01332 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01333 } 01334 } 01335 } 01336 } 01337 } else { 01338 if ( conjB ) { 01339 // Compute C = alpha*A'*conj(B') + beta*C 01340 for (j=izero; j<n; j++) { 01341 for (i=izero; i<m; i++) { 01342 temp = C_zero; 01343 for (p=izero; p<k; p++) { 01344 temp += A[i*lda + p] 01345 * ScalarTraits<B_type>::conjugate(B[p*ldb + j]); 01346 } 01347 if (beta == beta_zero) { 01348 C[j*ldc + i] = alpha*temp; 01349 } else { 01350 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01351 } 01352 } 01353 } 01354 } else { 01355 // Compute C = alpha*A'*B' + beta*C 01356 for (j=izero; j<n; j++) { 01357 for (i=izero; i<m; i++) { 01358 temp = C_zero; 01359 for (p=izero; p<k; p++) { 01360 temp += A[i*lda + p]*B[p*ldb + j]; 01361 } 01362 if (beta == beta_zero) { 01363 C[j*ldc + i] = alpha*temp; 01364 } else { 01365 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01366 } 01367 } 01368 } 01369 } // end if (ETranspChar[transa]=='N') ... 01370 } // end if (ETranspChar[transb]=='N') ... 01371 } // end if (!BadArgument) ... 01372 } // end of GEMM 01373 01374 01375 template<typename OrdinalType, typename ScalarType> 01376 template <typename alpha_type, typename A_type, typename B_type, typename beta_type> 01377 void DefaultBLASImpl<OrdinalType, ScalarType>::SYMM(ESide side, EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type alpha, const A_type* A, const OrdinalType lda, const B_type* B, const OrdinalType ldb, const beta_type beta, ScalarType* C, const OrdinalType ldc) const 01378 { 01379 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 01380 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 01381 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 01382 beta_type beta_zero = ScalarTraits<beta_type>::zero(); 01383 ScalarType C_zero = ScalarTraits<ScalarType>::zero(); 01384 beta_type beta_one = ScalarTraits<beta_type>::one(); 01385 OrdinalType i, j, k, NRowA = m; 01386 ScalarType temp1, temp2; 01387 bool BadArgument = false; 01388 bool Upper = (EUploChar[uplo] == 'U'); 01389 if (ESideChar[side] == 'R') { NRowA = n; } 01390 01391 // Quick return. 01392 if ( (m==izero) || (n==izero) || ( (alpha==alpha_zero)&&(beta==beta_one) ) ) { return; } 01393 if( m < izero ) { 01394 std::cout << "BLAS::SYMM Error: M == "<< m << std::endl; 01395 BadArgument = true; } 01396 if( n < izero ) { 01397 std::cout << "BLAS::SYMM Error: N == "<< n << std::endl; 01398 BadArgument = true; } 01399 if( lda < NRowA ) { 01400 std::cout << "BLAS::SYMM Error: LDA < "<<NRowA<<std::endl; 01401 BadArgument = true; } 01402 if( ldb < m ) { 01403 std::cout << "BLAS::SYMM Error: LDB < MAX(1,M)"<<std::endl; 01404 BadArgument = true; } 01405 if( ldc < m ) { 01406 std::cout << "BLAS::SYMM Error: LDC < MAX(1,M)"<<std::endl; 01407 BadArgument = true; } 01408 01409 if(!BadArgument) { 01410 01411 // Only need to scale C and return. 01412 if (alpha == alpha_zero) { 01413 if (beta == beta_zero ) { 01414 for (j=izero; j<n; j++) { 01415 for (i=izero; i<m; i++) { 01416 C[j*ldc + i] = C_zero; 01417 } 01418 } 01419 } else { 01420 for (j=izero; j<n; j++) { 01421 for (i=izero; i<m; i++) { 01422 C[j*ldc + i] *= beta; 01423 } 01424 } 01425 } 01426 return; 01427 } 01428 01429 if ( ESideChar[side] == 'L') { 01430 // Compute C = alpha*A*B + beta*C 01431 01432 if (Upper) { 01433 // The symmetric part of A is stored in the upper triangular part of the matrix. 01434 for (j=izero; j<n; j++) { 01435 for (i=izero; i<m; i++) { 01436 temp1 = alpha*B[j*ldb + i]; 01437 temp2 = C_zero; 01438 for (k=izero; k<i; k++) { 01439 C[j*ldc + k] += temp1*A[i*lda + k]; 01440 temp2 += B[j*ldb + k]*A[i*lda + k]; 01441 } 01442 if (beta == beta_zero) { 01443 C[j*ldc + i] = temp1*A[i*lda + i] + alpha*temp2; 01444 } else { 01445 C[j*ldc + i] = beta*C[j*ldc + i] + temp1*A[i*lda + i] + alpha*temp2; 01446 } 01447 } 01448 } 01449 } else { 01450 // The symmetric part of A is stored in the lower triangular part of the matrix. 01451 for (j=izero; j<n; j++) { 01452 for (i=m-ione; i>-ione; i--) { 01453 temp1 = alpha*B[j*ldb + i]; 01454 temp2 = C_zero; 01455 for (k=i+ione; k<m; k++) { 01456 C[j*ldc + k] += temp1*A[i*lda + k]; 01457 temp2 += B[j*ldb + k]*A[i*lda + k]; 01458 } 01459 if (beta == beta_zero) { 01460 C[j*ldc + i] = temp1*A[i*lda + i] + alpha*temp2; 01461 } else { 01462 C[j*ldc + i] = beta*C[j*ldc + i] + temp1*A[i*lda + i] + alpha*temp2; 01463 } 01464 } 01465 } 01466 } 01467 } else { 01468 // Compute C = alpha*B*A + beta*C. 01469 for (j=izero; j<n; j++) { 01470 temp1 = alpha*A[j*lda + j]; 01471 if (beta == beta_zero) { 01472 for (i=izero; i<m; i++) { 01473 C[j*ldc + i] = temp1*B[j*ldb + i]; 01474 } 01475 } else { 01476 for (i=izero; i<m; i++) { 01477 C[j*ldc + i] = beta*C[j*ldc + i] + temp1*B[j*ldb + i]; 01478 } 01479 } 01480 for (k=izero; k<j; k++) { 01481 if (Upper) { 01482 temp1 = alpha*A[j*lda + k]; 01483 } else { 01484 temp1 = alpha*A[k*lda + j]; 01485 } 01486 for (i=izero; i<m; i++) { 01487 C[j*ldc + i] += temp1*B[k*ldb + i]; 01488 } 01489 } 01490 for (k=j+ione; k<n; k++) { 01491 if (Upper) { 01492 temp1 = alpha*A[k*lda + j]; 01493 } else { 01494 temp1 = alpha*A[j*lda + k]; 01495 } 01496 for (i=izero; i<m; i++) { 01497 C[j*ldc + i] += temp1*B[k*ldb + i]; 01498 } 01499 } 01500 } 01501 } // end if (ESideChar[side]=='L') ... 01502 } // end if(!BadArgument) ... 01503 } // end SYMM 01504 01505 template<typename OrdinalType, typename ScalarType> 01506 template <typename alpha_type, typename A_type, typename beta_type> 01507 void DefaultBLASImpl<OrdinalType, ScalarType>::SYRK(EUplo uplo, ETransp trans, const OrdinalType n, const OrdinalType k, const alpha_type alpha, const A_type* A, const OrdinalType lda, const beta_type beta, ScalarType* C, const OrdinalType ldc) const 01508 { 01509 typedef TypeNameTraits<OrdinalType> OTNT; 01510 typedef TypeNameTraits<ScalarType> STNT; 01511 01512 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 01513 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 01514 beta_type beta_zero = ScalarTraits<beta_type>::zero(); 01515 beta_type beta_one = ScalarTraits<beta_type>::one(); 01516 A_type temp, A_zero = ScalarTraits<A_type>::zero(); 01517 ScalarType C_zero = ScalarTraits<ScalarType>::zero(); 01518 OrdinalType i, j, l, NRowA = n; 01519 bool BadArgument = false; 01520 bool Upper = (EUploChar[uplo] == 'U'); 01521 01522 TEUCHOS_TEST_FOR_EXCEPTION( 01523 Teuchos::ScalarTraits<ScalarType>::isComplex 01524 && (trans == CONJ_TRANS), 01525 std::logic_error, 01526 "Teuchos::BLAS<"<<OTNT::name()<<","<<STNT::name()<<">::SYRK()" 01527 " does not support CONJ_TRANS for complex data types." 01528 ); 01529 01530 // Change dimensions of A matrix is transposed 01531 if( !(ETranspChar[trans]=='N') ) { 01532 NRowA = k; 01533 } 01534 01535 // Quick return. 01536 if ( n==izero ) { return; } 01537 if ( ( (alpha==alpha_zero) || (k==izero) ) && (beta==beta_one) ) { return; } 01538 if( n < izero ) { 01539 std::cout << "BLAS::SYRK Error: N == "<< n <<std::endl; 01540 BadArgument = true; } 01541 if( k < izero ) { 01542 std::cout << "BLAS::SYRK Error: K == "<< k <<std::endl; 01543 BadArgument = true; } 01544 if( lda < NRowA ) { 01545 std::cout << "BLAS::SYRK Error: LDA < "<<NRowA<<std::endl; 01546 BadArgument = true; } 01547 if( ldc < n ) { 01548 std::cout << "BLAS::SYRK Error: LDC < MAX(1,N)"<<std::endl; 01549 BadArgument = true; } 01550 01551 if(!BadArgument) { 01552 01553 // Scale C when alpha is zero 01554 if (alpha == alpha_zero) { 01555 if (Upper) { 01556 if (beta==beta_zero) { 01557 for (j=izero; j<n; j++) { 01558 for (i=izero; i<=j; i++) { 01559 C[j*ldc + i] = C_zero; 01560 } 01561 } 01562 } 01563 else { 01564 for (j=izero; j<n; j++) { 01565 for (i=izero; i<=j; i++) { 01566 C[j*ldc + i] *= beta; 01567 } 01568 } 01569 } 01570 } 01571 else { 01572 if (beta==beta_zero) { 01573 for (j=izero; j<n; j++) { 01574 for (i=j; i<n; i++) { 01575 C[j*ldc + i] = C_zero; 01576 } 01577 } 01578 } 01579 else { 01580 for (j=izero; j<n; j++) { 01581 for (i=j; i<n; i++) { 01582 C[j*ldc + i] *= beta; 01583 } 01584 } 01585 } 01586 } 01587 return; 01588 } 01589 01590 // Now we can start the computation 01591 01592 if ( ETranspChar[trans]=='N' ) { 01593 01594 // Form C <- alpha*A*A' + beta*C 01595 if (Upper) { 01596 for (j=izero; j<n; j++) { 01597 if (beta==beta_zero) { 01598 for (i=izero; i<=j; i++) { 01599 C[j*ldc + i] = C_zero; 01600 } 01601 } 01602 else if (beta!=beta_one) { 01603 for (i=izero; i<=j; i++) { 01604 C[j*ldc + i] *= beta; 01605 } 01606 } 01607 for (l=izero; l<k; l++) { 01608 if (A[l*lda + j] != A_zero) { 01609 temp = alpha*A[l*lda + j]; 01610 for (i = izero; i <=j; i++) { 01611 C[j*ldc + i] += temp*A[l*lda + i]; 01612 } 01613 } 01614 } 01615 } 01616 } 01617 else { 01618 for (j=izero; j<n; j++) { 01619 if (beta==beta_zero) { 01620 for (i=j; i<n; i++) { 01621 C[j*ldc + i] = C_zero; 01622 } 01623 } 01624 else if (beta!=beta_one) { 01625 for (i=j; i<n; i++) { 01626 C[j*ldc + i] *= beta; 01627 } 01628 } 01629 for (l=izero; l<k; l++) { 01630 if (A[l*lda + j] != A_zero) { 01631 temp = alpha*A[l*lda + j]; 01632 for (i=j; i<n; i++) { 01633 C[j*ldc + i] += temp*A[l*lda + i]; 01634 } 01635 } 01636 } 01637 } 01638 } 01639 } 01640 else { 01641 01642 // Form C <- alpha*A'*A + beta*C 01643 if (Upper) { 01644 for (j=izero; j<n; j++) { 01645 for (i=izero; i<=j; i++) { 01646 temp = A_zero; 01647 for (l=izero; l<k; l++) { 01648 temp += A[i*lda + l]*A[j*lda + l]; 01649 } 01650 if (beta==beta_zero) { 01651 C[j*ldc + i] = alpha*temp; 01652 } 01653 else { 01654 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01655 } 01656 } 01657 } 01658 } 01659 else { 01660 for (j=izero; j<n; j++) { 01661 for (i=j; i<n; i++) { 01662 temp = A_zero; 01663 for (l=izero; l<k; ++l) { 01664 temp += A[i*lda + l]*A[j*lda + l]; 01665 } 01666 if (beta==beta_zero) { 01667 C[j*ldc + i] = alpha*temp; 01668 } 01669 else { 01670 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01671 } 01672 } 01673 } 01674 } 01675 } 01676 } /* if (!BadArgument) */ 01677 } /* END SYRK */ 01678 01679 template<typename OrdinalType, typename ScalarType> 01680 template <typename alpha_type, typename A_type> 01681 void DefaultBLASImpl<OrdinalType, ScalarType>::TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type alpha, const A_type* A, const OrdinalType lda, ScalarType* B, const OrdinalType ldb) const 01682 { 01683 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 01684 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 01685 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 01686 A_type A_zero = ScalarTraits<A_type>::zero(); 01687 ScalarType B_zero = ScalarTraits<ScalarType>::zero(); 01688 ScalarType one = ScalarTraits<ScalarType>::one(); 01689 OrdinalType i, j, k, NRowA = m; 01690 ScalarType temp; 01691 bool BadArgument = false; 01692 bool LSide = (ESideChar[side] == 'L'); 01693 bool noUnit = (EDiagChar[diag] == 'N'); 01694 bool Upper = (EUploChar[uplo] == 'U'); 01695 bool noConj = (ETranspChar[transa] == 'T'); 01696 01697 if(!LSide) { NRowA = n; } 01698 01699 // Quick return. 01700 if (n==izero || m==izero) { return; } 01701 if( m < izero ) { 01702 std::cout << "BLAS::TRMM Error: M == "<< m <<std::endl; 01703 BadArgument = true; } 01704 if( n < izero ) { 01705 std::cout << "BLAS::TRMM Error: N == "<< n <<std::endl; 01706 BadArgument = true; } 01707 if( lda < NRowA ) { 01708 std::cout << "BLAS::TRMM Error: LDA < "<<NRowA<<std::endl; 01709 BadArgument = true; } 01710 if( ldb < m ) { 01711 std::cout << "BLAS::TRMM Error: LDB < MAX(1,M)"<<std::endl; 01712 BadArgument = true; } 01713 01714 if(!BadArgument) { 01715 01716 // B only needs to be zeroed out. 01717 if( alpha == alpha_zero ) { 01718 for( j=izero; j<n; j++ ) { 01719 for( i=izero; i<m; i++ ) { 01720 B[j*ldb + i] = B_zero; 01721 } 01722 } 01723 return; 01724 } 01725 01726 // Start the computations. 01727 if ( LSide ) { 01728 // A is on the left side of B. 01729 01730 if ( ETranspChar[transa]=='N' ) { 01731 // Compute B = alpha*A*B 01732 01733 if ( Upper ) { 01734 // A is upper triangular 01735 for( j=izero; j<n; j++ ) { 01736 for( k=izero; k<m; k++) { 01737 if ( B[j*ldb + k] != B_zero ) { 01738 temp = alpha*B[j*ldb + k]; 01739 for( i=izero; i<k; i++ ) { 01740 B[j*ldb + i] += temp*A[k*lda + i]; 01741 } 01742 if ( noUnit ) 01743 temp *=A[k*lda + k]; 01744 B[j*ldb + k] = temp; 01745 } 01746 } 01747 } 01748 } else { 01749 // A is lower triangular 01750 for( j=izero; j<n; j++ ) { 01751 for( k=m-ione; k>-ione; k-- ) { 01752 if( B[j*ldb + k] != B_zero ) { 01753 temp = alpha*B[j*ldb + k]; 01754 B[j*ldb + k] = temp; 01755 if ( noUnit ) 01756 B[j*ldb + k] *= A[k*lda + k]; 01757 for( i=k+ione; i<m; i++ ) { 01758 B[j*ldb + i] += temp*A[k*lda + i]; 01759 } 01760 } 01761 } 01762 } 01763 } 01764 } else { 01765 // Compute B = alpha*A'*B or B = alpha*conj(A')*B 01766 if( Upper ) { 01767 for( j=izero; j<n; j++ ) { 01768 for( i=m-ione; i>-ione; i-- ) { 01769 temp = B[j*ldb + i]; 01770 if ( noConj ) { 01771 if( noUnit ) 01772 temp *= A[i*lda + i]; 01773 for( k=izero; k<i; k++ ) { 01774 temp += A[i*lda + k]*B[j*ldb + k]; 01775 } 01776 } else { 01777 if( noUnit ) 01778 temp *= ScalarTraits<A_type>::conjugate(A[i*lda + i]); 01779 for( k=izero; k<i; k++ ) { 01780 temp += ScalarTraits<A_type>::conjugate(A[i*lda + k])*B[j*ldb + k]; 01781 } 01782 } 01783 B[j*ldb + i] = alpha*temp; 01784 } 01785 } 01786 } else { 01787 for( j=izero; j<n; j++ ) { 01788 for( i=izero; i<m; i++ ) { 01789 temp = B[j*ldb + i]; 01790 if ( noConj ) { 01791 if( noUnit ) 01792 temp *= A[i*lda + i]; 01793 for( k=i+ione; k<m; k++ ) { 01794 temp += A[i*lda + k]*B[j*ldb + k]; 01795 } 01796 } else { 01797 if( noUnit ) 01798 temp *= ScalarTraits<A_type>::conjugate(A[i*lda + i]); 01799 for( k=i+ione; k<m; k++ ) { 01800 temp += ScalarTraits<A_type>::conjugate(A[i*lda + k])*B[j*ldb + k]; 01801 } 01802 } 01803 B[j*ldb + i] = alpha*temp; 01804 } 01805 } 01806 } 01807 } 01808 } else { 01809 // A is on the right hand side of B. 01810 01811 if( ETranspChar[transa] == 'N' ) { 01812 // Compute B = alpha*B*A 01813 01814 if( Upper ) { 01815 // A is upper triangular. 01816 for( j=n-ione; j>-ione; j-- ) { 01817 temp = alpha; 01818 if( noUnit ) 01819 temp *= A[j*lda + j]; 01820 for( i=izero; i<m; i++ ) { 01821 B[j*ldb + i] *= temp; 01822 } 01823 for( k=izero; k<j; k++ ) { 01824 if( A[j*lda + k] != A_zero ) { 01825 temp = alpha*A[j*lda + k]; 01826 for( i=izero; i<m; i++ ) { 01827 B[j*ldb + i] += temp*B[k*ldb + i]; 01828 } 01829 } 01830 } 01831 } 01832 } else { 01833 // A is lower triangular. 01834 for( j=izero; j<n; j++ ) { 01835 temp = alpha; 01836 if( noUnit ) 01837 temp *= A[j*lda + j]; 01838 for( i=izero; i<m; i++ ) { 01839 B[j*ldb + i] *= temp; 01840 } 01841 for( k=j+ione; k<n; k++ ) { 01842 if( A[j*lda + k] != A_zero ) { 01843 temp = alpha*A[j*lda + k]; 01844 for( i=izero; i<m; i++ ) { 01845 B[j*ldb + i] += temp*B[k*ldb + i]; 01846 } 01847 } 01848 } 01849 } 01850 } 01851 } else { 01852 // Compute B = alpha*B*A' or B = alpha*B*conj(A') 01853 01854 if( Upper ) { 01855 for( k=izero; k<n; k++ ) { 01856 for( j=izero; j<k; j++ ) { 01857 if( A[k*lda + j] != A_zero ) { 01858 if ( noConj ) 01859 temp = alpha*A[k*lda + j]; 01860 else 01861 temp = alpha*ScalarTraits<A_type>::conjugate(A[k*lda + j]); 01862 for( i=izero; i<m; i++ ) { 01863 B[j*ldb + i] += temp*B[k*ldb + i]; 01864 } 01865 } 01866 } 01867 temp = alpha; 01868 if( noUnit ) { 01869 if ( noConj ) 01870 temp *= A[k*lda + k]; 01871 else 01872 temp *= ScalarTraits<A_type>::conjugate(A[k*lda + k]); 01873 } 01874 if( temp != one ) { 01875 for( i=izero; i<m; i++ ) { 01876 B[k*ldb + i] *= temp; 01877 } 01878 } 01879 } 01880 } else { 01881 for( k=n-ione; k>-ione; k-- ) { 01882 for( j=k+ione; j<n; j++ ) { 01883 if( A[k*lda + j] != A_zero ) { 01884 if ( noConj ) 01885 temp = alpha*A[k*lda + j]; 01886 else 01887 temp = alpha*ScalarTraits<A_type>::conjugate(A[k*lda + j]); 01888 for( i=izero; i<m; i++ ) { 01889 B[j*ldb + i] += temp*B[k*ldb + i]; 01890 } 01891 } 01892 } 01893 temp = alpha; 01894 if( noUnit ) { 01895 if ( noConj ) 01896 temp *= A[k*lda + k]; 01897 else 01898 temp *= ScalarTraits<A_type>::conjugate(A[k*lda + k]); 01899 } 01900 if( temp != one ) { 01901 for( i=izero; i<m; i++ ) { 01902 B[k*ldb + i] *= temp; 01903 } 01904 } 01905 } 01906 } 01907 } // end if( ETranspChar[transa] == 'N' ) ... 01908 } // end if ( LSide ) ... 01909 } // end if (!BadArgument) 01910 } // end TRMM 01911 01912 template<typename OrdinalType, typename ScalarType> 01913 template <typename alpha_type, typename A_type> 01914 void DefaultBLASImpl<OrdinalType, ScalarType>::TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type alpha, const A_type* A, const OrdinalType lda, ScalarType* B, const OrdinalType ldb) const 01915 { 01916 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 01917 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 01918 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 01919 A_type A_zero = ScalarTraits<A_type>::zero(); 01920 ScalarType B_zero = ScalarTraits<ScalarType>::zero(); 01921 alpha_type alpha_one = ScalarTraits<alpha_type>::one(); 01922 ScalarType B_one = ScalarTraits<ScalarType>::one(); 01923 ScalarType temp; 01924 OrdinalType NRowA = m; 01925 bool BadArgument = false; 01926 bool noUnit = (EDiagChar[diag]=='N'); 01927 bool noConj = (ETranspChar[transa] == 'T'); 01928 01929 if (!(ESideChar[side] == 'L')) { NRowA = n; } 01930 01931 // Quick return. 01932 if (n == izero || m == izero) { return; } 01933 if( m < izero ) { 01934 std::cout << "BLAS::TRSM Error: M == "<<m<<std::endl; 01935 BadArgument = true; } 01936 if( n < izero ) { 01937 std::cout << "BLAS::TRSM Error: N == "<<n<<std::endl; 01938 BadArgument = true; } 01939 if( lda < NRowA ) { 01940 std::cout << "BLAS::TRSM Error: LDA < "<<NRowA<<std::endl; 01941 BadArgument = true; } 01942 if( ldb < m ) { 01943 std::cout << "BLAS::TRSM Error: LDB < MAX(1,M)"<<std::endl; 01944 BadArgument = true; } 01945 01946 if(!BadArgument) 01947 { 01948 int i, j, k; 01949 // Set the solution to the zero vector. 01950 if(alpha == alpha_zero) { 01951 for(j = izero; j < n; j++) { 01952 for( i = izero; i < m; i++) { 01953 B[j*ldb+i] = B_zero; 01954 } 01955 } 01956 } 01957 else 01958 { // Start the operations. 01959 if(ESideChar[side] == 'L') { 01960 // 01961 // Perform computations for OP(A)*X = alpha*B 01962 // 01963 if(ETranspChar[transa] == 'N') { 01964 // 01965 // Compute B = alpha*inv( A )*B 01966 // 01967 if(EUploChar[uplo] == 'U') { 01968 // A is upper triangular. 01969 for(j = izero; j < n; j++) { 01970 // Perform alpha*B if alpha is not 1. 01971 if(alpha != alpha_one) { 01972 for( i = izero; i < m; i++) { 01973 B[j*ldb+i] *= alpha; 01974 } 01975 } 01976 // Perform a backsolve for column j of B. 01977 for(k = (m - ione); k > -ione; k--) { 01978 // If this entry is zero, we don't have to do anything. 01979 if (B[j*ldb + k] != B_zero) { 01980 if ( noUnit ) { 01981 B[j*ldb + k] /= A[k*lda + k]; 01982 } 01983 for(i = izero; i < k; i++) { 01984 B[j*ldb + i] -= B[j*ldb + k] * A[k*lda + i]; 01985 } 01986 } 01987 } 01988 } 01989 } 01990 else 01991 { // A is lower triangular. 01992 for(j = izero; j < n; j++) { 01993 // Perform alpha*B if alpha is not 1. 01994 if(alpha != alpha_one) { 01995 for( i = izero; i < m; i++) { 01996 B[j*ldb+i] *= alpha; 01997 } 01998 } 01999 // Perform a forward solve for column j of B. 02000 for(k = izero; k < m; k++) { 02001 // If this entry is zero, we don't have to do anything. 02002 if (B[j*ldb + k] != B_zero) { 02003 if ( noUnit ) { 02004 B[j*ldb + k] /= A[k*lda + k]; 02005 } 02006 for(i = k+ione; i < m; i++) { 02007 B[j*ldb + i] -= B[j*ldb + k] * A[k*lda + i]; 02008 } 02009 } 02010 } 02011 } 02012 } // end if (uplo == 'U') 02013 } // if (transa =='N') 02014 else { 02015 // 02016 // Compute B = alpha*inv( A' )*B 02017 // or B = alpha*inv( conj(A') )*B 02018 // 02019 if(EUploChar[uplo] == 'U') { 02020 // A is upper triangular. 02021 for(j = izero; j < n; j++) { 02022 for( i = izero; i < m; i++) { 02023 temp = alpha*B[j*ldb+i]; 02024 if ( noConj ) { 02025 for(k = izero; k < i; k++) { 02026 temp -= A[i*lda + k] * B[j*ldb + k]; 02027 } 02028 if ( noUnit ) { 02029 temp /= A[i*lda + i]; 02030 } 02031 } else { 02032 for(k = izero; k < i; k++) { 02033 temp -= ScalarTraits<A_type>::conjugate(A[i*lda + k]) 02034 * B[j*ldb + k]; 02035 } 02036 if ( noUnit ) { 02037 temp /= ScalarTraits<A_type>::conjugate(A[i*lda + i]); 02038 } 02039 } 02040 B[j*ldb + i] = temp; 02041 } 02042 } 02043 } 02044 else 02045 { // A is lower triangular. 02046 for(j = izero; j < n; j++) { 02047 for(i = (m - ione) ; i > -ione; i--) { 02048 temp = alpha*B[j*ldb+i]; 02049 if ( noConj ) { 02050 for(k = i+ione; k < m; k++) { 02051 temp -= A[i*lda + k] * B[j*ldb + k]; 02052 } 02053 if ( noUnit ) { 02054 temp /= A[i*lda + i]; 02055 } 02056 } else { 02057 for(k = i+ione; k < m; k++) { 02058 temp -= ScalarTraits<A_type>::conjugate(A[i*lda + k]) 02059 * B[j*ldb + k]; 02060 } 02061 if ( noUnit ) { 02062 temp /= ScalarTraits<A_type>::conjugate(A[i*lda + i]); 02063 } 02064 } 02065 B[j*ldb + i] = temp; 02066 } 02067 } 02068 } 02069 } 02070 } // if (side == 'L') 02071 else { 02072 // side == 'R' 02073 // 02074 // Perform computations for X*OP(A) = alpha*B 02075 // 02076 if (ETranspChar[transa] == 'N') { 02077 // 02078 // Compute B = alpha*B*inv( A ) 02079 // 02080 if(EUploChar[uplo] == 'U') { 02081 // A is upper triangular. 02082 // Perform a backsolve for column j of B. 02083 for(j = izero; j < n; j++) { 02084 // Perform alpha*B if alpha is not 1. 02085 if(alpha != alpha_one) { 02086 for( i = izero; i < m; i++) { 02087 B[j*ldb+i] *= alpha; 02088 } 02089 } 02090 for(k = izero; k < j; k++) { 02091 // If this entry is zero, we don't have to do anything. 02092 if (A[j*lda + k] != A_zero) { 02093 for(i = izero; i < m; i++) { 02094 B[j*ldb + i] -= A[j*lda + k] * B[k*ldb + i]; 02095 } 02096 } 02097 } 02098 if ( noUnit ) { 02099 temp = B_one/A[j*lda + j]; 02100 for(i = izero; i < m; i++) { 02101 B[j*ldb + i] *= temp; 02102 } 02103 } 02104 } 02105 } 02106 else 02107 { // A is lower triangular. 02108 for(j = (n - ione); j > -ione; j--) { 02109 // Perform alpha*B if alpha is not 1. 02110 if(alpha != alpha_one) { 02111 for( i = izero; i < m; i++) { 02112 B[j*ldb+i] *= alpha; 02113 } 02114 } 02115 // Perform a forward solve for column j of B. 02116 for(k = j+ione; k < n; k++) { 02117 // If this entry is zero, we don't have to do anything. 02118 if (A[j*lda + k] != A_zero) { 02119 for(i = izero; i < m; i++) { 02120 B[j*ldb + i] -= A[j*lda + k] * B[k*ldb + i]; 02121 } 02122 } 02123 } 02124 if ( noUnit ) { 02125 temp = B_one/A[j*lda + j]; 02126 for(i = izero; i < m; i++) { 02127 B[j*ldb + i] *= temp; 02128 } 02129 } 02130 } 02131 } // end if (uplo == 'U') 02132 } // if (transa =='N') 02133 else { 02134 // 02135 // Compute B = alpha*B*inv( A' ) 02136 // or B = alpha*B*inv( conj(A') ) 02137 // 02138 if(EUploChar[uplo] == 'U') { 02139 // A is upper triangular. 02140 for(k = (n - ione); k > -ione; k--) { 02141 if ( noUnit ) { 02142 if ( noConj ) 02143 temp = B_one/A[k*lda + k]; 02144 else 02145 temp = B_one/ScalarTraits<A_type>::conjugate(A[k*lda + k]); 02146 for(i = izero; i < m; i++) { 02147 B[k*ldb + i] *= temp; 02148 } 02149 } 02150 for(j = izero; j < k; j++) { 02151 if (A[k*lda + j] != A_zero) { 02152 if ( noConj ) 02153 temp = A[k*lda + j]; 02154 else 02155 temp = ScalarTraits<A_type>::conjugate(A[k*lda + j]); 02156 for(i = izero; i < m; i++) { 02157 B[j*ldb + i] -= temp*B[k*ldb + i]; 02158 } 02159 } 02160 } 02161 if (alpha != alpha_one) { 02162 for (i = izero; i < m; i++) { 02163 B[k*ldb + i] *= alpha; 02164 } 02165 } 02166 } 02167 } 02168 else 02169 { // A is lower triangular. 02170 for(k = izero; k < n; k++) { 02171 if ( noUnit ) { 02172 if ( noConj ) 02173 temp = B_one/A[k*lda + k]; 02174 else 02175 temp = B_one/ScalarTraits<A_type>::conjugate(A[k*lda + k]); 02176 for (i = izero; i < m; i++) { 02177 B[k*ldb + i] *= temp; 02178 } 02179 } 02180 for(j = k+ione; j < n; j++) { 02181 if(A[k*lda + j] != A_zero) { 02182 if ( noConj ) 02183 temp = A[k*lda + j]; 02184 else 02185 temp = ScalarTraits<A_type>::conjugate(A[k*lda + j]); 02186 for(i = izero; i < m; i++) { 02187 B[j*ldb + i] -= temp*B[k*ldb + i]; 02188 } 02189 } 02190 } 02191 if (alpha != alpha_one) { 02192 for (i = izero; i < m; i++) { 02193 B[k*ldb + i] *= alpha; 02194 } 02195 } 02196 } 02197 } 02198 } 02199 } 02200 } 02201 } 02202 } 02203 02204 // Explicit instantiation for template<int,float> 02205 02206 template <> 02207 class TEUCHOSNUMERICS_LIB_DLL_EXPORT BLAS<int, float> 02208 { 02209 public: 02210 inline BLAS(void) {} 02211 inline BLAS(const BLAS<int, float>& /*BLAS_source*/) {} 02212 inline virtual ~BLAS(void) {} 02213 void ROTG(float* da, float* db, float* c, float* s) const; 02214 void ROT(const int n, float* dx, const int incx, float* dy, const int incy, float* c, float* s) const; 02215 float ASUM(const int n, const float* x, const int incx) const; 02216 void AXPY(const int n, const float alpha, const float* x, const int incx, float* y, const int incy) const; 02217 void COPY(const int n, const float* x, const int incx, float* y, const int incy) const; 02218 float DOT(const int n, const float* x, const int incx, const float* y, const int incy) const; 02219 float NRM2(const int n, const float* x, const int incx) const; 02220 void SCAL(const int n, const float alpha, float* x, const int incx) const; 02221 int IAMAX(const int n, const float* x, const int incx) const; 02222 void GEMV(ETransp trans, const int m, const int n, const float alpha, const float* A, const int lda, const float* x, const int incx, const float beta, float* y, const int incy) const; 02223 void TRMV(EUplo uplo, ETransp trans, EDiag diag, const int n, const float* A, const int lda, float* x, const int incx) const; 02224 void GER(const int m, const int n, const float alpha, const float* x, const int incx, const float* y, const int incy, float* A, const int lda) const; 02225 void GEMM(ETransp transa, ETransp transb, const int m, const int n, const int k, const float alpha, const float* A, const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc) const; 02226 void SYMM(ESide side, EUplo uplo, const int m, const int n, const float alpha, const float* A, const int lda, const float *B, const int ldb, const float beta, float *C, const int ldc) const; 02227 void SYRK(EUplo uplo, ETransp trans, const int n, const int k, const float alpha, const float* A, const int lda, const float beta, float* C, const int ldc) const; 02228 void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const float alpha, const float* A, const int lda, float* B, const int ldb) const; 02229 void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const float alpha, const float* A, const int lda, float* B, const int ldb) const; 02230 }; 02231 02232 // Explicit instantiation for template<int,double> 02233 02234 template<> 02235 class TEUCHOSNUMERICS_LIB_DLL_EXPORT BLAS<int, double> 02236 { 02237 public: 02238 inline BLAS(void) {} 02239 inline BLAS(const BLAS<int, double>& /*BLAS_source*/) {} 02240 inline virtual ~BLAS(void) {} 02241 void ROTG(double* da, double* db, double* c, double* s) const; 02242 void ROT(const int n, double* dx, const int incx, double* dy, const int incy, double* c, double* s) const; 02243 double ASUM(const int n, const double* x, const int incx) const; 02244 void AXPY(const int n, const double alpha, const double* x, const int incx, double* y, const int incy) const; 02245 void COPY(const int n, const double* x, const int incx, double* y, const int incy) const; 02246 double DOT(const int n, const double* x, const int incx, const double* y, const int incy) const; 02247 double NRM2(const int n, const double* x, const int incx) const; 02248 void SCAL(const int n, const double alpha, double* x, const int incx) const; 02249 int IAMAX(const int n, const double* x, const int incx) const; 02250 void GEMV(ETransp trans, const int m, const int n, const double alpha, const double* A, const int lda, const double* x, const int incx, const double beta, double* y, const int incy) const; 02251 void TRMV(EUplo uplo, ETransp trans, EDiag diag, const int n, const double* A, const int lda, double* x, const int incx) const; 02252 void GER(const int m, const int n, const double alpha, const double* x, const int incx, const double* y, const int incy, double* A, const int lda) const; 02253 void GEMM(ETransp transa, ETransp transb, const int m, const int n, const int k, const double alpha, const double* A, const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc) const; 02254 void SYMM(ESide side, EUplo uplo, const int m, const int n, const double alpha, const double* A, const int lda, const double *B, const int ldb, const double beta, double *C, const int ldc) const; 02255 void SYRK(EUplo uplo, ETransp trans, const int n, const int k, const double alpha, const double* A, const int lda, const double beta, double* C, const int ldc) const; 02256 void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const double alpha, const double* A, const int lda, double* B, const int ldb) const; 02257 void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const double alpha, const double* A, const int lda, double* B, const int ldb) const; 02258 }; 02259 02260 // Explicit instantiation for template<int,complex<float> > 02261 02262 template<> 02263 class TEUCHOSNUMERICS_LIB_DLL_EXPORT BLAS<int, std::complex<float> > 02264 { 02265 public: 02266 inline BLAS(void) {} 02267 inline BLAS(const BLAS<int, std::complex<float> >& /*BLAS_source*/) {} 02268 inline virtual ~BLAS(void) {} 02269 void ROTG(std::complex<float>* da, std::complex<float>* db, float* c, std::complex<float>* s) const; 02270 void ROT(const int n, std::complex<float>* dx, const int incx, std::complex<float>* dy, const int incy, float* c, std::complex<float>* s) const; 02271 float ASUM(const int n, const std::complex<float>* x, const int incx) const; 02272 void AXPY(const int n, const std::complex<float> alpha, const std::complex<float>* x, const int incx, std::complex<float>* y, const int incy) const; 02273 void COPY(const int n, const std::complex<float>* x, const int incx, std::complex<float>* y, const int incy) const; 02274 std::complex<float> DOT(const int n, const std::complex<float>* x, const int incx, const std::complex<float>* y, const int incy) const; 02275 float NRM2(const int n, const std::complex<float>* x, const int incx) const; 02276 void SCAL(const int n, const std::complex<float> alpha, std::complex<float>* x, const int incx) const; 02277 int IAMAX(const int n, const std::complex<float>* x, const int incx) const; 02278 void GEMV(ETransp trans, const int m, const int n, const std::complex<float> alpha, const std::complex<float>* A, const int lda, const std::complex<float>* x, const int incx, const std::complex<float> beta, std::complex<float>* y, const int incy) const; 02279 void TRMV(EUplo uplo, ETransp trans, EDiag diag, const int n, const std::complex<float>* A, const int lda, std::complex<float>* x, const int incx) const; 02280 void GER(const int m, const int n, const std::complex<float> alpha, const std::complex<float>* x, const int incx, const std::complex<float>* y, const int incy, std::complex<float>* A, const int lda) const; 02281 void GEMM(ETransp transa, ETransp transb, const int m, const int n, const int k, const std::complex<float> alpha, const std::complex<float>* A, const int lda, const std::complex<float>* B, const int ldb, const std::complex<float> beta, std::complex<float>* C, const int ldc) const; 02282 void SYMM(ESide side, EUplo uplo, const int m, const int n, const std::complex<float> alpha, const std::complex<float>* A, const int lda, const std::complex<float> *B, const int ldb, const std::complex<float> beta, std::complex<float> *C, const int ldc) const; 02283 void SYRK(EUplo uplo, ETransp trans, const int n, const int k, const std::complex<float> alpha, const std::complex<float>* A, const int lda, const std::complex<float> beta, std::complex<float>* C, const int ldc) const; 02284 void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const std::complex<float> alpha, const std::complex<float>* A, const int lda, std::complex<float>* B, const int ldb) const; 02285 void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const std::complex<float> alpha, const std::complex<float>* A, const int lda, std::complex<float>* B, const int ldb) const; 02286 }; 02287 02288 // Explicit instantiation for template<int,complex<double> > 02289 02290 template<> 02291 class TEUCHOSNUMERICS_LIB_DLL_EXPORT BLAS<int, std::complex<double> > 02292 { 02293 public: 02294 inline BLAS(void) {} 02295 inline BLAS(const BLAS<int, std::complex<double> >& /*BLAS_source*/) {} 02296 inline virtual ~BLAS(void) {} 02297 void ROTG(std::complex<double>* da, std::complex<double>* db, double* c, std::complex<double>* s) const; 02298 void ROT(const int n, std::complex<double>* dx, const int incx, std::complex<double>* dy, const int incy, double* c, std::complex<double>* s) const; 02299 double ASUM(const int n, const std::complex<double>* x, const int incx) const; 02300 void AXPY(const int n, const std::complex<double> alpha, const std::complex<double>* x, const int incx, std::complex<double>* y, const int incy) const; 02301 void COPY(const int n, const std::complex<double>* x, const int incx, std::complex<double>* y, const int incy) const; 02302 std::complex<double> DOT(const int n, const std::complex<double>* x, const int incx, const std::complex<double>* y, const int incy) const; 02303 double NRM2(const int n, const std::complex<double>* x, const int incx) const; 02304 void SCAL(const int n, const std::complex<double> alpha, std::complex<double>* x, const int incx) const; 02305 int IAMAX(const int n, const std::complex<double>* x, const int incx) const; 02306 void GEMV(ETransp trans, const int m, const int n, const std::complex<double> alpha, const std::complex<double>* A, const int lda, const std::complex<double>* x, const int incx, const std::complex<double> beta, std::complex<double>* y, const int incy) const; 02307 void TRMV(EUplo uplo, ETransp trans, EDiag diag, const int n, const std::complex<double>* A, const int lda, std::complex<double>* x, const int incx) const; 02308 void GER(const int m, const int n, const std::complex<double> alpha, const std::complex<double>* x, const int incx, const std::complex<double>* y, const int incy, std::complex<double>* A, const int lda) const; 02309 void GEMM(ETransp transa, ETransp transb, const int m, const int n, const int k, const std::complex<double> alpha, const std::complex<double>* A, const int lda, const std::complex<double>* B, const int ldb, const std::complex<double> beta, std::complex<double>* C, const int ldc) const; 02310 void SYMM(ESide side, EUplo uplo, const int m, const int n, const std::complex<double> alpha, const std::complex<double>* A, const int lda, const std::complex<double> *B, const int ldb, const std::complex<double> beta, std::complex<double> *C, const int ldc) const; 02311 void SYRK(EUplo uplo, ETransp trans, const int n, const int k, const std::complex<double> alpha, const std::complex<double>* A, const int lda, const std::complex<double> beta, std::complex<double>* C, const int ldc) const; 02312 void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const std::complex<double> alpha, const std::complex<double>* A, const int lda, std::complex<double>* B, const int ldb) const; 02313 void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const std::complex<double> alpha, const std::complex<double>* A, const int lda, std::complex<double>* B, const int ldb) const; 02314 }; 02315 02316 } // namespace Teuchos 02317 02318 #endif // _TEUCHOS_BLAS_HPP_
1.7.6.1