From c3f0d7d9aa4fa0680d88da7e8403a11ffd5f334e Mon Sep 17 00:00:00 2001
From: Van Dung Nguyen <vdg.nguyen@gmail.com>
Date: Wed, 28 Oct 2020 14:41:10 +0100
Subject: [PATCH] correct

---
 NonLinearSolver/modelReduction/ANNUtils.h               | 7 ++++---
 NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp | 5 +++++
 NonLinearSolver/modelReduction/Tree.cpp                 | 3 ++-
 3 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/NonLinearSolver/modelReduction/ANNUtils.h b/NonLinearSolver/modelReduction/ANNUtils.h
index a113ff127..1620de713 100644
--- a/NonLinearSolver/modelReduction/ANNUtils.h
+++ b/NonLinearSolver/modelReduction/ANNUtils.h
@@ -65,14 +65,15 @@ class RectifierActivationFunction: public activationFunction
 {
   protected:
     double _fact;
+    double _offset;
   public:
-    RectifierActivationFunction(double f=10.): activationFunction(),_fact(f){}
+    RectifierActivationFunction(double f=10.): activationFunction(),_fact(f),_offset(0.){}
     RectifierActivationFunction(const RectifierActivationFunction& src): activationFunction(src){}
     virtual ~RectifierActivationFunction(){}
     virtual activationFunction::functionType getType() const {return activationFunction::rectifier;};
     virtual std::string getName() const;
-    virtual double getVal(double x) const  {return (log(1.+exp(_fact*x)))/_fact;};
-    virtual double getReciprocalVal(double y) const {return (log(exp(_fact*y)-1.))/_fact;}; // inverse function
+    virtual double getVal(double x) const  {return (_offset+log(1.+exp(_fact*x)))/_fact;};
+    virtual double getReciprocalVal(double y) const {return (log(exp(_fact*(y-_offset))-1.))/_fact;}; // inverse function
     virtual double getDiff(double x) const {return exp(_fact*x)/(1.+exp(_fact*x));};
     virtual activationFunction* clone() const {return new RectifierActivationFunction(*this);};
 };
diff --git a/NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp b/NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp
index 1fdead313..f990a778e 100644
--- a/NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp
+++ b/NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp
@@ -1885,8 +1885,10 @@ void TrainingDeepMaterialNetwork::train(double lr, int maxEpoch, std::string los
     
     if (removeZeroContribution)
     {
+      
       if ((epoch+1) %numStepRemove==0)
       {
+        double costfuncPrev_pre = evaluateTrainingSet(*lossEval);
         Msg::Info("tree removing");
         bool removed =  _T->removeLeavesWithZeroContribution(tolRemove);
         if (removed)
@@ -1899,7 +1901,10 @@ void TrainingDeepMaterialNetwork::train(double lr, int maxEpoch, std::string los
           Wprev = Wcur;
           g.resize(numDof,true);
         }
+        costfuncPrev = evaluateTrainingSet(*lossEval);
+        Msg::Info("pre-removing costfuncPrev = %e and after removing costfuncPrev = %e",costfuncPrev_pre,costfuncPrev);
       }
+      
     }
     
     epoch++;
diff --git a/NonLinearSolver/modelReduction/Tree.cpp b/NonLinearSolver/modelReduction/Tree.cpp
index 3f178125e..c41223e04 100644
--- a/NonLinearSolver/modelReduction/Tree.cpp
+++ b/NonLinearSolver/modelReduction/Tree.cpp
@@ -1904,7 +1904,8 @@ void Tree::initialize(bool rand)
   for (int i=0; i< allLeaves.size(); i++)
   {
     TreeNode* n = allLeaves[i];
-    n->weight *= (1./toltalWeightLeaves);
+    double vv = (n->af->getVal(n->weight))/toltalWeightLeaves;
+    n->weight *= n->af->getReciprocalVal(vv);
     // propagate data
     double w = n->weight;
     while (n->parent != NULL)
-- 
GitLab