SundanceUserDefOpEvaluator.cpp
Go to the documentation of this file.
00001 /* @HEADER@ */
00002 // ************************************************************************
00003 // 
00004 //                             Sundance
00005 //                 Copyright 2011 Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 //
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Kevin Long (kevin.long@ttu.edu)
00038 // 
00039 
00040 /* @HEADER@ */
00041 
00042 #include "SundanceUserDefOpEvaluator.hpp"
00043 #include "SundanceUserDefOpCommonEvaluator.hpp"
00044 #include "SundanceUserDefOpElement.hpp"
00045 #include "SundanceEvalManager.hpp"
00046 
00047 #include "PlayaTabs.hpp"
00048 #include "SundanceOut.hpp"
00049 #include "SundanceUserDefOp.hpp"
00050 
00051 using namespace Sundance;
00052 using namespace Sundance;
00053 using namespace Sundance;
00054 using namespace Teuchos;
00055 
00056 
00057 
00058 
00059 
00060 UserDefOpEvaluator
00061 ::UserDefOpEvaluator(const UserDefOpElement* expr,
00062                      const RCP<const UserDefOpCommonEvaluator>& commonEval,
00063                      const EvalContext& context)
00064   : ChainRuleEvaluator(expr, context),
00065     argValueIndex_(expr->numChildren()),
00066     argValueIsConstant_(expr->numChildren()),
00067     functor_(expr->functorElement()),
00068     commonEval_(commonEval),
00069     maxOrder_(0),
00070     numVarArgDerivs_(0),
00071     numConstArgDerivs_(0),
00072     allArgsAreConstant_(true)
00073 {
00074   Tabs tab1;
00075   SUNDANCE_VERB_LOW(tab1 << "initializing user defined op evaluator for " 
00076                     << expr->toString());
00077   Array<int> orders = findRequiredOrders(expr, context);
00078 
00079   for (int i=0; i<orders.size(); i++) 
00080     {
00081       if (orders[i] > maxOrder_) maxOrder_ = orders[i];
00082     }
00083   commonEval->updateMaxOrder(maxOrder_);
00084 
00085   SUNDANCE_VERB_HIGH(tab1 << "setting arg deriv indices");
00086   
00087   
00088   /* Find the mapping from argument derivatives to indices in the 
00089    * functor's vector of return values */
00090   Map<MultiSet<int>, int> varArgDerivs;
00091   Map<MultiSet<int>, int> constArgDerivs;
00092   expr->getArgDerivIndices(orders, varArgDerivs, constArgDerivs);
00093   numVarArgDerivs_ = varArgDerivs.size();
00094   numConstArgDerivs_ = constArgDerivs.size();
00095   typedef Map<MultiSet<int>, int>::const_iterator iter;
00096   for (iter i=varArgDerivs.begin(); i!=varArgDerivs.end(); i++)
00097     {
00098       Tabs tab2;
00099       SUNDANCE_VERB_EXTREME(tab2 << "variable arg deriv " << i->first 
00100                             << " will be at index "
00101                             << i->second);
00102       addVarArgDeriv(i->first, i->second);
00103     }
00104   
00105   for (iter i=constArgDerivs.begin(); i!=constArgDerivs.end(); i++)
00106     {
00107       Tabs tab2;
00108       SUNDANCE_VERB_EXTREME(tab2 << "constant arg deriv " << i->first 
00109                             << " will be at index "
00110                             << i->second);
00111       addConstArgDeriv(i->first, i->second);
00112     }
00113 
00114   /* Find the indices to the zeroth derivative of each argument */
00115   
00116   for (int i=0; i<expr->numChildren(); i++)
00117     {
00118       const SparsitySuperset* sArg = childSparsity(i);
00119       int numConst=0;
00120       int numVec=0;
00121       for (int j=0; j<sArg->numDerivs(); j++)
00122         {
00123           if (sArg->deriv(j).order() == 0) 
00124             {
00125               if (sArg->state(j)==VectorDeriv)
00126                 {
00127                   argValueIndex_[i] = numVec;              
00128                   allArgsAreConstant_ = false;
00129                 }
00130               else
00131                 {
00132                   argValueIndex_[i] = numConst;              
00133                 }
00134               break;
00135             }
00136           if (sArg->state(j) == VectorDeriv) 
00137             {
00138               numVec++;
00139             }
00140           else
00141             {
00142               numConst++;
00143             }
00144         }
00145     }
00146   
00147   /* Call init() at the base class to set up chain rule evaluation */
00148   init(expr, context);
00149 }
00150 
00151 
00152 
00153 
00154 void UserDefOpEvaluator::resetNumCalls() const
00155 {
00156   commonEval()->markCacheAsInvalid();
00157   ChainRuleEvaluator::resetNumCalls();
00158 }
00159 
00160 
00161 
00162 
00163 Array<int> UserDefOpEvaluator::findRequiredOrders(const ExprWithChildren* expr, 
00164                                                   const EvalContext& context)
00165 {
00166   Tabs tab0;
00167   SUNDANCE_VERB_HIGH(tab0 << "finding required arg deriv orders");
00168 
00169   Set<int> orders;
00170   
00171   const Set<MultipleDeriv>& R = expr->findR(context);
00172   typedef Set<MultipleDeriv>::const_iterator iter;
00173 
00174   for (iter md=R.begin(); md!=R.end(); md++)
00175     {
00176       Tabs tab1;
00177       
00178       int N = md->order();
00179       if (N > maxOrder_) maxOrder_ = N;
00180       if (N==0) orders.put(N);
00181       for (int n=1; n<=N; n++)
00182         {
00183           const Set<MultiSet<int> >& QW = expr->findQ_W(n, context);
00184           for (Set<MultiSet<int> >::const_iterator q=QW.begin(); q!=QW.end(); q++)
00185             {
00186               orders.put(q->size());
00187             }
00188         }
00189     }
00190   SUNDANCE_VERB_HIGH(tab0 << "arg deriv orders=" << orders);
00191   return orders.elements();
00192 }
00193 
00194 
00195 
00196 
00197 void UserDefOpEvaluator
00198 ::evalArgDerivs(const EvalManager& mgr,
00199                 const Array<RCP<Array<double> > >& constArgVals,
00200                 const Array<RCP<Array<RCP<EvalVector> > > >& varArgVals,
00201                 Array<double>& constArgDerivs,
00202                 Array<RCP<EvalVector> >& varArgDerivs) const
00203 {
00204   if (!commonEval()->cacheIsValid())
00205     {
00206       commonEval()->evalAllComponents(mgr, constArgVals, varArgVals);
00207     }
00208   if (allArgsAreConstant_)
00209     {
00210       constArgDerivs = commonEval()->constArgDerivCache(myIndex());
00211     }
00212   else
00213     {
00214       varArgDerivs = commonEval()->varArgDerivCache(myIndex());
00215     }
00216 }
00217 

Site Contact