From a6a935f152cde89bf1e0054d3941de59b255fc7e Mon Sep 17 00:00:00 2001
From: Jonathan Lambrechts <jonathan.lambrechts@uclouvain.be>
Date: Wed, 8 Dec 2010 12:34:17 +0000
Subject: [PATCH] change solutionFunction Mechanism

---
 Solver/function.cpp | 23 ++++++++++++++++++-----
 Solver/function.h   |  8 +++++++-
 2 files changed, 25 insertions(+), 6 deletions(-)

diff --git a/Solver/function.cpp b/Solver/function.cpp
index 8e73c08497..1ed574ab0e 100644
--- a/Solver/function.cpp
+++ b/Solver/function.cpp
@@ -229,8 +229,19 @@ void dataCacheDouble::_eval()
   _valid = true;
 }
 
-dataCacheDouble &dataCacheMap::get(const function *f, dataCacheDouble *caller)
+dataCacheDouble *dataCacheMap::get(const function *f, dataCacheDouble *caller, bool createIfNotPresent)
 {
+  //special case
+  if (f == function::getSolution()) {
+    f = _functionSolution;
+    if (f == NULL)
+      Msg::Error ("solution function has not been set");
+  } else if (f == function::getSolutionGradient()) {
+    f = _functionSolutionGradient;
+    if (f == NULL)
+      Msg::Error ("solution function gradient has not been set");
+  }
+
   // do I have a cache for this function ?
   dataCacheDouble *&r = _cacheDoubleMap[f];
   // can I use the cache of my parent ?
@@ -247,15 +258,17 @@ dataCacheDouble &dataCacheMap::get(const function *f, dataCacheDouble *caller)
       }
     }
     if (okFromParent)
-      r = &_parent->get (f,caller);
+      r = _parent->get (f,caller);
   }
   // no cache found, create a new one
   if (r==NULL) {
+    if (!createIfNotPresent) 
+      return NULL;
     r = new dataCacheDouble (this, (function*)(f));
     r->_directDependencies.resize (f->arguments.size());
     for (unsigned int i = 0; i < f->arguments.size(); i++) {
       r->_directDependencies[i] = 
-        &getSecondaryCache(f->arguments[i].iMap)->get(f->arguments[i].f, r);
+        getSecondaryCache(f->arguments[i].iMap)->get(f->arguments[i].f, r);
     }
     for (unsigned i = 0; i < f->_functionReplaces.size(); i++) {
       functionReplaceCache replaceCache;
@@ -274,7 +287,7 @@ dataCacheDouble &dataCacheMap::get(const function *f, dataCacheDouble *caller)
       }
       for (std::vector<function::argument>::iterator it = replace->_toCompute.begin();
            it!= replace->_toCompute.end(); it++ ) {
-        replaceCache.toCompute.push_back(&rMap->getSecondaryCache(it->iMap)->get(it->f, r));
+        replaceCache.toCompute.push_back(rMap->getSecondaryCache(it->iMap)->get(it->f, r));
       }
       replaceCache.map = rMap;
       r->functionReplaceCaches.push_back (replaceCache); 
@@ -292,7 +305,7 @@ dataCacheDouble &dataCacheMap::get(const function *f, dataCacheDouble *caller)
       caller->_iDependOn.insert(*it);
     }
   }
-  return *r;
+  return r;
 }
 
 // dataCacheMap
diff --git a/Solver/function.h b/Solver/function.h
index d2946bc029..bff835b5e0 100644
--- a/Solver/function.h
+++ b/Solver/function.h
@@ -201,6 +201,7 @@ class dataCacheDouble {
 };
 
 class dataCacheMap {
+  const function *_functionSolution, *_functionSolutionGradient;
  public:
   dataCacheMap  *_parent;
   std::list<dataCacheMap*> _children;
@@ -211,6 +212,7 @@ class dataCacheMap {
   std::set<dataCacheDouble*> _toInvalidateOnElement;
   MElement *_element;
   dataCacheMap() {
+    _functionSolution = _functionSolutionGradient = NULL;
     _nbEvaluationPoints = 0;
     _parent=NULL;
   }
@@ -236,7 +238,7 @@ class dataCacheMap {
   {
     _secondaryCaches.push_back(s);
   }
-  dataCacheDouble &get(const function *f, dataCacheDouble *caller=0);
+  dataCacheDouble *get(const function *f, dataCacheDouble *caller=0, bool createIfNotPresent = true);
   virtual void setElement(MElement *element)
   {
     _element=element;
@@ -257,6 +259,10 @@ class dataCacheMap {
     m->_nbEvaluationPoints = 0;
     return m;
   }
+  inline void setSolutionFunction(const function *functionSolution, const function *functionSolutionGradient) {
+    _functionSolution = functionSolution;
+    _functionSolutionGradient = functionSolutionGradient;
+  }
   void setNbEvaluationPoints(int nbEvaluationPoints);
   inline int getNbEvaluationPoints() { return _nbEvaluationPoints; }
 };
-- 
GitLab