SundanceUserDefOpCommonEvaluator.cpp
Go to the documentation of this file.
00001 /* @HEADER@ */
00002 // ************************************************************************
00003 // 
00004 //                             Sundance
00005 //                 Copyright 2011 Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 //
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Kevin Long (kevin.long@ttu.edu)
00038 // 
00039 
00040 /* @HEADER@ */
00041 
00042 #include "SundanceUserDefOpCommonEvaluator.hpp"
00043 #include "SundanceEvalManager.hpp"
00044 #include "SundanceEvalVector.hpp"
00045 #include "SundanceUserDefOp.hpp"
00046 #include "SundanceUserDefOpElement.hpp"
00047 #include "SundanceSparsitySuperset.hpp"
00048 #include "PlayaTabs.hpp"
00049 #include "SundanceOut.hpp"
00050 
00051 using namespace Sundance;
00052 using namespace Sundance;
00053 
00054 using namespace Sundance;
00055 using namespace Teuchos;
00056 
00057 
00058 UserDefOpCommonEvaluator
00059 ::UserDefOpCommonEvaluator(const UserDefFunctor* functor,
00060                            const UserDefOpElement* expr,
00061                            const EvalContext& context)
00062   :  maxOrder_(0),
00063      argValueIndex_(functor->domainDim(), -1),
00064      argValueIsConstant_(functor->domainDim()),
00065      constArgDerivCache_(functor->rangeDim()),
00066      varArgDerivCache_(functor->rangeDim()),
00067      cacheIsValid_(false),
00068      functor_(functor)
00069 {
00070   /* Find the indices to the zeroth derivative of each argument */
00071   for (int i=0; i<functor->domainDim(); i++)
00072     {
00073       const SparsitySuperset* sArg = expr->evaluatableChild(i)->sparsitySuperset(context).get();
00074       int numConst=0;
00075       int numVec=0;
00076       for (int j=0; j<sArg->numDerivs(); j++)
00077         {
00078           if (sArg->deriv(j).order() == 0) 
00079             {
00080               if (sArg->state(j)==VectorDeriv)
00081                 {
00082                   argValueIndex_[i] = numVec;              
00083                 }
00084               else
00085                 {
00086                   argValueIndex_[i] = numConst;              
00087                 }
00088               break;
00089             }
00090           if (sArg->state(j) == VectorDeriv) 
00091             {
00092               numVec++;
00093             }
00094           else
00095             {
00096               numConst++;
00097             }
00098         }
00099       /* Check to make sure a zeroth derivative has been found. */
00100       TEUCHOS_TEST_FOR_EXCEPTION(argValueIndex_[i]==-1, std::runtime_error,
00101                          "no zeroth derivative found for argument #" << i
00102                          << " of " << expr->toString());
00103     }
00104 }
00105 
00106 
00107 
00108 
00109 void UserDefOpCommonEvaluator
00110 ::evalAllComponents(const EvalManager& mgr,
00111                     const Array<RCP<Array<double> > >& constArgVals,
00112                     const Array<RCP<Array<RCP<EvalVector> > > >& vArgVals) const 
00113 {
00114   Tabs tab0;
00115   int numPoints = EvalManager::stack().vecSize();
00116   SUNDANCE_MSG3(mgr.verb(), tab0 << "UDOpCommonEval::evalAllComponents()");
00117   SUNDANCE_MSG3(mgr.verb(), tab0 << "num points = " << numPoints);
00118   SUNDANCE_MSG2(mgr.verb(), tab0 << "max diff order = " << maxOrder_);
00119 
00120   TEUCHOS_TEST_FOR_EXCEPTION(numPoints==0, std::logic_error,
00121                      "Empty vector detected in evalArgDerivs()"); 
00122 
00123   /* Get an array of pointers for the argument vectors.
00124    * If some of the arguments are constant, copy them into vectors. */
00125   Array<RCP<EvalVector> > argVals(argValueIndex_.size());
00126   Array<double*> argPtrs(argValueIndex_.size());
00127   for (int q=0; q<argValueIndex_.size(); q++)
00128     {
00129       Tabs tab1;
00130       if (argValueIsConstant_[q]) 
00131         {
00132           argVals[q] = mgr.popVector();
00133           double* ptr = argVals[q]->start();
00134           double c =  (*(constArgVals[q]))[argValueIndex_[q]];
00135           for (int p=0; p<numPoints; p++)
00136             {
00137               ptr[p] = c;
00138             }
00139           argPtrs[q] = ptr;
00140         }
00141       else
00142         {
00143           argVals[q] = (*(vArgVals[q]))[argValueIndex_[q]];
00144           argPtrs[q] = argVals[q]->start();
00145         }
00146       SUNDANCE_MSG3(mgr.verb(), tab1 << "argument #" << q << " is:");
00147       Tabs tab2;
00148       SUNDANCE_MSG3(mgr.verb(), tab2 << argVals[q]->str());
00149     }
00150 
00151   /* Allocate vectors for the function values and derivatives */
00152   TEUCHOS_TEST_FOR_EXCEPTION(maxOrder_ > 2, std::runtime_error,
00153                      "Differentiation order " << maxOrder_ << ">2 not supported "
00154                      "for user-defined operators");
00155   int rangeDim = functor_->rangeDim();
00156   int domainDim = functor_->domainDim();
00157   int nTotal = 1;
00158   int numFirst = domainDim;
00159   int numSecond = domainDim*(domainDim+1)/2;
00160   if (maxOrder_ > 0) nTotal += numFirst;
00161   if (maxOrder_ > 1) nTotal += numSecond;
00162   int numResultVecs = nTotal * rangeDim;
00163   
00164   /* The resultVecs array contains pointers to the numerical vectors in the
00165    * cache of vector-valued arg derivs.
00166    */
00167   Array<double*> resultVecs(numResultVecs);
00168 
00169   
00170   for (int i=0; i<rangeDim; i++)
00171     {
00172       varArgDerivCache_[i].resize(nTotal);
00173       varArgDerivCache_[i][0] = mgr.popVector();
00174       varArgDerivCache_[i][0]->resize(numPoints);
00175       varArgDerivCache_[i][0]->setString(functor_->name(i));
00176       int d0Pos = i;
00177       SUNDANCE_MSG3(mgr.verb(), "zeroth deriv of elem #" << i << " is at " << d0Pos);
00178       resultVecs[d0Pos] = varArgDerivCache_[i][0]->start();
00179       if (maxOrder_ > 0)
00180         {
00181           int ptr = 0;
00182           for (int j=0; j<domainDim; j++)
00183             {
00184               varArgDerivCache_[i][j+1] = mgr.popVector();
00185               varArgDerivCache_[i][j+1]->resize(numPoints);
00186               int d1Pos = rangeDim + domainDim*i + j;
00187               SUNDANCE_MSG3(mgr.verb(), "first deriv (" << j << ") of elem #" << i 
00188                                  << " is at " << d1Pos);
00189               resultVecs[d1Pos] 
00190                 = varArgDerivCache_[i][j+1]->start();
00191               varArgDerivCache_[i][j+1]->setString("D[" + functor_->name(i) 
00192                                                    + ", " 
00193                                                    + argVals[j]->str() + "]");
00194               if (maxOrder_ > 1)
00195                 {
00196                   for (int k=0; k<=j; k++, ptr++)
00197                     {
00198                       int m = (1 + numFirst);
00199                       varArgDerivCache_[i][m+ptr] = mgr.popVector();
00200                       varArgDerivCache_[i][m+ptr]->resize(numPoints);
00201                       varArgDerivCache_[i][m+ptr]->setString("D[" 
00202                                                              + functor_->name(i) 
00203                                                              + ", {" 
00204                                                              + argVals[j]->str() 
00205                                                              + ", "
00206                                                              + argVals[k]->str() 
00207                                                              + "}]");
00208                       int d2Pos = rangeDim + domainDim*rangeDim 
00209                         + i*numSecond + ptr;
00210                       SUNDANCE_MSG3(mgr.verb(), "second deriv (" << j << ", " << k << 
00211                                          ") of elem #" << i << " is at " << d2Pos);
00212                       resultVecs[d2Pos] 
00213                         = varArgDerivCache_[i][m+ptr]->start();
00214                     }
00215                 }
00216             }
00217         }
00218     }
00219 
00220   
00221   /* Call the user's callback function. The results will be written
00222    * into the cache of argument derivatives. */
00223   const double** in = const_cast<const double**>(&(argPtrs[0]));
00224   double** out = &(resultVecs[0]);
00225 
00226   functor_->evaluationCallback(numPoints, maxOrder_, in, out);
00227 
00228  
00229 
00230   markCacheAsValid();
00231 }
00232 
00233 

Site Contact