From b808c6b09461d18259db8e4a8dc68a5c0ef6c580 Mon Sep 17 00:00:00 2001
From: Van Dung NGUYEN <vdg.nguyen@gmail.com>
Date: Sun, 8 Nov 2020 10:12:15 +0100
Subject: [PATCH] add to

---
 .../modelReduction/DeepMaterialNetworks.cpp   | 20 +++++++++++++++----
 1 file changed, 16 insertions(+), 4 deletions(-)

diff --git a/NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp b/NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp
index ca01520d3..9ff495bae 100644
--- a/NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp
+++ b/NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp
@@ -1616,7 +1616,10 @@ void TrainingDeepMaterialNetwork::getKeys(const TreeNode* node, std::vector<Dof>
   }
   else
   {
-    for (int i=0; i< node->direction.size(); i++)
+    int numberNormal = node->childs.size()-1;
+    int totalNumberDirVars = node->direction.size();
+    int numberVarsPerNormal = totalNumberDirVars/numberNormal;
+    for (int i=0; i< numberVarsPerNormal; i++)
     {
       keys.push_back(Dof(node->location, Dof::createTypeWithTwoInts(i, node->depth+1)));
     }
@@ -1738,9 +1741,15 @@ void TrainingDeepMaterialNetwork::updateFieldFromUnknown(const fullVector<double
     }
     else
     {
-      for (int j=0; j< node.direction.size(); j++)
+      int numberNormal = node.childs.size()-1;
+      int totalNumberDirVars = node.direction.size();
+      int numberVarsPerNormal = totalNumberDirVars/numberNormal;
+      for (int j=0; j< numberNormal; j++)
       {
-        node.direction[j] = val(j);
+        for (int k=0; k< numberVarsPerNormal; k++)
+        {
+          node.direction[j*numberVarsPerNormal+k] = val(k);
+        }
       }
     };
   };
@@ -1766,7 +1775,10 @@ void TrainingDeepMaterialNetwork::updateFieldFromUnknown(const fullVector<double
 void TrainingDeepMaterialNetwork::getNormal(const TreeNode* node, SVector3& vec, bool stiff, std::vector<SVector3>* DnormalDunknown) const
 {
   double pi = 3.14159265359;
-  int ncomp = node->direction.size(); 
+  int numberNormal = node.childs.size()-1;
+  int totalNumberDirVars = node.direction.size();
+  int ncomp = totalNumberDirVars/numberNormal;
+
   if (ncomp == 1)
   {
     double theta = 2.*pi*node->direction[0];
-- 
GitLab