SundanceUserDefOpEvaluator.cpp
Go to the documentation of this file.
00001 /* @HEADER@ */
00002 // ************************************************************************
00003 // 
00004 //                              Sundance
00005 //                 Copyright (2005) Sandia Corporation
00006 // 
00007 // Copyright (year first published) Sandia Corporation.  Under the terms 
00008 // of Contract DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government 
00009 // retains certain rights in this software.
00010 // 
00011 // This library is free software; you can redistribute it and/or modify
00012 // it under the terms of the GNU Lesser General Public License as
00013 // published by the Free Software Foundation; either version 2.1 of the
00014 // License, or (at your option) any later version.
00015 //  
00016 // This library is distributed in the hope that it will be useful, but
00017 // WITHOUT ANY WARRANTY; without even the implied warranty of
00018 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00019 // Lesser General Public License for more details.
00020 //                                                                                 
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License along with this library; if not, write to the Free Software
00023 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00024 // USA                                                                                
00025 // Questions? Contact Kevin Long (krlong@sandia.gov), 
00026 // Sandia National Laboratories, Livermore, California, USA
00027 // 
00028 // ************************************************************************
00029 /* @HEADER@ */
00030 
00031 #include "SundanceUserDefOpEvaluator.hpp"
00032 #include "SundanceUserDefOpCommonEvaluator.hpp"
00033 #include "SundanceUserDefOpElement.hpp"
00034 #include "SundanceEvalManager.hpp"
00035 
00036 #include "PlayaTabs.hpp"
00037 #include "SundanceOut.hpp"
00038 #include "SundanceUserDefOp.hpp"
00039 
00040 using namespace Sundance;
00041 using namespace Sundance;
00042 using namespace Sundance;
00043 using namespace Teuchos;
00044 
00045 
00046 
00047 
00048 
00049 UserDefOpEvaluator
00050 ::UserDefOpEvaluator(const UserDefOpElement* expr,
00051                      const RCP<const UserDefOpCommonEvaluator>& commonEval,
00052                      const EvalContext& context)
00053   : ChainRuleEvaluator(expr, context),
00054     argValueIndex_(expr->numChildren()),
00055     argValueIsConstant_(expr->numChildren()),
00056     functor_(expr->functorElement()),
00057     commonEval_(commonEval),
00058     maxOrder_(0),
00059     numVarArgDerivs_(0),
00060     numConstArgDerivs_(0),
00061     allArgsAreConstant_(true)
00062 {
00063   Tabs tab1;
00064   SUNDANCE_VERB_LOW(tab1 << "initializing user defined op evaluator for " 
00065                     << expr->toString());
00066   Array<int> orders = findRequiredOrders(expr, context);
00067 
00068   for (int i=0; i<orders.size(); i++) 
00069     {
00070       if (orders[i] > maxOrder_) maxOrder_ = orders[i];
00071     }
00072   commonEval->updateMaxOrder(maxOrder_);
00073 
00074   SUNDANCE_VERB_HIGH(tab1 << "setting arg deriv indices");
00075   
00076   
00077   /* Find the mapping from argument derivatives to indices in the 
00078    * functor's vector of return values */
00079   Map<MultiSet<int>, int> varArgDerivs;
00080   Map<MultiSet<int>, int> constArgDerivs;
00081   expr->getArgDerivIndices(orders, varArgDerivs, constArgDerivs);
00082   numVarArgDerivs_ = varArgDerivs.size();
00083   numConstArgDerivs_ = constArgDerivs.size();
00084   typedef Map<MultiSet<int>, int>::const_iterator iter;
00085   for (iter i=varArgDerivs.begin(); i!=varArgDerivs.end(); i++)
00086     {
00087       Tabs tab2;
00088       SUNDANCE_VERB_EXTREME(tab2 << "variable arg deriv " << i->first 
00089                             << " will be at index "
00090                             << i->second);
00091       addVarArgDeriv(i->first, i->second);
00092     }
00093   
00094   for (iter i=constArgDerivs.begin(); i!=constArgDerivs.end(); i++)
00095     {
00096       Tabs tab2;
00097       SUNDANCE_VERB_EXTREME(tab2 << "constant arg deriv " << i->first 
00098                             << " will be at index "
00099                             << i->second);
00100       addConstArgDeriv(i->first, i->second);
00101     }
00102 
00103   /* Find the indices to the zeroth derivative of each argument */
00104   
00105   for (int i=0; i<expr->numChildren(); i++)
00106     {
00107       const SparsitySuperset* sArg = childSparsity(i);
00108       int numConst=0;
00109       int numVec=0;
00110       for (int j=0; j<sArg->numDerivs(); j++)
00111         {
00112           if (sArg->deriv(j).order() == 0) 
00113             {
00114               if (sArg->state(j)==VectorDeriv)
00115                 {
00116                   argValueIndex_[i] = numVec;              
00117                   allArgsAreConstant_ = false;
00118                 }
00119               else
00120                 {
00121                   argValueIndex_[i] = numConst;              
00122                 }
00123               break;
00124             }
00125           if (sArg->state(j) == VectorDeriv) 
00126             {
00127               numVec++;
00128             }
00129           else
00130             {
00131               numConst++;
00132             }
00133         }
00134     }
00135   
00136   /* Call init() at the base class to set up chain rule evaluation */
00137   init(expr, context);
00138 }
00139 
00140 
00141 
00142 
00143 void UserDefOpEvaluator::resetNumCalls() const
00144 {
00145   commonEval()->markCacheAsInvalid();
00146   ChainRuleEvaluator::resetNumCalls();
00147 }
00148 
00149 
00150 
00151 
00152 Array<int> UserDefOpEvaluator::findRequiredOrders(const ExprWithChildren* expr, 
00153                                                   const EvalContext& context)
00154 {
00155   Tabs tab0;
00156   SUNDANCE_VERB_HIGH(tab0 << "finding required arg deriv orders");
00157 
00158   Set<int> orders;
00159   
00160   const Set<MultipleDeriv>& R = expr->findR(context);
00161   typedef Set<MultipleDeriv>::const_iterator iter;
00162 
00163   for (iter md=R.begin(); md!=R.end(); md++)
00164     {
00165       Tabs tab1;
00166       
00167       int N = md->order();
00168       if (N > maxOrder_) maxOrder_ = N;
00169       if (N==0) orders.put(N);
00170       for (int n=1; n<=N; n++)
00171         {
00172           const Set<MultiSet<int> >& QW = expr->findQ_W(n, context);
00173           for (Set<MultiSet<int> >::const_iterator q=QW.begin(); q!=QW.end(); q++)
00174             {
00175               orders.put(q->size());
00176             }
00177         }
00178     }
00179   SUNDANCE_VERB_HIGH(tab0 << "arg deriv orders=" << orders);
00180   return orders.elements();
00181 }
00182 
00183 
00184 
00185 
00186 void UserDefOpEvaluator
00187 ::evalArgDerivs(const EvalManager& mgr,
00188                 const Array<RCP<Array<double> > >& constArgVals,
00189                 const Array<RCP<Array<RCP<EvalVector> > > >& varArgVals,
00190                 Array<double>& constArgDerivs,
00191                 Array<RCP<EvalVector> >& varArgDerivs) const
00192 {
00193   if (!commonEval()->cacheIsValid())
00194     {
00195       commonEval()->evalAllComponents(mgr, constArgVals, varArgVals);
00196     }
00197   if (allArgsAreConstant_)
00198     {
00199       constArgDerivs = commonEval()->constArgDerivCache(myIndex());
00200     }
00201   else
00202     {
00203       varArgDerivs = commonEval()->varArgDerivCache(myIndex());
00204     }
00205 }
00206 

Site Contact