SundanceChainRuleSum.cpp
Go to the documentation of this file.
00001 /* @HEADER@ */
00002 // ************************************************************************
00003 // 
00004 //                             Sundance
00005 //                 Copyright 2011 Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 //
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Kevin Long (kevin.long@ttu.edu)
00038 // 
00039 
00040 /* @HEADER@ */
00041 
00042 #include "SundanceChainRuleSum.hpp"
00043 #include "SundanceEvalManager.hpp"
00044 #include "SundanceEvalVector.hpp"
00045 #include "PlayaExceptions.hpp"
00046 #include "SundanceSet.hpp"
00047 #include "PlayaTabs.hpp"
00048 #include "SundanceOut.hpp"
00049 
00050 using namespace Sundance;
00051 using namespace Sundance;
00052 
00053 using namespace Sundance;
00054 using namespace Teuchos;
00055 
00056 
00057 ChainRuleSum::ChainRuleSum(const MultipleDeriv& md,
00058                            int resultIndex,
00059                            bool resultIsConstant)
00060   : md_(md),
00061     resultIndex_(resultIndex),
00062     resultIsConstant_(resultIsConstant),
00063     argDerivIndex_(),
00064     argDerivIsConstant_(),
00065     terms_()
00066 {;}
00067 
00068 
00069 void ChainRuleSum::addTerm(int argDerivIndex, 
00070                            bool argDerivIsConstant,
00071                            const Array<DerivProduct>& sum)
00072 {
00073   argDerivIndex_.append(argDerivIndex);
00074   argDerivIsConstant_.append(argDerivIsConstant);
00075   terms_.append(sum);
00076 }
00077 
00078 
00079 void ChainRuleSum
00080 ::evalConstant(const EvalManager& mgr,
00081                const Array<RCP<Array<double> > >& constantArgResults,
00082                const Array<double>& constantArgDerivs,
00083                double& result) const
00084 {
00085   Tabs tabs;
00086   SUNDANCE_VERB_HIGH(tabs << "ChainRuleSum::evalConstant()");
00087   result = 0.0;
00088   for (int i=0; i<numTerms(); i++)
00089     {
00090       const double& argDeriv = constantArgDerivs[argDerivIndex(i)];
00091       const Array<DerivProduct>& sumOfDerivProducts = terms(i);
00092       double innerSum = 0.0;
00093       for (int j=0; j<sumOfDerivProducts.size(); j++)
00094         {
00095           double prod = 1.0;
00096           const DerivProduct& p = sumOfDerivProducts[j];
00097           for (int k=0; k<p.numConstants(); k++)
00098             {
00099               const IndexPair& ip = p.constant(k);
00100               prod *= (*(constantArgResults[ip.argIndex()]))[ip.valueIndex()];
00101             }
00102           innerSum += prod;
00103         }
00104       result += innerSum*argDeriv;
00105     }
00106 }
00107 
00108 
00109 void ChainRuleSum
00110 ::evalVar(const EvalManager& mgr,
00111           const Array<RCP<Array<double> > >& constantArgResults,
00112           const Array<RCP<Array<RCP<EvalVector> > > > & vArgResults,
00113           const Array<double>& constantArgDerivs,
00114           const Array<RCP<EvalVector> >& varArgDerivs,
00115           RCP<EvalVector>& varResult) const
00116 {
00117   Tabs tabs;
00118   SUNDANCE_VERB_HIGH(tabs << "ChainRuleSum::evalVar()");
00119   int vecSize=-1;
00120   for (int i=0; i<varArgDerivs.size(); i++)
00121     {
00122       int s = varArgDerivs[i]->length();
00123       TEUCHOS_TEST_FOR_EXCEPTION(vecSize != -1 && s != vecSize, std::logic_error,
00124                          "inconsistent vector sizes " << vecSize
00125                          << " and " << s);
00126       vecSize = s;
00127     } 
00128   for (int i=0; i<vArgResults.size(); i++)
00129     {
00130       for (int j=0; j<vArgResults[i]->size(); j++)
00131         {
00132           int s = (*(vArgResults[i]))[j]->length();
00133           TEUCHOS_TEST_FOR_EXCEPTION(vecSize != -1 && s != vecSize, std::logic_error,
00134                              "inconsistent vector sizes " << vecSize
00135                              << " and " << s);
00136           vecSize = s;
00137         }
00138     } 
00139   TEUCHOS_TEST_FOR_EXCEPT(vecSize==-1);
00140   
00141   varResult = mgr.popVector();
00142   varResult->resize(vecSize);
00143   varResult->setToConstant(0.0);
00144 
00145   for (int i=0; i<numTerms(); i++)
00146     {
00147       Tabs tab1;
00148       SUNDANCE_VERB_HIGH(tab1 << "term=" << i << " of " << numTerms());
00149       RCP<EvalVector> innerSum = mgr.popVector();
00150       innerSum->resize(vecSize);
00151       innerSum->setToConstant(0.0);
00152       const Array<DerivProduct>& sumOfDerivProducts = terms(i);
00153 
00154       SUNDANCE_VERB_HIGH(tab1 << "inner sum init = " << *innerSum
00155                          << ", num terms = " << terms(i).size());
00156 
00157       for (int j=0; j<sumOfDerivProducts.size(); j++)
00158         {
00159           Tabs tab2;
00160           SUNDANCE_VERB_HIGH(tab2 << "dp=" << j << " of " << sumOfDerivProducts.size());
00161           const DerivProduct& p = sumOfDerivProducts[j];
00162           double cc = p.coeff();
00163           SUNDANCE_VERB_HIGH(tab2 << "multiplicity=" << cc);
00164           for (int k=0; k<p.numConstants(); k++)
00165             {
00166               const IndexPair& ip = p.constant(k);
00167               cc *= (*(constantArgResults[ip.argIndex()]))[ip.valueIndex()];
00168             }
00169           if (p.numVariables()==0)
00170             {
00171               innerSum->add_S(cc);
00172             }
00173           else if (p.numVariables()==1)
00174             {
00175               const IndexPair& ip = p.variable(0);
00176               const EvalVector* v 
00177                 = (*(vArgResults[ip.argIndex()]))[ip.valueIndex()].get();
00178               if (cc==1.0) innerSum->add_V(v);
00179               else innerSum->add_SV(cc, v);
00180             }
00181           else if (p.numVariables()==2)
00182             {
00183               const IndexPair& ip0 = p.variable(0);
00184               const EvalVector* v0 
00185                 = (*(vArgResults[ip0.argIndex()]))[ip0.valueIndex()].get();
00186               const IndexPair& ip1 = p.variable(1);
00187               const EvalVector* v1
00188                 = (*(vArgResults[ip1.argIndex()]))[ip1.valueIndex()].get();
00189               if (cc==1.0) innerSum->add_VV(v0, v1);
00190               else innerSum->add_SVV(cc, v0, v1);
00191             }
00192           else
00193             {
00194               const IndexPair& ip0 = p.variable(0);
00195               const EvalVector* v0 
00196                 = (*(vArgResults[ip0.argIndex()]))[ip0.valueIndex()].get();
00197               RCP<EvalVector> tmp = v0->clone();
00198               for (int k=1; k<p.numVariables(); k++)
00199                 {
00200                   const IndexPair& ip1 = p.variable(k);
00201                   const EvalVector* v1
00202                     = (*(vArgResults[ip1.argIndex()]))[ip1.valueIndex()].get();
00203                   tmp->multiply_V(v1);
00204                 }
00205               if (cc==1.0) innerSum->add_V(tmp.get());
00206               else innerSum->add_SV(cc, tmp.get());
00207             }
00208           SUNDANCE_VERB_HIGH(tab2 << "inner sum=" << *innerSum);
00209         }
00210 
00211       int adi = argDerivIndex(i);
00212       if (argDerivIsConstant(i))
00213         {
00214           const double& df_dq = constantArgDerivs[adi];
00215           varResult->add_SV(df_dq, innerSum.get());
00216         }
00217       else
00218         {
00219           const EvalVector* df_dq = varArgDerivs[adi].get();
00220           SUNDANCE_VERB_HIGH(tab1 << "arg deriv=" << *df_dq);
00221           varResult->add_VV(df_dq, innerSum.get());
00222           SUNDANCE_VERB_HIGH(tab1 << "outer sum=" << *varResult);
00223         }
00224     }
00225 }
00226 
00227 

Site Contact