00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042 #include "SundanceChainRuleEvaluator.hpp"
00043 #include "SundanceCombinatorialUtils.hpp"
00044
00045 #include "SundanceUnknownFuncElement.hpp"
00046 #include "SundanceEvalManager.hpp"
00047 #include "PlayaExceptions.hpp"
00048 #include "SundanceSet.hpp"
00049 #include "PlayaTabs.hpp"
00050 #include "SundanceOut.hpp"
00051
00052
00053 using namespace Sundance;
00054 using namespace Sundance;
00055
00056 using namespace Sundance;
00057 using namespace Teuchos;
00058
00059
00060 ChainRuleEvaluator::ChainRuleEvaluator(const ExprWithChildren* expr,
00061 const EvalContext& context)
00062 : SubtypeEvaluator<ExprWithChildren>(expr, context),
00063 expansions_(),
00064 childEvaluators_(expr->numChildren()),
00065 childSparsity_(expr->numChildren()),
00066 constArgDerivMap_(),
00067 varArgDerivMap_(),
00068 zerothDerivResultIndex_(-1),
00069 zerothDerivIsConstant_(false)
00070 {
00071 Tabs tabs;
00072 SUNDANCE_MSG1(context.setupVerbosity(),
00073 tabs << "ChainRuleEvaluator base class ctor for "
00074 << expr->toString());
00075 for (int i=0; i<numChildren(); i++)
00076 {
00077 childEvaluators_[i] = expr->evaluatableChild(i)->evaluator(context);
00078 childEvaluators_[i]->addClient();
00079 childSparsity_[i] = expr->evaluatableChild(i)->sparsitySuperset(context);
00080 }
00081 }
00082
00083 Sundance::Map<OrderedPair<int, int>, Array<Array<int> > >& ChainRuleEvaluator::compMap()
00084 {
00085 static Map<OrderedPair<int, int>, Array<Array<int> > > rtn;
00086 return rtn;
00087 }
00088
00089 void ChainRuleEvaluator::resetNumCalls() const
00090 {
00091 for (int i=0; i<numChildren(); i++)
00092 {
00093 childEvaluators_[i]->resetNumCalls();
00094 }
00095 Evaluator::resetNumCalls();
00096 }
00097
00098
00099 void ChainRuleEvaluator::addConstArgDeriv(const MultiSet<int>& df, int index)
00100 {
00101 constArgDerivMap_.put(df, index);
00102 }
00103
00104 void ChainRuleEvaluator::addVarArgDeriv(const MultiSet<int>& df, int index)
00105 {
00106 varArgDerivMap_.put(df, index);
00107 }
00108
00109 int ChainRuleEvaluator::constArgDerivIndex(const MultiSet<int>& df) const
00110 {
00111 TEUCHOS_TEST_FOR_EXCEPTION(!constArgDerivMap_.containsKey(df), std::logic_error,
00112 "argument derivative " << df << " not found in constant "
00113 "argument derivative map");
00114
00115 return constArgDerivMap_.get(df);
00116 }
00117
00118 int ChainRuleEvaluator::varArgDerivIndex(const MultiSet<int>& df) const
00119 {
00120 TEUCHOS_TEST_FOR_EXCEPTION(!varArgDerivMap_.containsKey(df), std::logic_error,
00121 "argument derivative " << df << " not found in variable "
00122 "argument derivative map");
00123
00124 return varArgDerivMap_.get(df);
00125 }
00126
00127
00128 const Array<Array<int> >& ChainRuleEvaluator::nComps(int N, int n) const
00129 {
00130 OrderedPair<int,int> key(n,N);
00131 if (!compMap().containsKey(key))
00132 {
00133 compMap().put(key, compositions(N)[n-1]);
00134 }
00135 return compMap().get(key);
00136 }
00137
00138
00139 double ChainRuleEvaluator::fact(int n) const
00140 {
00141 TEUCHOS_TEST_FOR_EXCEPTION(n<0, std::logic_error, "negative argument " << n << " to factorial");
00142 if (n==0 || n==1) return 1.0;
00143 return n*fact(n-1);
00144 }
00145
00146 double ChainRuleEvaluator::choose(int N, int n) const
00147 {
00148 return fact(N)/fact(n)/fact(N-n);
00149 }
00150
00151 double ChainRuleEvaluator::stirling2(int n, int k) const
00152 {
00153 if (n < k) return 0;
00154 if (n == k) return 1;
00155 if (k<=0) return 0;
00156 if (k==1) return 1;
00157 if (n-1 == k) return choose(n, 2);
00158 return k*stirling2(n-1, k) + stirling2(n-1, k-1);
00159 }
00160
00161
00162 MultipleDeriv ChainRuleEvaluator::makeMD(const Array<Deriv>& d)
00163 {
00164 MultipleDeriv rtn;
00165 for (int i=0; i<d.size(); i++)
00166 {
00167 rtn.put(d[i]);
00168 }
00169 return rtn;
00170 }
00171
00172
00173 Set<MultiSet<MultipleDeriv> > ChainRuleEvaluator::chainRuleBins(const MultipleDeriv& d,
00174 const MultiSet<int>& q)
00175 {
00176 int n = q.size();
00177 Array<Array<Array<Deriv> > > bins = binnings(d, n);
00178
00179 Set<MultiSet<MultipleDeriv> > rtn;
00180
00181 for (int i=0; i<bins.size(); i++)
00182 {
00183 MultiSet<MultipleDeriv> b;
00184 for (int j=0; j<bins[i].size(); j++)
00185 {
00186 b.put(makeMD(bins[i][j]));
00187 }
00188 rtn.put(b);
00189 }
00190
00191
00192 return rtn;
00193 }
00194
00195
00196 int ChainRuleEvaluator::derivComboMultiplicity(const MultiSet<MultipleDeriv>& b) const
00197 {
00198
00199
00200 MultipleDeriv dTot;
00201 Array<MultiSet<Deriv> > derivSets(b.size());
00202 Array<Array<Deriv> > derivArrays(b.size());
00203 Set<Deriv> allDerivs;
00204 int k=0;
00205 bool allDerivsAreDistinct = true;
00206 bool allDerivsAreIdentical = true;
00207 for (MultiSet<MultipleDeriv>::const_iterator i=b.begin(); i!=b.end(); i++, k++)
00208 {
00209 for (MultipleDeriv::const_iterator j=i->begin(); j!=i->end(); j++)
00210 {
00211 derivSets[k].put(*j);
00212 derivArrays[k].append(*j);
00213 dTot.put(*j);
00214 if (allDerivs.contains(*j)) allDerivsAreDistinct = false;
00215 if (allDerivs.size()>0 && !allDerivs.contains(*j)) allDerivsAreIdentical = false;
00216 allDerivs.put(*j);
00217 }
00218 }
00219 int totOrder = dTot.order();
00220
00221
00222 TEUCHOS_TEST_FOR_EXCEPTION(totOrder > 3, std::logic_error,
00223 "deriv order " << totOrder << " not supported");
00224
00225 if (b.size()==1) return 1;
00226 if (totOrder == (int) b.size()) return 1;
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238 TEUCHOS_TEST_FOR_EXCEPTION(derivArrays.size() != 2, std::logic_error,
00239 "unexpected size=" << derivArrays.size());
00240
00241 if (allDerivsAreIdentical) return 3;
00242 if (allDerivsAreDistinct) return 1;
00243
00244 if (derivArrays[0].size()==1)
00245 {
00246 if (derivSets[1].contains(derivArrays[0][0])) return 2;
00247 return 1;
00248 }
00249 else
00250 {
00251 if (derivSets[0].contains(derivArrays[1][0])) return 2;
00252 return 1;
00253 }
00254 }
00255
00256
00257 void ChainRuleEvaluator::init(const ExprWithChildren* expr,
00258 const EvalContext& context)
00259 {
00260 int verb = context.setupVerbosity();
00261
00262 typedef Array<OrderedPair<Array<MultiSet<int> >, Array<MultipleDeriv> > > CR;
00263 Tabs tabs;
00264 SUNDANCE_MSG1(verb, tabs << "ChainRuleEvaluator::init() for "
00265 << expr->toString());
00266
00267 const Set<MultipleDeriv>& C = expr->findC(context);
00268 const Set<MultipleDeriv>& R = expr->findR(context);
00269
00270 Array<Set<MultipleDeriv> > argV(expr->numChildren());
00271 Array<Set<MultipleDeriv> > argC(expr->numChildren());
00272 Array<Set<MultipleDeriv> > argR(expr->numChildren());
00273
00274 for (int i=0; i<numChildren(); i++)
00275 {
00276 argV[i] = expr->evaluatableChild(i)->findV(context);
00277 argC[i] = expr->evaluatableChild(i)->findC(context);
00278 argR[i] = expr->evaluatableChild(i)->findR(context);
00279 }
00280 SUNDANCE_MSG3(verb, tabs << "sparsity = " << *(this->sparsity()));
00281 typedef Set<MultipleDeriv>::const_iterator iter;
00282
00283 int count=0;
00284 int vecResultIndex = 0;
00285 int constResultIndex = 0;
00286 for (iter md=R.begin(); md!=R.end(); md++, count++)
00287 {
00288 Tabs tab1;
00289 SUNDANCE_MSG3(verb, tab1 << "working out evaluator for " << *md);
00290 int N = md->order();
00291 bool resultIsConstant = C.contains(*md);
00292 int resultIndex;
00293 if (resultIsConstant)
00294 {
00295 Tabs tab2;
00296 SUNDANCE_MSG3(verb, tab2 << "result is constant, const index=" << constResultIndex);
00297 addConstantIndex(count, constResultIndex);
00298 resultIndex = constResultIndex;
00299 constResultIndex++;
00300 }
00301 else
00302 {
00303 Tabs tab2;
00304 SUNDANCE_MSG3(verb, tab2 << "result is variable, vec index=" << vecResultIndex);
00305 addVectorIndex(count, vecResultIndex);
00306 resultIndex = vecResultIndex;
00307 vecResultIndex++;
00308 }
00309
00310 SUNDANCE_MSG3(verb, tab1 << "order=" << N);
00311
00312 if (N==0)
00313 {
00314 Tabs tab2;
00315 SUNDANCE_MSG3(verb, tab2 << "zeroth deriv index=" << resultIndex);
00316 zerothDerivIsConstant_ = resultIsConstant;
00317 zerothDerivResultIndex_ = resultIndex;
00318 continue;
00319 }
00320
00321
00322
00323 RCP<ChainRuleSum> sum
00324 = rcp(new ChainRuleSum(*md, resultIndex, resultIsConstant));
00325
00326 const MultipleDeriv& nu = *md;
00327
00328 for (int n=1; n<=N; n++)
00329 {
00330 Tabs tab2;
00331 SUNDANCE_MSG3(verb, tab2 << "n=" << n);
00332 const Set<MultiSet<int> >& QW = expr->findQ_W(n, context);
00333 const Set<MultiSet<int> >& QC = expr->findQ_C(n, context);
00334 SUNDANCE_MSG3(verb, tab2 << "Q_W=" << QW);
00335 SUNDANCE_MSG3(verb, tab2 << "Q_C=" << QC);
00336 for (Set<MultiSet<int> >::const_iterator
00337 j=QW.begin(); j!=QW.end(); j++)
00338 {
00339 Tabs tab3;
00340 const MultiSet<int>& lambda = *j;
00341 SUNDANCE_MSG3(verb, tab3 << "arg index set =" << lambda);
00342 bool argDerivIsConstant = QC.contains(lambda);
00343 int argDerivIndex = -1;
00344 if (argDerivIsConstant)
00345 {
00346 argDerivIndex = constArgDerivIndex(lambda);
00347 }
00348 else
00349 {
00350 argDerivIndex = varArgDerivIndex(lambda);
00351 }
00352 Array<DerivProduct> pSum;
00353 for (int s=1; s<=N; s++)
00354 {
00355 Tabs tab4;
00356 SUNDANCE_MSG3(verb, tab4 << "preparing chain rule terms for "
00357 "s=" << s << ", lambda=" << lambda << ", nu=" << nu);
00358 CR p = chainRuleTerms(s, lambda, nu);
00359 for (CR::const_iterator j=p.begin(); j!=p.end(); j++)
00360 {
00361 Tabs tab5;
00362 Array<MultiSet<int> > K = j->first();
00363 Array<MultipleDeriv> L = j->second();
00364 SUNDANCE_MSG3(verb, tab5 << "K=" << K << std::endl << tab5 << "L=" << L);
00365 double weight = chainRuleMultiplicity(nu, K, L);
00366 SUNDANCE_MSG3(verb, tab5 << "weight=" << weight);
00367 DerivProduct prod(weight);
00368 bool termIsZero = false;
00369 for (int j=0; j<K.size(); j++)
00370 {
00371 for (MultiSet<int>::const_iterator
00372 k=K[j].begin(); k!=K[j].end(); k++)
00373 {
00374 int argIndex = *k;
00375 const MultipleDeriv& derivOfArg = L[j];
00376 const RCP<SparsitySuperset>& argSp
00377 = childSparsity_[argIndex];
00378 const RCP<Evaluator>& argEv
00379 = childEvaluators_[argIndex];
00380
00381 int rawValIndex = argSp->getIndex(derivOfArg);
00382 SUNDANCE_MSG3(verb, tab5 << "argR="
00383 << argR[argIndex]);
00384 if (argV[argIndex].contains(derivOfArg))
00385 {
00386 SUNDANCE_MSG3(verb, tab5 << "mdArg is variable");
00387 int valIndex
00388 = argEv->vectorIndexMap().get(rawValIndex);
00389 prod.addVariableFactor(IndexPair(argIndex, valIndex));
00390 }
00391 else if (argC[argIndex].contains(derivOfArg))
00392 {
00393 SUNDANCE_MSG3(verb, tab5 << "mdArg is constant");
00394 int valIndex
00395 = argEv->constantIndexMap().get(rawValIndex);
00396 prod.addConstantFactor(IndexPair(argIndex, valIndex));
00397 }
00398 else
00399 {
00400 SUNDANCE_MSG3(verb, tab5 << "mdArg is zero");
00401 termIsZero = true;
00402 break;
00403 }
00404 }
00405 if (termIsZero) break;
00406 }
00407 if (!termIsZero) pSum.append(prod);
00408 }
00409 }
00410 sum->addTerm(argDerivIndex, argDerivIsConstant, pSum);
00411 }
00412 }
00413 TEUCHOS_TEST_FOR_EXCEPTION(sum->numTerms()==0, std::logic_error,
00414 "Empty sum in chain rule expansion for derivative "
00415 << *md);
00416 expansions_.append(sum);
00417 }
00418
00419 SUNDANCE_MSG3(verb, tabs << "num constant results: "
00420 << this->sparsity()->numConstantDerivs());
00421
00422 SUNDANCE_MSG3(verb, tabs << "num var results: "
00423 << this->sparsity()->numVectorDerivs());
00424
00425
00426 }
00427
00428
00429
00430 void ChainRuleEvaluator::internalEval(const EvalManager& mgr,
00431 Array<double>& constantResults,
00432 Array<RCP<EvalVector> >& vectorResults) const
00433 {
00434 TimeMonitor timer(chainRuleEvalTimer());
00435 Tabs tabs(0);
00436
00437 SUNDANCE_MSG1(mgr.verb(), tabs << "ChainRuleEvaluator::eval() expr="
00438 << expr()->toString());
00439
00440
00441 SUNDANCE_MSG2(mgr.verb(), tabs << "max diff order = " << mgr.getRegion().topLevelDiffOrder());
00442 SUNDANCE_MSG2(mgr.verb(), tabs << "return sparsity " << std::endl << tabs << *(this->sparsity()));
00443
00444 constantResults.resize(this->sparsity()->numConstantDerivs());
00445 vectorResults.resize(this->sparsity()->numVectorDerivs());
00446
00447 SUNDANCE_MSG3(mgr.verb(),tabs << "num constant results: "
00448 << this->sparsity()->numConstantDerivs());
00449
00450 SUNDANCE_MSG3(mgr.verb(),tabs << "num var results: "
00451 << this->sparsity()->numVectorDerivs());
00452
00453 Array<RCP<Array<double> > > constantArgResults(numChildren());
00454 Array<RCP<Array<RCP<EvalVector> > > > varArgResults(numChildren());
00455
00456 Array<double> constantArgDerivs;
00457 Array<RCP<EvalVector> > varArgDerivs;
00458
00459 for (int i=0; i<numChildren(); i++)
00460 {
00461 Tabs tab1;
00462 SUNDANCE_MSG3(mgr.verb(), tab1 << "computing results for child #"
00463 << i);
00464
00465 constantArgResults[i] = rcp(new Array<double>());
00466 varArgResults[i] = rcp(new Array<RCP<EvalVector> >());
00467 childEvaluators_[i]->eval(mgr, *(constantArgResults[i]),
00468 *(varArgResults[i]));
00469 if (mgr.verb() > 3)
00470 {
00471 Out::os() << tabs << "constant arg #" << i <<
00472 " results:" << *(constantArgResults[i]) << std::endl;
00473 Out::os() << tabs << "variable arg #" << i << " derivs:" << std::endl;
00474 for (int j=0; j<varArgResults[i]->size(); j++)
00475 {
00476 Tabs tab1;
00477 Out::os() << tab1 << j << " ";
00478 (*(varArgResults[i]))[j]->print(Out::os());
00479 Out::os() << std::endl;
00480 }
00481 }
00482 }
00483
00484 evalArgDerivs(mgr, constantArgResults, varArgResults,
00485 constantArgDerivs, varArgDerivs);
00486
00487
00488 if (mgr.verb() > 2)
00489 {
00490 Out::os() << tabs << "constant arg derivs:" << constantArgDerivs << std::endl;
00491 Out::os() << tabs << "variable arg derivs:" << std::endl;
00492 for (int i=0; i<varArgDerivs.size(); i++)
00493 {
00494 Tabs tab1;
00495 Out::os() << tab1 << i << " ";
00496 varArgDerivs[i]->print(Out::os());
00497 Out::os() << std::endl;
00498 }
00499 }
00500
00501
00502 for (int i=0; i<expansions_.size(); i++)
00503 {
00504 Tabs tab1;
00505 int resultIndex = expansions_[i]->resultIndex();
00506 bool isConstant = expansions_[i]->resultIsConstant();
00507 SUNDANCE_MSG3(mgr.verb(), tab1 << "doing expansion for deriv #" << i
00508 << ", result index="
00509 << resultIndex << std::endl << tab1
00510 << "deriv=" << expansions_[i]->deriv());
00511 if (isConstant)
00512 {
00513 expansions_[i]->evalConstant(mgr, constantArgResults, constantArgDerivs,
00514 constantResults[resultIndex]);
00515 }
00516 else
00517 {
00518 expansions_[i]->evalVar(mgr, constantArgResults, varArgResults,
00519 constantArgDerivs, varArgDerivs,
00520 vectorResults[resultIndex]);
00521 }
00522 }
00523
00524 if (zerothDerivResultIndex_ >= 0)
00525 {
00526 SUNDANCE_MSG3(mgr.verb(), tabs << "processing zeroth-order deriv");
00527 Tabs tab1;
00528 SUNDANCE_MSG3(mgr.verb(), tab1 << "result index = " << zerothDerivResultIndex_);
00529 if (zerothDerivIsConstant_)
00530 {
00531 Tabs tab2;
00532 SUNDANCE_MSG3(mgr.verb(), tab2 << "zeroth-order deriv is constant");
00533 constantResults[zerothDerivResultIndex_] = constantArgDerivs[0];
00534 }
00535 else
00536 {
00537 Tabs tab2;
00538 SUNDANCE_MSG3(mgr.verb(), tab2 << "zeroth-order deriv is variable");
00539 vectorResults[zerothDerivResultIndex_] = varArgDerivs[0];
00540 }
00541 }
00542
00543
00544 if (mgr.verb() > 1)
00545 {
00546 Tabs tab1;
00547 Out::os() << tab1 << "chain rule results " << std::endl;
00548 mgr.showResults(Out::os(), this->sparsity(), vectorResults,
00549 constantResults);
00550 }
00551
00552 SUNDANCE_MSG1(mgr.verb(), tabs << "ChainRuleEvaluator::eval() done");
00553 }
00554
00555
00556
00557
00558 namespace Sundance {
00559
00560 MultipleDeriv makeDeriv(const Expr& a)
00561 {
00562 const UnknownFuncElement* aPtr
00563 = dynamic_cast<const UnknownFuncElement*>(a[0].ptr().get());
00564
00565 TEUCHOS_TEST_FOR_EXCEPT(aPtr==0);
00566
00567 Deriv d = funcDeriv(aPtr);
00568 MultipleDeriv rtn;
00569 rtn.put(d);
00570 return rtn;
00571 }
00572
00573
00574 MultipleDeriv makeDeriv(const Expr& a, const Expr& b)
00575 {
00576 const UnknownFuncElement* aPtr
00577 = dynamic_cast<const UnknownFuncElement*>(a[0].ptr().get());
00578
00579 TEUCHOS_TEST_FOR_EXCEPT(aPtr==0);
00580
00581 const UnknownFuncElement* bPtr
00582 = dynamic_cast<const UnknownFuncElement*>(b[0].ptr().get());
00583
00584 TEUCHOS_TEST_FOR_EXCEPT(bPtr==0);
00585
00586 Deriv da = funcDeriv(aPtr);
00587 Deriv db = funcDeriv(bPtr);
00588 MultipleDeriv rtn;
00589 rtn.put(da);
00590 rtn.put(db);
00591 return rtn;
00592 }
00593
00594
00595
00596 MultipleDeriv makeDeriv(const Expr& a, const Expr& b, const Expr& c)
00597 {
00598 const UnknownFuncElement* aPtr
00599 = dynamic_cast<const UnknownFuncElement*>(a[0].ptr().get());
00600
00601 TEUCHOS_TEST_FOR_EXCEPT(aPtr==0);
00602
00603 const UnknownFuncElement* bPtr
00604 = dynamic_cast<const UnknownFuncElement*>(b[0].ptr().get());
00605
00606 TEUCHOS_TEST_FOR_EXCEPT(bPtr==0);
00607
00608 const UnknownFuncElement* cPtr
00609 = dynamic_cast<const UnknownFuncElement*>(c[0].ptr().get());
00610
00611 TEUCHOS_TEST_FOR_EXCEPT(cPtr==0);
00612
00613 Deriv da = funcDeriv(aPtr);
00614 Deriv db = funcDeriv(bPtr);
00615 Deriv dc = funcDeriv(cPtr);
00616 MultipleDeriv rtn;
00617 rtn.put(da);
00618 rtn.put(db);
00619 rtn.put(dc);
00620 return rtn;
00621 }
00622
00623 }