diff --git a/dG3D/src/dG3DIPVariable.cpp b/dG3D/src/dG3DIPVariable.cpp
index 0eb550cf7ffa2c71c697315d8dd003095ce25680..a91ef710f2a2af5a264a317f9cc046e335e62419 100644
--- a/dG3D/src/dG3DIPVariable.cpp
+++ b/dG3D/src/dG3DIPVariable.cpp
@@ -1811,14 +1811,17 @@ double ANNBasedDG3DIPVariable::get(const int comp) const
 
 
 torchANNBasedDG3DIPVariable::torchANNBasedDG3DIPVariable(const int n, const bool createBodyForceHO, const bool oninter):
-                 dG3DIPVariable(createBodyForceHO, oninter), _n(n),restart_internalVars(1,n), _kinematicVariables(1,6)
+                 dG3DIPVariable(createBodyForceHO, oninter), _n(n),restart_internalVars(1,n), _kinematicVariables(1,6),
+                 _latticeParams(1,2)
 {
 #if defined(HAVE_TORCH)
    _internalVars = -1.0*torch::ones({1, 1, _n});
 #endif
+    _latticeParams(0,0) = 1.0;
+    _latticeParams(0,1) = 1.0;
 }
 torchANNBasedDG3DIPVariable::torchANNBasedDG3DIPVariable(const torchANNBasedDG3DIPVariable &source):dG3DIPVariable(source),
-  _internalVars(source._internalVars), _kinematicVariables(source._kinematicVariables){};
+  _internalVars(source._internalVars), _kinematicVariables(source._kinematicVariables), _latticeParams(source._latticeParams){};
 torchANNBasedDG3DIPVariable& torchANNBasedDG3DIPVariable::operator =(const IPVariable& src)
 {
   dG3DIPVariable::operator=(src);
@@ -1828,6 +1831,7 @@ torchANNBasedDG3DIPVariable& torchANNBasedDG3DIPVariable::operator =(const IPVar
   {
     _internalVars = psrc->_internalVars;
     _kinematicVariables = psrc->_kinematicVariables;
+    _latticeParams = psrc->_latticeParams;
   }
 #else
   Msg::Error("NOT COMPILED WITH TORCH");
diff --git a/dG3D/src/dG3DIPVariable.h b/dG3D/src/dG3DIPVariable.h
index 016efa98814b6fc3917c75e598fb7a560cd2ddc0..6e60c0edf78658caf99045b2ad2e2b02899d004c 100644
--- a/dG3D/src/dG3DIPVariable.h
+++ b/dG3D/src/dG3DIPVariable.h
@@ -2242,6 +2242,7 @@ class torchANNBasedDG3DIPVariable :public dG3DIPVariable
 #endif
     fullMatrix<double> _kinematicVariables; // kinematic variable
     fullMatrix<double> restart_internalVars; // internal variable
+    fullMatrix<double> _latticeParams;
 
   public:
     torchANNBasedDG3DIPVariable(const int n, const bool createBodyForceHO,  const bool oninter);
diff --git a/dG3D/src/dG3DMaterialLaw.cpp b/dG3D/src/dG3DMaterialLaw.cpp
index d271db0cdc8300a0bc05106c0765dd62321c5910..e45f9aa9176f5e739ff56b9cb74444bce7588acb 100644
--- a/dG3D/src/dG3DMaterialLaw.cpp
+++ b/dG3D/src/dG3DMaterialLaw.cpp
@@ -1930,7 +1930,23 @@ torchANNBasedDG3DMaterialLaw::torchANNBasedDG3DMaterialLaw(const int num, const
         VYY[0][0][2] = 1.0;
         VZZ[0][0][3] = 1.0;
      }
-     else if(_numberOfInput == 6){ // for 2D case
+     else if(_numberOfInput == 6){ // for 3D case
+        VXX = torch::zeros({1, 1, 6});
+        VXY = torch::zeros({1, 1, 6});
+        VYY = torch::zeros({1, 1, 6});
+        VYZ = torch::zeros({1, 1, 6});
+        VZZ = torch::zeros({1, 1, 6});
+        VZX = torch::zeros({1, 1, 6});
+
+        VXX[0][0][0] = 1.0;
+        VYY[0][0][1] = 1.0;
+        VZZ[0][0][2] = 1.0;
+        VXY[0][0][3] = 1.0;
+        VZX[0][0][4] = 1.0;
+        VYZ[0][0][5] = 1.0;
+     }
+
+     else if(_numberOfInput == 9){ // for 3D case with Lattice Params @Mohib
         VXX = torch::zeros({1, 1, 6});
         VXY = torch::zeros({1, 1, 6});
         VYY = torch::zeros({1, 1, 6});
@@ -1993,7 +2009,15 @@ torchANNBasedDG3DMaterialLaw::torchANNBasedDG3DMaterialLaw(const torchANNBasedDG
         VYY = src.VYY;
         VZZ = src.VZZ;
        }
-       else if(_numberOfInput == 6){ // for 2D case
+       else if(_numberOfInput == 6){ // for 3D case
+        VXX = src.VXX;
+        VXY = src.VXY;
+        VYY = src.VYY;
+        VYZ = src.VYZ;
+        VZZ = src.VZZ;
+        VZX = src.VZX;
+       }
+       else if(_numberOfInput == 9){ // for 3D case with Lattice Params @Mohib
         VXX = src.VXX;
         VXY = src.VXY;
         VYY = src.VYY;
@@ -2032,34 +2056,71 @@ void torchANNBasedDG3DMaterialLaw::setKinematicInput(const int i)
   }
 }
 
+
+void torchANNBasedDG3DMaterialLaw::setLatticeParameters(const double radius, const double size){
+    _latticeRadius = radius;
+    _latticeSize = size;
+}
+
 void torchANNBasedDG3DMaterialLaw::initLaws(const std::map<int,materialLaw*> &maplaw){
 #if defined(HAVE_TORCH)
     if (!_initialized){
-        fullMatrix<double> E1(1,6), S(1,6);
+        fullMatrix<double> E1(1,6), S(1,6), Geo1(1, 2);
         torch::Tensor h_init = -1.0*torch::ones({1, 1, _numberOfInternalVariables});
         static torch::Tensor h_tmp = torch::zeros({1, 1, _numberOfInternalVariables});
         fullMatrix<double> DSDE(6,6);
 
-       if (_tangentByPerturbation){
-           static fullMatrix<double> E1_plus(1,6), S_plus(1,6);
-          for (int i=0; i< 6; i++)
+        //Read from torch Material Definition @Mohib
+        Geo1(0,0) = _latticeRadius;
+        Geo1(0,1) = _latticeSize;
+
+        if (_tangentByPerturbation){
+           //TODO: Refractor this for Lattice @Mohib
+          static fullMatrix<double> E1_plus(1,6), S_plus(1,6);
+          if(_numberOfInput==6 || _numberOfInput==3)
           {
-             E1_plus = E1;
-             E1_plus(0,i) += _pertTol;
-             RNNstress_stiff(E1_plus, h_init, h_tmp, S_plus, false, DSDE);
-             for (int j=0; j<6; j++)
-             {
-               DSDE(j,i) = (S_plus(0,j) - S(0,j))/_pertTol;
-             }
+              for (int i=0; i< 6; i++)
+              {
+                 E1_plus = E1;
+                 E1_plus(0,i) += _pertTol;
+                 RNNstress_stiff(E1_plus, h_init, h_tmp, S_plus, false, DSDE);
+                 for (int j=0; j<6; j++)
+                 {
+                   DSDE(j,i) = (S_plus(0,j) - S(0,j))/_pertTol;
+                 }
+              }
           }
-      }
-      else
-      {
-          RNNstress_stiff(E1, h_init, h_tmp, S, true, DSDE);
-      }
+          if(_numberOfInput==9)
+          {
+              for (int i=0; i< 6; i++)
+              {
+                 E1_plus = E1;
+                 E1_plus(0,i) += _pertTol;
+                 RNNstressGeo_stiff(E1_plus, Geo1, h_init, h_tmp, S_plus, false, DSDE);
+                 for (int j=0; j<6; j++)
+                 {
+                   DSDE(j,i) = (S_plus(0,j) - S(0,j))/_pertTol;
+                 }
+              }
 
-      convertFullMatrixToSTensor43(DSDE,elasticStiffness);
-      _initialized = true;
+          }
+
+        }
+        else
+        {
+            if(_numberOfInput==6 || _numberOfInput==3)
+            {
+                RNNstress_stiff(E1, h_init, h_tmp, S, true, DSDE);
+            }
+            if(_numberOfInput==9)
+            {
+                RNNstressGeo_stiff(E1, Geo1, h_init, h_tmp, S, true, DSDE);
+            }
+
+        }
+
+        convertFullMatrixToSTensor43(DSDE,elasticStiffness);
+        _initialized = true;
     };
 #else
   Msg::Error("NOT COMPILED WITH TORCH");
@@ -2137,30 +2198,67 @@ void torchANNBasedDG3DMaterialLaw::stress(IPVariable* ipv, const IPVariable* ipv
     Msg::Error("kinematic type %d has not been defined",_kinematicInput);
   }
 
+  // Construct Lattice geometry input for torch model @Mohib
+  fullMatrix<double> Geo1(1, 2);
+
+  Geo1(0, 0) = _latticeRadius;
+  Geo1(0, 1) = _latticeSize;
+
   static fullMatrix<double> S(1,6), DSDE(6,6);
   const torch::Tensor& h0 = ipvprev->getConstRefToInternalVariables();
   torch::Tensor& h1 = ipvcur->getRefToInternalVariables();
 
   if (stiff && _tangentByPerturbation)
   {
-    RNNstress_stiff(E1, h0, h1, S, false, DSDE);
-    static fullMatrix<double> E1_plus(1,6), S_plus(1,6);
-    static torch::Tensor h_tmp = torch::zeros({1, 1, _numberOfInternalVariables});
 
-    for (int i=0; i< 6; i++)
-    {
-      E1_plus = E1;
-      E1_plus(0,i) += _pertTol;
-      RNNstress_stiff(E1_plus, h0, h_tmp, S_plus, false, DSDE);
-      for (int j=0; j<6; j++)
+    if(_numberOfInput==6 || _numberOfInput==3)
       {
-        DSDE(j,i) = (S_plus(0,j) - S(0,j))/_pertTol;
+        RNNstress_stiff(E1, h0, h1, S, false, DSDE);
+        static fullMatrix<double> E1_plus(1,6), S_plus(1,6);
+        static torch::Tensor h_tmp = torch::zeros({1, 1, _numberOfInternalVariables});
+
+        for (int i=0; i< 6; i++)
+        {
+          E1_plus = E1;
+          E1_plus(0,i) += _pertTol;
+          RNNstress_stiff(E1_plus, h0, h_tmp, S_plus, false, DSDE);
+          for (int j=0; j<6; j++)
+          {
+            DSDE(j,i) = (S_plus(0,j) - S(0,j))/_pertTol;
+          }
+        }
+      }
+      else if(_numberOfInput==9)
+      {
+        RNNstressGeo_stiff(E1, Geo1, h0, h1, S, false, DSDE);
+        S.print("original");
+        static fullMatrix<double> E1_plus(1,6), S_plus(1,6);
+        static torch::Tensor h_tmp = torch::zeros({1, 1, _numberOfInternalVariables});
+
+        for (int i=0; i< 6; i++)
+        {
+          E1_plus = E1;
+          E1_plus(0,i) += _pertTol;
+          RNNstressGeo_stiff(E1_plus, Geo1, h0, h_tmp, S_plus, false, DSDE);
+          S_plus.print("perturbated");
+          for (int j=0; j<6; j++)
+          {
+            DSDE(j,i) = (S_plus(0,j) - S(0,j))/_pertTol;
+          }
+        }
+
       }
-    }
   }
   else
   {
-       RNNstress_stiff(E1, h0, h1, S, true, DSDE);
+      if(_numberOfInput==3 || _numberOfInput==6){
+          RNNstress_stiff(E1, h0, h1, S, true, DSDE);
+      }
+      else
+      {
+        RNNstressGeo_stiff(E1, Geo1, h0, h1, S, true, DSDE);
+      }
+
   }
 
   if (_kinematicInput == smallStrain)
@@ -2373,6 +2471,181 @@ void torchANNBasedDG3DMaterialLaw::RNNstress_stiff(const fullMatrix<double>& E1,
 }
 #endif
 
+#if defined(HAVE_TORCH)
+void torchANNBasedDG3DMaterialLaw::RNNstressGeo_stiff(const fullMatrix<double>& E1, const fullMatrix<double>& Geo1,const torch::Tensor& h0, torch::Tensor& h1,
+                                      fullMatrix<double>& S, const bool stiff, fullMatrix<double>& DSDE)
+{
+
+   static vector<float> E_vec(_numberOfInput - 2);
+   static vector<float> Geo_vec(2);
+   static vector<float> EG_vec;
+   static torch::Tensor E_norm;
+   static torch::Tensor Geo_norm;
+   static torch::Tensor EG_norm;
+   // use RNN to predict stress-------------
+   vector<torch::jit::IValue> inputs;
+
+   Normalize_strain(E1, E_vec);
+   E_norm = torch::from_blob(E_vec.data(), {1,1, _numberOfInput - 2}, torch::requires_grad());
+
+   Normalize_geo(Geo1, Geo_vec);
+   Geo_norm = torch::from_blob(Geo_vec.data(), {1,1,2}, torch::requires_grad());
+
+   EG_vec.insert(EG_vec.end(), Geo_vec.begin(), Geo_vec.end());
+   EG_vec.insert(EG_vec.end(), E_vec.begin(), E_vec.end());
+   EG_norm = torch::from_blob(EG_vec.data(), {1,1, _numberOfInput}, torch::requires_grad());
+
+//   inputs.push_back(Geo_norm);
+//   inputs.push_back(E_norm);
+   inputs.push_back(EG_norm);
+   inputs.push_back(h0);
+
+   auto outputs= module.forward(inputs).toTuple();
+   torch::Tensor S_norm = outputs->elements()[0].toTensor();
+   h1 = outputs->elements()[1].toTensor();
+
+   InverseNormalize_stress(S_norm, S);
+
+   if(stiff){
+     for (int i=0; i<6; i++){
+        for (int j=0; j<6; j++){
+           DSDE(i,j) = 0.0;
+        }
+     }
+     if(_numberOfInput == 3){
+         S_norm.backward(VXX,true);                   //  dS/dE
+         auto EnormGrad_a = E_norm.grad().accessor<float,3>();
+         DSDE(0,0) = EnormGrad_a[0][0][0];
+         DSDE(0,1) = EnormGrad_a[0][0][2];
+         DSDE(0,3) = EnormGrad_a[0][0][1];
+         E_norm.grad().zero_();
+         S_norm.backward(VYY,true);              //  dS/dE
+         DSDE(1,0) = EnormGrad_a[0][0][0];
+         DSDE(1,1) = EnormGrad_a[0][0][2];
+         DSDE(1,3) = EnormGrad_a[0][0][1];
+         E_norm.grad().zero_();
+         S_norm.backward(VZZ,true);                   //  dS/dE
+         DSDE(2,0) = EnormGrad_a[0][0][0];
+         DSDE(2,1) = EnormGrad_a[0][0][2];
+         DSDE(2,3) = EnormGrad_a[0][0][1];
+         E_norm.grad().zero_();
+         S_norm.backward(VXY,true);                    //  dS/dE
+         DSDE(3,0) = EnormGrad_a[0][0][0];
+         DSDE(3,1) = EnormGrad_a[0][0][2];
+         DSDE(3,3) = EnormGrad_a[0][0][1];
+         E_norm.grad().zero_();
+
+       }
+       else if(_numberOfInput >= 6){
+         S_norm.backward(VXX,true);                   //  dS/dE
+         auto EnormGrad_a = EG_norm.grad().accessor<float,3>();
+//         DSDE(0,0) = EnormGrad_a[0][0][12];
+//         DSDE(0,1) = EnormGrad_a[0][0][14];
+//         DSDE(0,2) = EnormGrad_a[0][0][20];
+//         DSDE(0,3) = EnormGrad_a[0][0][24];
+//         DSDE(0,4) = EnormGrad_a[0][0][30];
+//         DSDE(0,5) = EnormGrad_a[0][0][32];
+         for (int i=0; i<6; i++){
+             DSDE(0,i) = EnormGrad_a[0][0][i];
+//             printf("%3.5f \n", EnormGrad_a[0][0][0]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][1]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][2]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][3]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][4]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][5]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][6]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][7]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][8]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][9]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][10]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][11]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][12]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][13]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][14]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][15]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][16]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][17]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][18]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][19]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][20]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][21]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][22]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][23]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][24]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][25]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][26]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][27]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][28]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][29]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][30]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][31]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][32]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][33]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][34]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][35]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][36]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][37]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][38]);
+//             printf("%3.5f \n", EnormGrad_a[0][0][39]);
+//
+//             printf("\n---------------\n");
+         }
+         EG_norm.grad().zero_();
+         S_norm.backward(VYY,true);                   //  dS/dE
+         for (int i=0; i<6; i++){
+             DSDE(1,i) = EnormGrad_a[0][0][i+12];
+         }
+         EG_norm.grad().zero_();
+         S_norm.backward(VZZ,true);                   //  dS/dE
+         for (int i=0; i<6; i++){
+             DSDE(2,i) = EnormGrad_a[0][0][i+12];
+         }
+         EG_norm.grad().zero_();
+         S_norm.backward(VXY,true);                   //  dS/dE
+         for (int i=0; i<6; i++){
+             DSDE(3,i) = EnormGrad_a[0][0][i+12];
+         }
+
+         EG_norm.grad().zero_();
+         S_norm.backward(VZX,true);                   //  dS/dE
+         for (int i=0; i<6; i++){
+             DSDE(4,i) = EnormGrad_a[0][0][i+12];
+         }
+         EG_norm.grad().zero_();
+         S_norm.backward(VYZ,true);                   //  dS/dE
+         for (int i=0; i<6; i++){
+             DSDE(5,i) = EnormGrad_a[0][0][i+1];
+         }
+         EG_norm.grad().zero_();
+      }
+
+
+      for (int i=0; i<6; i++){
+         DSDE(i,0) = DSDE(i,0)/_EXXstd;
+         DSDE(i,1) = DSDE(i,1)/_EYYstd;
+         DSDE(i,2) = DSDE(i,2)/_EZZstd;
+         DSDE(i,3) = DSDE(i,3)/_EXYstd;
+         DSDE(i,4) = DSDE(i,4)/_EZXstd;
+         DSDE(i,5) = DSDE(i,5)/_EYZstd;
+      }
+      for (int j=0; j<6; j++){
+         DSDE(0,j) = DSDE(0,j)*_SXXstd;
+         DSDE(1,j) = DSDE(1,j)*_SYYstd;
+         DSDE(2,j) = DSDE(2,j)*_SZZstd;
+         DSDE(3,j) = DSDE(3,j)*_SXYstd;
+         DSDE(4,j) = DSDE(4,j)*_SZXstd;
+         DSDE(5,j) = DSDE(5,j)*_SYZstd;
+      }
+
+
+   /*    for( i=0;i<6; i++){
+        cout << DSDE(i,0)<< " " << DSDE(i,1)<<" " << DSDE(i,2)<<" " << DSDE(i,3)<<" " << DSDE(i,4)<<" " << DSDE(i,5)<< endl;
+       }
+         cout << "--------------------------------- "<< endl;*/
+   }
+}
+#endif
+
 void torchANNBasedDG3DMaterialLaw::convertFullMatrixToSTensor43(const fullMatrix<double>& DSDE, STensor43& DsecondPKDE) const
 {
   if (DSDE.size1() != 6 or DSDE.size2() != 6)
@@ -2478,7 +2751,7 @@ void torchANNBasedDG3DMaterialLaw::Normalize_strain(const fullMatrix<double>& E1
         E_norm[1] =  (E1(0,3)- _EXYmean)/_EXYstd;
         E_norm[2] =  (E1(0,1)- _EYYmean)/_EYYstd;
     }
-    else if(_numberOfInput == 6){
+    else if(_numberOfInput >= 6){
         E_norm[0] =  (E1(0,0)- _EXXmean)/_EXXstd;
         E_norm[1] =  (E1(0,1)- _EYYmean)/_EYYstd;
         E_norm[2] =  (E1(0,2)- _EZZmean)/_EZZstd;
@@ -2487,6 +2760,23 @@ void torchANNBasedDG3DMaterialLaw::Normalize_strain(const fullMatrix<double>& E1
         E_norm[5] =  (E1(0,5)- _EYZmean)/_EYZstd;
     }
 };
+
+void torchANNBasedDG3DMaterialLaw::Normalize_geo(const fullMatrix<double>& Geo1, vector<float>& Geo_norm) const
+{
+
+    Geo_norm[0] = (Geo1(0, 0) - _Radiusmean) / _Radiusstd;
+    // TODO: Remove Hard Code @ Mohib
+    if(_CellSizestd == 0.0)
+    {
+        Geo_norm[1] = Geo1(0, 1) - _CellSizemean;
+    }
+    else
+    {
+        Geo_norm[1] = Geo1(0, 1) - _CellSizemean;
+    }
+
+};
+
 #if defined(HAVE_TORCH)
 void torchANNBasedDG3DMaterialLaw::InverseNormalize_stress(const torch::Tensor& S_norm, fullMatrix<double>& S) const
 {
@@ -2499,7 +2789,7 @@ void torchANNBasedDG3DMaterialLaw::InverseNormalize_stress(const torch::Tensor&
         S(0,4) =  0.0*_SZXstd + _SZXmean;
         S(0,5) =  0.0*_SYZstd + _SYZmean;
     }
-    else if(_numberOfInput == 6){
+    else if(_numberOfInput >= 6){
         S(0,0) =  Snorm_a[0][0][0]*_SXXstd + _SXXmean;
         S(0,1) =  Snorm_a[0][0][1]*_SYYstd + _SYYmean;
         S(0,2) =  Snorm_a[0][0][2]*_SZZstd + _SZZmean;
diff --git a/dG3D/src/dG3DMaterialLaw.h b/dG3D/src/dG3DMaterialLaw.h
index b327c39a32956434bfb50b77a2f95afae4965df4..c07e01cf918537fe271ed8a54e3ecd3a34980c1b 100644
--- a/dG3D/src/dG3DMaterialLaw.h
+++ b/dG3D/src/dG3DMaterialLaw.h
@@ -473,6 +473,10 @@ class torchANNBasedDG3DMaterialLaw : public dG3DMaterialLaw{
     #ifndef SWIG
     int _numberOfInput;
     int _numberOfInternalVariables;
+
+    // Lattice Parameters for Optimizer IO @Mohib
+    double _latticeRadius;
+    double _latticeSize;
 #if defined(HAVE_TORCH)
     torch::jit::script::Module module;
 #endif
@@ -546,6 +550,9 @@ class torchANNBasedDG3DMaterialLaw : public dG3DMaterialLaw{
                                      bool pert=false, double tol = 1e-5);
 
 		void setKinematicInput(const int i);
+
+        // Setter initializer for Lattice Parameters @Mohib
+        void setLatticeParameters(const double radius, const double size);
     #ifndef SWIG
 		torchANNBasedDG3DMaterialLaw(const torchANNBasedDG3DMaterialLaw& src);
 		virtual ~torchANNBasedDG3DMaterialLaw();
@@ -575,9 +582,11 @@ class torchANNBasedDG3DMaterialLaw : public dG3DMaterialLaw{
   private:
     void convertFullMatrixToSTensor43(const fullMatrix<double>& DSDE, STensor43& K) const;
     void Normalize_strain(const fullMatrix<double>& E1, vector<float>& E_norm) const;
+    void Normalize_geo(const fullMatrix<double>& Geo1, vector<float>& Geo_norm) const;
 #if defined(HAVE_TORCH)
     void InverseNormalize_stress(const torch::Tensor& S_norm, fullMatrix<double>& S) const;
     void RNNstress_stiff(const fullMatrix<double>& E1, const torch::Tensor& h0, torch::Tensor& h1, fullMatrix<double>& S, const bool stiff, fullMatrix<double>& DSDE);
+    void RNNstressGeo_stiff(const fullMatrix<double>& E1, const fullMatrix<double>& Geo1, const torch::Tensor& h0, torch::Tensor& h1, fullMatrix<double>& S, const bool stiff, fullMatrix<double>& DSDE);
 #endif
     #endif //SWIG