SundanceChainRuleSum.cpp
Go to the documentation of this file.
00001 /* @HEADER@ */
00002 // ************************************************************************
00003 // 
00004 //                              Sundance
00005 //                 Copyright (2005) Sandia Corporation
00006 // 
00007 // Copyright (year first published) Sandia Corporation.  Under the terms 
00008 // of Contract DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government 
00009 // retains certain rights in this software.
00010 // 
00011 // This library is free software; you can redistribute it and/or modify
00012 // it under the terms of the GNU Lesser General Public License as
00013 // published by the Free Software Foundation; either version 2.1 of the
00014 // License, or (at your option) any later version.
00015 //  
00016 // This library is distributed in the hope that it will be useful, but
00017 // WITHOUT ANY WARRANTY; without even the implied warranty of
00018 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00019 // Lesser General Public License for more details.
00020 //                                                                                 
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License along with this library; if not, write to the Free Software
00023 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00024 // USA                                                                                
00025 // Questions? Contact Kevin Long (krlong@sandia.gov), 
00026 // Sandia National Laboratories, Livermore, California, USA
00027 // 
00028 // ************************************************************************
00029 /* @HEADER@ */
00030 
00031 #include "SundanceChainRuleSum.hpp"
00032 #include "SundanceEvalManager.hpp"
00033 #include "SundanceEvalVector.hpp"
00034 #include "PlayaExceptions.hpp"
00035 #include "SundanceSet.hpp"
00036 #include "PlayaTabs.hpp"
00037 #include "SundanceOut.hpp"
00038 
00039 using namespace Sundance;
00040 using namespace Sundance;
00041 
00042 using namespace Sundance;
00043 using namespace Teuchos;
00044 
00045 
00046 ChainRuleSum::ChainRuleSum(const MultipleDeriv& md,
00047                            int resultIndex,
00048                            bool resultIsConstant)
00049   : md_(md),
00050     resultIndex_(resultIndex),
00051     resultIsConstant_(resultIsConstant),
00052     argDerivIndex_(),
00053     argDerivIsConstant_(),
00054     terms_()
00055 {;}
00056 
00057 
00058 void ChainRuleSum::addTerm(int argDerivIndex, 
00059                            bool argDerivIsConstant,
00060                            const Array<DerivProduct>& sum)
00061 {
00062   argDerivIndex_.append(argDerivIndex);
00063   argDerivIsConstant_.append(argDerivIsConstant);
00064   terms_.append(sum);
00065 }
00066 
00067 
00068 void ChainRuleSum
00069 ::evalConstant(const EvalManager& mgr,
00070                const Array<RCP<Array<double> > >& constantArgResults,
00071                const Array<double>& constantArgDerivs,
00072                double& result) const
00073 {
00074   Tabs tabs;
00075   SUNDANCE_VERB_HIGH(tabs << "ChainRuleSum::evalConstant()");
00076   result = 0.0;
00077   for (int i=0; i<numTerms(); i++)
00078     {
00079       const double& argDeriv = constantArgDerivs[argDerivIndex(i)];
00080       const Array<DerivProduct>& sumOfDerivProducts = terms(i);
00081       double innerSum = 0.0;
00082       for (int j=0; j<sumOfDerivProducts.size(); j++)
00083         {
00084           double prod = 1.0;
00085           const DerivProduct& p = sumOfDerivProducts[j];
00086           for (int k=0; k<p.numConstants(); k++)
00087             {
00088               const IndexPair& ip = p.constant(k);
00089               prod *= (*(constantArgResults[ip.argIndex()]))[ip.valueIndex()];
00090             }
00091           innerSum += prod;
00092         }
00093       result += innerSum*argDeriv;
00094     }
00095 }
00096 
00097 
00098 void ChainRuleSum
00099 ::evalVar(const EvalManager& mgr,
00100           const Array<RCP<Array<double> > >& constantArgResults,
00101           const Array<RCP<Array<RCP<EvalVector> > > > & vArgResults,
00102           const Array<double>& constantArgDerivs,
00103           const Array<RCP<EvalVector> >& varArgDerivs,
00104           RCP<EvalVector>& varResult) const
00105 {
00106   Tabs tabs;
00107   SUNDANCE_VERB_HIGH(tabs << "ChainRuleSum::evalVar()");
00108   int vecSize=-1;
00109   for (int i=0; i<varArgDerivs.size(); i++)
00110     {
00111       int s = varArgDerivs[i]->length();
00112       TEUCHOS_TEST_FOR_EXCEPTION(vecSize != -1 && s != vecSize, std::logic_error,
00113                          "inconsistent vector sizes " << vecSize
00114                          << " and " << s);
00115       vecSize = s;
00116     } 
00117   for (int i=0; i<vArgResults.size(); i++)
00118     {
00119       for (int j=0; j<vArgResults[i]->size(); j++)
00120         {
00121           int s = (*(vArgResults[i]))[j]->length();
00122           TEUCHOS_TEST_FOR_EXCEPTION(vecSize != -1 && s != vecSize, std::logic_error,
00123                              "inconsistent vector sizes " << vecSize
00124                              << " and " << s);
00125           vecSize = s;
00126         }
00127     } 
00128   TEUCHOS_TEST_FOR_EXCEPT(vecSize==-1);
00129   
00130   varResult = mgr.popVector();
00131   varResult->resize(vecSize);
00132   varResult->setToConstant(0.0);
00133 
00134   for (int i=0; i<numTerms(); i++)
00135     {
00136       Tabs tab1;
00137       SUNDANCE_VERB_HIGH(tab1 << "term=" << i << " of " << numTerms());
00138       RCP<EvalVector> innerSum = mgr.popVector();
00139       innerSum->resize(vecSize);
00140       innerSum->setToConstant(0.0);
00141       const Array<DerivProduct>& sumOfDerivProducts = terms(i);
00142 
00143       SUNDANCE_VERB_HIGH(tab1 << "inner sum init = " << *innerSum
00144                          << ", num terms = " << terms(i).size());
00145 
00146       for (int j=0; j<sumOfDerivProducts.size(); j++)
00147         {
00148           Tabs tab2;
00149           SUNDANCE_VERB_HIGH(tab2 << "dp=" << j << " of " << sumOfDerivProducts.size());
00150           const DerivProduct& p = sumOfDerivProducts[j];
00151           double cc = p.coeff();
00152           SUNDANCE_VERB_HIGH(tab2 << "multiplicity=" << cc);
00153           for (int k=0; k<p.numConstants(); k++)
00154             {
00155               const IndexPair& ip = p.constant(k);
00156               cc *= (*(constantArgResults[ip.argIndex()]))[ip.valueIndex()];
00157             }
00158           if (p.numVariables()==0)
00159             {
00160               innerSum->add_S(cc);
00161             }
00162           else if (p.numVariables()==1)
00163             {
00164               const IndexPair& ip = p.variable(0);
00165               const EvalVector* v 
00166                 = (*(vArgResults[ip.argIndex()]))[ip.valueIndex()].get();
00167               if (cc==1.0) innerSum->add_V(v);
00168               else innerSum->add_SV(cc, v);
00169             }
00170           else if (p.numVariables()==2)
00171             {
00172               const IndexPair& ip0 = p.variable(0);
00173               const EvalVector* v0 
00174                 = (*(vArgResults[ip0.argIndex()]))[ip0.valueIndex()].get();
00175               const IndexPair& ip1 = p.variable(1);
00176               const EvalVector* v1
00177                 = (*(vArgResults[ip1.argIndex()]))[ip1.valueIndex()].get();
00178               if (cc==1.0) innerSum->add_VV(v0, v1);
00179               else innerSum->add_SVV(cc, v0, v1);
00180             }
00181           else
00182             {
00183               const IndexPair& ip0 = p.variable(0);
00184               const EvalVector* v0 
00185                 = (*(vArgResults[ip0.argIndex()]))[ip0.valueIndex()].get();
00186               RCP<EvalVector> tmp = v0->clone();
00187               for (int k=1; k<p.numVariables(); k++)
00188                 {
00189                   const IndexPair& ip1 = p.variable(k);
00190                   const EvalVector* v1
00191                     = (*(vArgResults[ip1.argIndex()]))[ip1.valueIndex()].get();
00192                   tmp->multiply_V(v1);
00193                 }
00194               if (cc==1.0) innerSum->add_V(tmp.get());
00195               else innerSum->add_SV(cc, tmp.get());
00196             }
00197           SUNDANCE_VERB_HIGH(tab2 << "inner sum=" << *innerSum);
00198         }
00199 
00200       int adi = argDerivIndex(i);
00201       if (argDerivIsConstant(i))
00202         {
00203           const double& df_dq = constantArgDerivs[adi];
00204           varResult->add_SV(df_dq, innerSum.get());
00205         }
00206       else
00207         {
00208           const EvalVector* df_dq = varArgDerivs[adi].get();
00209           SUNDANCE_VERB_HIGH(tab1 << "arg deriv=" << *df_dq);
00210           varResult->add_VV(df_dq, innerSum.get());
00211           SUNDANCE_VERB_HIGH(tab1 << "outer sum=" << *varResult);
00212         }
00213     }
00214 }
00215 
00216 

Site Contact