diff --git a/NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp b/NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp
index 63577897ccc0631d22e9532a0c7b38d8e8a77e7f..a9a5f35f4397ecce35dbd620b72591c4e49eb495 100644
--- a/NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp
+++ b/NonLinearSolver/modelReduction/DeepMaterialNetworks.cpp
@@ -2544,6 +2544,12 @@ DeepMaterialNetwork::DeepMaterialNetwork(const Tree& T): _T(&T), _isInitialized(
   #if defined(HAVE_PETSC) || defined(HAVE_GMM)
   _sys = NULL;
   #endif //
+  
+  if (_T->checkDuplicate())
+  {
+    Msg::Error("Duplicating node exists, the computation cannot continue !!!");
+    Msg::Exit(0);
+  }
 };
 
 
diff --git a/NonLinearSolver/modelReduction/Tree.cpp b/NonLinearSolver/modelReduction/Tree.cpp
index fa1b4f61447e3731f9b239786610c66c7c1b6642..0460ae75ced60f8485b76738ac700792093a13df 100644
--- a/NonLinearSolver/modelReduction/Tree.cpp
+++ b/NonLinearSolver/modelReduction/Tree.cpp
@@ -958,6 +958,34 @@ int Tree::getNumberOfNodes() const
   return allNodes.size();
 };
 
+bool Tree::checkDuplicate() const
+{
+  bool ok = false;
+  nodeContainer allNodes;
+  getAllNodes(allNodes);
+  std::set<int> nodeId;
+  for (int i=0; i< allNodes.size(); i++)
+  {
+    const TreeNode* node = allNodes[i];
+    std::set<int>::const_iterator itF = nodeId.find(node->getNodeId());
+    if (itF == nodeId.end())
+    {
+      nodeId.insert(node->getNodeId());
+    }
+    else
+    {
+      Msg::Error("node duplicating:");
+      node->printData();
+      ok = true;
+    }
+  }
+  if (!ok)
+  {
+    Msg::Info("no duplicatint node is found !!!");
+  }
+  return ok;
+}
+
 bool Tree::removeZeroLeaves(double tol, bool iteration)
 {
   bool ok = false;
@@ -1020,88 +1048,49 @@ bool Tree::removeZeroLeaves(double tol, bool iteration)
       {
         Msg::Info("----------------------------------");
         parent->printData();
-        for (int i=0; i< parent->childs.size(); i++)
-        {
-          parent->childs[i]->printData();
-        }
-        
-        std::vector<TreeNode*> newChilds;
-        bool found = false;
-        int removePos = -1;
-        for (int i=0; i< parent->childs.size(); i++)
-        {
-          if (parent->childs[i] == leaf)
-          {
-            found = true;
-            removePos = i;
-          }
-          if (!found)
-          {
-            newChilds.push_back(parent->childs[i]);
-          }
-          else
-          {
-            for (int j=i+1; j<parent->childs.size(); j++)
-            {
-              TreeNode* n = parent->childs[j];
-              n->childOrder = j-1;
-              newChilds.push_back(n);
-            };
-            break;
-          }
-        };
-        
         //
+        std::vector<TreeNode*> newChilds;
         std::vector<double> newDirection;
-        if (removePos > -1)
+        //
+        int numberNormal = parent->childs.size()-1;
+        int totalNumberDirVars = parent->direction.size();
+        int numberVarsPerNormal = totalNumberDirVars/numberNormal;
+        
+        for (int j=0; j< parent->childs.size(); j++)
         {
-          int numberNormal = parent->childs.size()-1;
-          int totalNumberDirVars = parent->direction.size();
-          int numberVarsPerNormal = totalNumberDirVars/numberNormal;
-          //
-          if (removePos == 0)
-          {
-            for (int j=numberVarsPerNormal; j< parent->direction.size(); j++)
-            {
-              newDirection.push_back(parent->direction[j]);
-            }
-          }
-          else if (removePos == parent->childs.size()-1)
+          if (parent->childs[j]->location != leaf->location)
           {
-            for (int j=0; j< parent->direction.size()-numberVarsPerNormal; j++)
-            {
-              newDirection.push_back(parent->direction[j]);
-            }
-          }
-          else
-          {
-            for (int j=0; j< (removePos-1)*numberVarsPerNormal; j++)
-            {
-              newDirection.push_back(parent->direction[j]);
-            }
-            for (int j=(removePos)*numberVarsPerNormal; j< parent->direction.size(); j++)
+            newChilds.push_back(parent->childs[j]);
+            if (newChilds.size() > 1)
             {
-              newDirection.push_back(parent->direction[j]);
+              if (numberVarsPerNormal == 1)
+              {
+                newDirection.push_back(parent->direction[j-1]);
+              }
+              else if (numberVarsPerNormal == 2)
+              {
+                newDirection.push_back(parent->direction[(j-1)*numberVarsPerNormal]);
+                newDirection.push_back(parent->direction[(j-1)*numberVarsPerNormal+1]);
+              }
+              else
+              {
+                Msg::Error("numberVarsPerNormal = %d is not correct !!!",numberVarsPerNormal);
+                Msg::Exit(0);
+              }
             }
           }
-          Msg::Info("old size = %d, new size = %d",parent->direction.size(),newDirection.size());
         }
-        else
-        {
-          Msg::Error("leaf is not removed correctly");
-          Msg::Exit(0);
-        };
-        
-        parent->childs = newChilds;
-        parent->direction = newDirection;
-       
-        Msg::Info("after removing");
+        Msg::Info("\nSTART removing ");
         parent->printData();
-        for (int i=0; i< parent->childs.size(); i++)
+        parent->childs = newChilds;
+        for (int j=0; j <parent->childs.size(); j++)
         {
-          parent->childs[i]->printData();
+          parent->childs[j]->childOrder = j;
         }
-        Msg::Info("----------------------------------");
+        parent->direction = newDirection;   
+        Msg::Info("------------------");
+        parent->printData();     
+        Msg::Info("DONE removing\n");
       }
       else
       {
diff --git a/NonLinearSolver/modelReduction/Tree.h b/NonLinearSolver/modelReduction/Tree.h
index 3087d7297f76d3d67b9ec5ec558ef9d328513513..20610abe277fb260a7af7bda56618a14c1ab5866 100644
--- a/NonLinearSolver/modelReduction/Tree.h
+++ b/NonLinearSolver/modelReduction/Tree.h
@@ -72,6 +72,7 @@ class Tree
     void saveDataToFile(std::string filename) const;
     void loadDataFromFile(std::string filename);
     int getNumberOfNodes() const;
+    bool checkDuplicate() const;
     double getPhaseFraction(int i) const;
     void printPhaseFraction() const;
     void printLeafFraction() const;