SundanceChainRuleEvaluator.hpp
Go to the documentation of this file.
00001 /* @HEADER@ */
00002 // ************************************************************************
00003 // 
00004 //                              Sundance
00005 //                 Copyright (2005) Sandia Corporation
00006 // 
00007 // Copyright (year first published) Sandia Corporation.  Under the terms 
00008 // of Contract DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government 
00009 // retains certain rights in this software.
00010 // 
00011 // This library is free software; you can redistribute it and/or modify
00012 // it under the terms of the GNU Lesser General Public License as
00013 // published by the Free Software Foundation; either version 2.1 of the
00014 // License, or (at your option) any later version.
00015 //  
00016 // This library is distributed in the hope that it will be useful, but
00017 // WITHOUT ANY WARRANTY; without even the implied warranty of
00018 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00019 // Lesser General Public License for more details.
00020 //                                                                                 
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License along with this library; if not, write to the Free Software
00023 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00024 // USA                                                                                
00025 // Questions? Contact Kevin Long (krlong@sandia.gov), 
00026 // Sandia National Laboratories, Livermore, California, USA
00027 // 
00028 // ************************************************************************
00029 /* @HEADER@ */
00030 
00031 #ifndef SUNDANCE_CHAINRULEEVALUATOR_H
00032 #define SUNDANCE_CHAINRULEEVALUATOR_H
00033 
00034 #include "SundanceDefs.hpp"
00035 #include "SundanceSubtypeEvaluator.hpp"
00036 #include "SundanceExprWithChildren.hpp"
00037 #include "SundanceChainRuleSum.hpp"
00038 
00039 namespace Sundance 
00040 {
00041 /** 
00042  *
00043  */
00044 class ChainRuleEvaluator : public SubtypeEvaluator<ExprWithChildren>
00045 {
00046 public:
00047 
00048   /** */
00049   ChainRuleEvaluator(const ExprWithChildren* expr, 
00050     const EvalContext& context);
00051 
00052   /** */
00053   virtual ~ChainRuleEvaluator(){;}
00054 
00055   /** */
00056   virtual void internalEval(const EvalManager& mgr,
00057     Array<double>& constantResults,
00058     Array<RCP<EvalVector> >& vectorResults) const ;
00059 
00060   /** */
00061   int numChildren() const {return childEvaluators_.size();}
00062 
00063   /** */
00064   int constArgDerivIndex(const MultiSet<int>& df) const ;
00065 
00066   /** */
00067   int varArgDerivIndex(const MultiSet<int>& df) const ;
00068 
00069   /** */
00070   TEUCHOS_TIMER(chainRuleEvalTimer, "chain rule evaluation");
00071 
00072 
00073   /** */
00074   const Array<Array<int> >& nComps(int N, int n) const ;
00075 
00076   /** */
00077   void resetNumCalls() const ;
00078 
00079   /** 
00080    * Evaluate the derivatives of the expression with respect to
00081    * the arguments.
00082    *
00083    * \param mgr Manager for this evaluation step
00084    *
00085    * \param constDerivsOfArgs Constant values and functional
00086    * derivatives of arguments.  The outer array index is over
00087    * arguments. The inner array index is over functional
00088    * derivatives of that argument.
00089    *
00090    * \param varDerivsOfArgs Variable values and functional
00091    * derivatives of arguments.  The outer array index is over
00092    * arguments. The inner array index is over functional
00093    * derivatives of that argument.
00094    *
00095    * \param constArgDerivs Constant-valued derivatives of expr wrt
00096    * arguments.
00097    *
00098    * \param varArgDerivs Variable-valued derivatives of expr wrt
00099    * arguments.
00100    */
00101 
00102   virtual void evalArgDerivs(const EvalManager& mgr,
00103     const Array<RCP<Array<double> > >& constDerivsOfArgs,
00104     const Array<RCP<Array<RCP<EvalVector> > > >& varDerivOfArgs,
00105     Array<double>& constArgDerivs,
00106     Array<RCP<EvalVector> >& varArgDerivs) const = 0 ;
00107 
00108 
00109   static Set<MultiSet<MultipleDeriv> > chainRuleBins(const MultipleDeriv& d,
00110     const MultiSet<int>& q);
00111       
00112 protected:
00113   /** The init() function should be called from the derived class ctors */
00114   void init(const ExprWithChildren* expr, 
00115     const EvalContext& context);
00116 
00117   /** */
00118   void addConstArgDeriv(const MultiSet<int>& df, int index);
00119 
00120   /** */
00121   void addVarArgDeriv(const MultiSet<int>& df, int index);
00122 
00123   /** */
00124   const Evaluator* childEvaluator(int i) const {return childEvaluators_[i].get();}
00125 
00126   /** */
00127   const SparsitySuperset* childSparsity(int i) const {return childSparsity_[i].get();}
00128 
00129   static MultipleDeriv makeMD(const Array<Deriv>& d) ;
00130 
00131   /** Returns the binomial coefficient */
00132   double choose(int N, int n) const ;
00133       
00134   /** Returns the factorial of n */
00135   double fact(int n) const ;
00136 
00137 
00138   /** Returns the stirling number of the second kind */
00139   double stirling2(int n, int k) const ;
00140 
00141   /** */
00142   int derivComboMultiplicity(const MultiSet<MultipleDeriv>& b) const ;
00143 
00144 
00145 private:
00146 
00147       
00148   Array<RCP<ChainRuleSum> > expansions_;
00149 
00150   Array<RCP<Evaluator> > childEvaluators_;
00151 
00152   Array<RCP<SparsitySuperset> > childSparsity_;
00153 
00154   Map<MultiSet<int>, int> constArgDerivMap_;
00155 
00156   Map<MultiSet<int>, int> varArgDerivMap_;
00157 
00158   int zerothDerivResultIndex_;
00159 
00160   bool zerothDerivIsConstant_;
00161 
00162   static Map<OrderedPair<int, int>, Array<Array<int> > >& compMap() ;
00163 };
00164 
00165 /** */
00166 MultipleDeriv makeDeriv(const Expr& a);
00167 
00168 /** */
00169 MultipleDeriv makeDeriv(const Expr& a, const Expr& b);
00170 
00171 /** */
00172 MultipleDeriv makeDeriv(const Expr& a, const Expr& b, const Expr& c);
00173 
00174     
00175 }
00176 
00177 
00178 #endif

Site Contact