From b7d956b527febeb3b3657834ea0374461519f57e Mon Sep 17 00:00:00 2001
From: Mohib <mohib.mustafa@gmail.com>
Date: Tue, 9 May 2023 17:05:33 +0200
Subject: [PATCH] [BugFix - Cleanup] - changes in torchbasedANNmaterial

- Removed custom class initaitor, instead replaced with a setter that
  keeps additional inputs

- Replaced additional inputs with a runtime dynamic container aka
  std::vector

- Contains 4 Class props of type double that store initail values, will
  be removed in a later commmit

- implemented child virtual class func, setTime that accesses time at GP

- renamed some containers for readability

- removed some debugging statements

- Removed a bug which caused dt instead of t to be input into .pt model

- Contains TODO statements (harmless) that will be removed in future commits.

- Test Pass okay !

Mohib
---
 dG3D/src/dG3DMaterialLaw.cpp | 183 ++++++++++++++++++++---------------
 dG3D/src/dG3DMaterialLaw.h   |  50 ++++++----
 2 files changed, 137 insertions(+), 96 deletions(-)

diff --git a/dG3D/src/dG3DMaterialLaw.cpp b/dG3D/src/dG3DMaterialLaw.cpp
index e45f9aa91..bda667625 100644
--- a/dG3D/src/dG3DMaterialLaw.cpp
+++ b/dG3D/src/dG3DMaterialLaw.cpp
@@ -1907,8 +1907,7 @@ torchANNBasedDG3DMaterialLaw::torchANNBasedDG3DMaterialLaw(const int num, const
                 _EXXmean(EXXmean), _EXXstd(EXXstd), _EXYmean(EXYmean), _EXYstd(EXYstd), _EYYmean(EYYmean), _EYYstd(EYYstd), _EYZmean(EYZmean),
                 _EYZstd(EYZstd), _EZZmean(EZZmean), _EZZstd(EZZstd), _EZXmean(EZXmean), _EZXstd(EZXstd), _SXXmean(SXXmean), _SXXstd(SXXstd), _SXYmean(SXYmean),
                 _SXYstd(SXYstd), _SYYmean(SYYmean), _SYYstd(SYYstd), _SYZmean(SYZmean), _SYZstd(SYZstd), _SZZmean(SZZmean), _SZZstd(SZZstd),
-                _SZXmean(SZXmean), _SZXstd(SZXstd), _tangentByPerturbation(pert), _pertTol(tol), _kinematicInput(EGL),
-                _Radiusmean(-1.0), _Radiusstd(-1.0), _CellSizemean(-1.0), _CellSizestd(-1.0) // Constructor initializes these last 4 Class parameters, specific to Lattices on negative values @Mohib
+                _SZXmean(SZXmean), _SZXstd(SZXstd), _tangentByPerturbation(pert), _pertTol(tol), _kinematicInput(EGL)
 {
 #if defined(HAVE_TORCH)
     try{
@@ -1967,39 +1966,39 @@ torchANNBasedDG3DMaterialLaw::torchANNBasedDG3DMaterialLaw(const int num, const
 
 }
 
-// Constructor definition for Lattice based torch (.pt) model @Mohib
-torchANNBasedDG3DMaterialLaw::torchANNBasedDG3DMaterialLaw(const int num, const double rho,
-                                                           const int numberOfInput, const int numInternalVars, const char* nameTorch,
-                                                           const double Radiusmean, const double Radiusstd,
-                                                           const double CellSizemean, const double CellSizestd,
-                                                           const double EXXmean, const double EXXstd, const double EYYmean, const double EYYstd, const double EZZmean, const double EZZstd,
-                                                           const double EXYmean, const double EXYstd, const double EYZmean, const double EYZstd, const double EXZmean, const double EXZstd,
-                                                           const double SXXmean, const double SXXstd, const double SYYmean, const double SYYstd, const double SZZmean, const double SZZstd,
-                                                           const double SXYmean, const double SXYstd, const double SYZmean, const double SYZstd, const double SXZmean, const double SXZstd,
-                                                           bool pert, double tol): torchANNBasedDG3DMaterialLaw(num, rho, numberOfInput, numInternalVars, nameTorch,
-                                                                                                                EXXmean,  EXXstd,  EXYmean, EXYstd,  EYYmean,  EYYstd,
-                                                                                                                EYZmean,  EYZstd,  EZZmean, EZZstd, EXZmean,EXZstd,
-                                                                                                                SXXmean,  SXXstd,  SXYmean, SXYstd, SYYmean,  SYYstd,
-                                                                                                                SYZmean,  SYZstd,  SZZmean, SZZstd, SXZmean,SXZstd,
-                                                                                                                pert, tol)
-{
-
-  _Radiusmean = Radiusmean;
-  _Radiusstd = Radiusstd;
-  _CellSizemean = CellSizemean;
-  _CellSizestd = CellSizestd;
-
-}
-
+// Constructor definition for Lattice based torch (.pt) model @Mohib TODO: Remove comments b4 commit retundent
+//torchANNBasedDG3DMaterialLaw::torchANNBasedDG3DMaterialLaw(const int num, const double rho,
+//                                                           const int numberOfInput, const int numInternalVars, const char* nameTorch,
+//                                                           const double Radiusmean, const double Radiusstd,
+//                                                           const double CellSizemean, const double CellSizestd,
+//                                                           const double EXXmean, const double EXXstd, const double EYYmean, const double EYYstd, const double EZZmean, const double EZZstd,
+//                                                           const double EXYmean, const double EXYstd, const double EYZmean, const double EYZstd, const double EXZmean, const double EXZstd,
+//                                                           const double SXXmean, const double SXXstd, const double SYYmean, const double SYYstd, const double SZZmean, const double SZZstd,
+//                                                           const double SXYmean, const double SXYstd, const double SYZmean, const double SYZstd, const double SXZmean, const double SXZstd,
+//                                                           bool pert, double tol): torchANNBasedDG3DMaterialLaw(num, rho, numberOfInput, numInternalVars, nameTorch,
+//                                                                                                                EXXmean,  EXXstd,  EXYmean, EXYstd,  EYYmean,  EYYstd,
+//                                                                                                                EYZmean,  EYZstd,  EZZmean, EZZstd, EXZmean,EXZstd,
+//                                                                                                                SXXmean,  SXXstd,  SXYmean, SXYstd, SYYmean,  SYYstd,
+//                                                                                                                SYZmean,  SYZstd,  SZZmean, SZZstd, SXZmean,SXZstd,
+//                                                                                                                pert, tol)
+//{
+//
+//  _Radiusmean = Radiusmean;
+//  _Radiusstd = Radiusstd;
+//  _CellSizemean = CellSizemean;
+//  _CellSizestd = CellSizestd;
+//
+//}
 
+// Modified to add _extra_inp, _latticeRadius & _latticeSize @Mohib. TODO: refractor latticeRadius and size to IPvariable
 torchANNBasedDG3DMaterialLaw::torchANNBasedDG3DMaterialLaw(const torchANNBasedDG3DMaterialLaw& src):
       dG3DMaterialLaw(src), _numberOfInput(src._numberOfInput), _numberOfInternalVariables(src._numberOfInternalVariables),
                 _EXXmean(src._EXXmean), _EXXstd(src._EXXstd), _EXYmean(src._EXYmean), _EXYstd(src._EXYstd), _EYYmean(src._EYYmean), _EYYstd(src._EYYstd),
                 _EYZmean(src._EYZmean), _EYZstd(src._EYZstd), _EZZmean(src._EZZmean), _EZZstd(src._EZZstd), _EZXmean(src._EZXmean), _EZXstd(src._EZXstd),
                 _SXXmean(src._SXXmean), _SXXstd(src._SXXstd), _SXYmean(src._SXYmean), _SXYstd(src._SXYstd), _SYYmean(src._SYYmean), _SYYstd(src._SYYstd),
                 _SYZmean(src._SYZmean), _SYZstd(src._SYZstd), _SZZmean(src._SZZmean), _SZZstd(src._SZZstd), _SZXmean(src._SZXmean), _SZXstd(src._SZXstd),
-                _tangentByPerturbation(src._tangentByPerturbation),_pertTol(src._pertTol), _kinematicInput(src._kinematicInput),
-                _Radiusmean(src._Radiusmean), _Radiusstd(src._Radiusstd), _CellSizemean(src._CellSizemean), _CellSizestd(src._CellSizestd) // Constructor initializes these last 4 Class parameters, specific to Lattice based torch model @Mohib
+                _tangentByPerturbation(src._tangentByPerturbation), _pertTol(src._pertTol), _kinematicInput(src._kinematicInput),
+                _extra_inp(src._extra_inp), _latticeRadius(src._latticeRadius), _latticeSize(src._latticeSize)
 {
 #if defined(HAVE_TORCH)
        module = src.module;
@@ -2062,17 +2061,37 @@ void torchANNBasedDG3DMaterialLaw::setLatticeParameters(const double radius, con
     _latticeSize = size;
 }
 
+void torchANNBasedDG3DMaterialLaw::setExtraInp(const double mean, const double std){
+    _extra_inp.push_back(mean);
+    _extra_inp.push_back(std);
+}
+
+void torchANNBasedDG3DMaterialLaw::setInitTime(const double curr_t){
+    _pTime = curr_t;
+    _cTime = curr_t;
+}
+
+void torchANNBasedDG3DMaterialLaw::setTime(const double ctime, const double dtime){
+
+    dG3DMaterialLaw::setTime(ctime,dtime);
+    _pTime = _cTime;
+    _cTime = ctime;
+}
+
 void torchANNBasedDG3DMaterialLaw::initLaws(const std::map<int,materialLaw*> &maplaw){
 #if defined(HAVE_TORCH)
     if (!_initialized){
-        fullMatrix<double> E1(1,6), S(1,6), Geo1(1, 2);
+        // Added additional container Extra1 for additonal inputs @Mohib TODO: Remove hardcode
+        fullMatrix<double> E1(1,6), S(1,6), Extra1(1, 3);
         torch::Tensor h_init = -1.0*torch::ones({1, 1, _numberOfInternalVariables});
         static torch::Tensor h_tmp = torch::zeros({1, 1, _numberOfInternalVariables});
         fullMatrix<double> DSDE(6,6);
 
-        //Read from torch Material Definition @Mohib
-        Geo1(0,0) = _latticeRadius;
-        Geo1(0,1) = _latticeSize;
+        //TODO: Read from torch Material Definition @Mohib
+        Extra1(0,0) = _latticeRadius;
+        Extra1(0,1) = _latticeSize;
+        //TODO: Remove hardcode @Mohib
+        Extra1(0,2) = _cTime;
 
         if (_tangentByPerturbation){
            //TODO: Refractor this for Lattice @Mohib
@@ -2090,13 +2109,13 @@ void torchANNBasedDG3DMaterialLaw::initLaws(const std::map<int,materialLaw*> &ma
                  }
               }
           }
-          if(_numberOfInput==9)
+          if(_numberOfInput>6)
           {
               for (int i=0; i< 6; i++)
               {
                  E1_plus = E1;
                  E1_plus(0,i) += _pertTol;
-                 RNNstressGeo_stiff(E1_plus, Geo1, h_init, h_tmp, S_plus, false, DSDE);
+                 RNNstressGeo_stiff(E1_plus, Extra1, h_init, h_tmp, S_plus, false, DSDE);
                  for (int j=0; j<6; j++)
                  {
                    DSDE(j,i) = (S_plus(0,j) - S(0,j))/_pertTol;
@@ -2114,7 +2133,7 @@ void torchANNBasedDG3DMaterialLaw::initLaws(const std::map<int,materialLaw*> &ma
             }
             if(_numberOfInput==9)
             {
-                RNNstressGeo_stiff(E1, Geo1, h_init, h_tmp, S, true, DSDE);
+                RNNstressGeo_stiff(E1, Extra1, h_init, h_tmp, S, true, DSDE);
             }
 
         }
@@ -2198,11 +2217,12 @@ void torchANNBasedDG3DMaterialLaw::stress(IPVariable* ipv, const IPVariable* ipv
     Msg::Error("kinematic type %d has not been defined",_kinematicInput);
   }
 
-  // Construct Lattice geometry input for torch model @Mohib
-  fullMatrix<double> Geo1(1, 2);
+  // Construct Lattice geometry input for torch model @Mohib, TODO: Remove hardcode! set size of exta by size of _extra
+  fullMatrix<double> Extra1(1, 3);
 
-  Geo1(0, 0) = _latticeRadius;
-  Geo1(0, 1) = _latticeSize;
+  Extra1(0, 0) = _latticeRadius;
+  Extra1(0, 1) = _latticeSize;
+  Extra1(0, 2) = _cTime;
 
   static fullMatrix<double> S(1,6), DSDE(6,6);
   const torch::Tensor& h0 = ipvprev->getConstRefToInternalVariables();
@@ -2230,8 +2250,8 @@ void torchANNBasedDG3DMaterialLaw::stress(IPVariable* ipv, const IPVariable* ipv
       }
       else if(_numberOfInput==9)
       {
-        RNNstressGeo_stiff(E1, Geo1, h0, h1, S, false, DSDE);
-        S.print("original");
+        RNNstressGeo_stiff(E1, Extra1, h0, h1, S, false, DSDE);
+        //S.print("original");
         static fullMatrix<double> E1_plus(1,6), S_plus(1,6);
         static torch::Tensor h_tmp = torch::zeros({1, 1, _numberOfInternalVariables});
 
@@ -2239,8 +2259,8 @@ void torchANNBasedDG3DMaterialLaw::stress(IPVariable* ipv, const IPVariable* ipv
         {
           E1_plus = E1;
           E1_plus(0,i) += _pertTol;
-          RNNstressGeo_stiff(E1_plus, Geo1, h0, h_tmp, S_plus, false, DSDE);
-          S_plus.print("perturbated");
+          RNNstressGeo_stiff(E1_plus, Extra1, h0, h_tmp, S_plus, false, DSDE);
+          //S_plus.print("perturbated");
           for (int j=0; j<6; j++)
           {
             DSDE(j,i) = (S_plus(0,j) - S(0,j))/_pertTol;
@@ -2256,7 +2276,7 @@ void torchANNBasedDG3DMaterialLaw::stress(IPVariable* ipv, const IPVariable* ipv
       }
       else
       {
-        RNNstressGeo_stiff(E1, Geo1, h0, h1, S, true, DSDE);
+        RNNstressGeo_stiff(E1, Extra1, h0, h1, S, true, DSDE);
       }
 
   }
@@ -2472,32 +2492,35 @@ void torchANNBasedDG3DMaterialLaw::RNNstress_stiff(const fullMatrix<double>& E1,
 #endif
 
 #if defined(HAVE_TORCH)
-void torchANNBasedDG3DMaterialLaw::RNNstressGeo_stiff(const fullMatrix<double>& E1, const fullMatrix<double>& Geo1,const torch::Tensor& h0, torch::Tensor& h1,
+void torchANNBasedDG3DMaterialLaw::RNNstressGeo_stiff(const fullMatrix<double>& E1, const fullMatrix<double>& Extra1,const torch::Tensor& h0, torch::Tensor& h1,
                                       fullMatrix<double>& S, const bool stiff, fullMatrix<double>& DSDE)
 {
-
-   static vector<float> E_vec(_numberOfInput - 2);
-   static vector<float> Geo_vec(2);
-   static vector<float> EG_vec;
+    //TODO: Remove Hardcode read from size of extra @Mohib
+   static vector<float> E_vec(_numberOfInput - 3);
+   static vector<float> Extra_vec(3);
+   // Dont pre initialize Combine_vec! It needs to be populated as it grows @Mohib
+   vector<float> Combine_vec;
    static torch::Tensor E_norm;
-   static torch::Tensor Geo_norm;
-   static torch::Tensor EG_norm;
+   static torch::Tensor Extra_norm;
+   torch::Tensor Combine_norm;
    // use RNN to predict stress-------------
    vector<torch::jit::IValue> inputs;
 
    Normalize_strain(E1, E_vec);
-   E_norm = torch::from_blob(E_vec.data(), {1,1, _numberOfInput - 2}, torch::requires_grad());
+   //E_norm = torch::from_blob(E_vec.data(), {1,1, _numberOfInput - 3}, torch::requires_grad());
 
-   Normalize_geo(Geo1, Geo_vec);
-   Geo_norm = torch::from_blob(Geo_vec.data(), {1,1,2}, torch::requires_grad());
+   Normalize_geo(Extra1, Extra_vec);
+   //Extra_norm = torch::from_blob(Geo_vec.data(), {1,1, 3}, torch::requires_grad());
 
-   EG_vec.insert(EG_vec.end(), Geo_vec.begin(), Geo_vec.end());
-   EG_vec.insert(EG_vec.end(), E_vec.begin(), E_vec.end());
-   EG_norm = torch::from_blob(EG_vec.data(), {1,1, _numberOfInput}, torch::requires_grad());
+   Combine_vec.insert(Combine_vec.end(), Extra_vec.begin(), Extra_vec.end());
+   Combine_vec.pop_back();
+   Combine_vec.insert(Combine_vec.end(), E_vec.begin(), E_vec.end());
+   Combine_vec.push_back(Extra_vec[2]);
+   Combine_norm = torch::from_blob(Combine_vec.data(), {1,1, _numberOfInput}, torch::requires_grad());
 
 //   inputs.push_back(Geo_norm);
 //   inputs.push_back(E_norm);
-   inputs.push_back(EG_norm);
+   inputs.push_back(Combine_norm);
    inputs.push_back(h0);
 
    auto outputs= module.forward(inputs).toTuple();
@@ -2538,7 +2561,7 @@ void torchANNBasedDG3DMaterialLaw::RNNstressGeo_stiff(const fullMatrix<double>&
        }
        else if(_numberOfInput >= 6){
          S_norm.backward(VXX,true);                   //  dS/dE
-         auto EnormGrad_a = EG_norm.grad().accessor<float,3>();
+         auto EnormGrad_a = Combine_norm.grad().accessor<float,3>();
 //         DSDE(0,0) = EnormGrad_a[0][0][12];
 //         DSDE(0,1) = EnormGrad_a[0][0][14];
 //         DSDE(0,2) = EnormGrad_a[0][0][20];
@@ -2590,33 +2613,33 @@ void torchANNBasedDG3DMaterialLaw::RNNstressGeo_stiff(const fullMatrix<double>&
 //
 //             printf("\n---------------\n");
          }
-         EG_norm.grad().zero_();
+         Combine_norm.grad().zero_();
          S_norm.backward(VYY,true);                   //  dS/dE
          for (int i=0; i<6; i++){
              DSDE(1,i) = EnormGrad_a[0][0][i+12];
          }
-         EG_norm.grad().zero_();
+         Combine_norm.grad().zero_();
          S_norm.backward(VZZ,true);                   //  dS/dE
          for (int i=0; i<6; i++){
              DSDE(2,i) = EnormGrad_a[0][0][i+12];
          }
-         EG_norm.grad().zero_();
+         Combine_norm.grad().zero_();
          S_norm.backward(VXY,true);                   //  dS/dE
          for (int i=0; i<6; i++){
              DSDE(3,i) = EnormGrad_a[0][0][i+12];
          }
 
-         EG_norm.grad().zero_();
+         Combine_norm.grad().zero_();
          S_norm.backward(VZX,true);                   //  dS/dE
          for (int i=0; i<6; i++){
              DSDE(4,i) = EnormGrad_a[0][0][i+12];
          }
-         EG_norm.grad().zero_();
+         Combine_norm.grad().zero_();
          S_norm.backward(VYZ,true);                   //  dS/dE
          for (int i=0; i<6; i++){
              DSDE(5,i) = EnormGrad_a[0][0][i+1];
          }
-         EG_norm.grad().zero_();
+         Combine_norm.grad().zero_();
       }
 
 
@@ -2761,19 +2784,23 @@ void torchANNBasedDG3DMaterialLaw::Normalize_strain(const fullMatrix<double>& E1
     }
 };
 
-void torchANNBasedDG3DMaterialLaw::Normalize_geo(const fullMatrix<double>& Geo1, vector<float>& Geo_norm) const
-{
-
-    Geo_norm[0] = (Geo1(0, 0) - _Radiusmean) / _Radiusstd;
-    // TODO: Remove Hard Code @ Mohib
-    if(_CellSizestd == 0.0)
-    {
-        Geo_norm[1] = Geo1(0, 1) - _CellSizemean;
-    }
-    else
-    {
-        Geo_norm[1] = Geo1(0, 1) - _CellSizemean;
-    }
+// Normalize extra input params @Mohib TODO: Refractor to Normalizeextra()
+void torchANNBasedDG3DMaterialLaw::Normalize_geo(const fullMatrix<double>& Extra1, vector<float>& Geo_norm) const
+{
+    //TODO: Remove Hard code and loop in _extra_inp @Mohib
+    //TODO: Put a check for 0 std !! @Mohib
+    Geo_norm[0] = (Extra1(0, 0) - _extra_inp[0]) / _extra_inp[1];
+    Geo_norm[1] = (Extra1(0, 1) - _extra_inp[2]) / _extra_inp[3];
+    Geo_norm[2] = (Extra1(0, 2) - _extra_inp[4]) / _extra_inp[5];
+    // TODO: Remove comment before commit @ Mohib
+//    if(_CellSizestd == 0.0)
+//    {
+//        Geo_norm[1] = Geo1(0, 1) - _CellSizemean;
+//    }
+//    else
+//    {
+//        Geo_norm[1] = Geo1(0, 1) - _CellSizemean;
+//    }
 
 };
 
diff --git a/dG3D/src/dG3DMaterialLaw.h b/dG3D/src/dG3DMaterialLaw.h
index c07e01cf9..dc89322a6 100644
--- a/dG3D/src/dG3DMaterialLaw.h
+++ b/dG3D/src/dG3DMaterialLaw.h
@@ -475,8 +475,12 @@ class torchANNBasedDG3DMaterialLaw : public dG3DMaterialLaw{
     int _numberOfInternalVariables;
 
     // Lattice Parameters for Optimizer IO @Mohib
+    // TODO: Read from ipvariable
     double _latticeRadius;
     double _latticeSize;
+    double _pTime;
+    double _cTime;
+
 #if defined(HAVE_TORCH)
     torch::jit::script::Module module;
 #endif
@@ -506,12 +510,14 @@ class torchANNBasedDG3DMaterialLaw : public dG3DMaterialLaw{
     double _SZXmean;
     double _SZXstd;
 
-    // Lattice Parameter inside of torch (.pt) model @ Mohib
-    double _Radiusmean;
-    double _Radiusstd;
-
-    double _CellSizemean;
-    double _CellSizestd;
+    // Container for additional inputs normalization @ Mohib
+    std::vector<double> _extra_inp;
+    // TODO: Remove comments b4 commit @Mohib
+//    double _Radiusmean;
+//    double _Radiusstd;
+//
+//    double _CellSizemean;
+//    double _CellSizestd;
 
     bool _tangentByPerturbation;
     double _pertTol;
@@ -538,21 +544,27 @@ class torchANNBasedDG3DMaterialLaw : public dG3DMaterialLaw{
                 const double SXYstd, const double SYYmean, const double SYYstd, const double SYZmean, const double SYZstd, const double SZZmean,
                 const double SZZstd, const double SZXmean, const double SZXstd, bool pert=false, double tol = 1e-5);
 
-        // Constructor declaration for Lattice based torch model @Mohib
-        torchANNBasedDG3DMaterialLaw(const int num, const double rho,
-                                     const int numberOfInput, const int numInternalVars, const char* nameTorch,
-                                     const double Radiusmean, const double Radiusstd,
-                                     const double CellSizemean, const double CellSizestd,
-                                     const double EXXmean, const double EXXstd, const double EYYmean, const double EYYstd, const double EZZmean, const double EZZstd,
-                                     const double EXYmean, const double EXYstd, const double EYZmean, const double EYZstd, const double EXZmean, const double EXZstd,
-                                     const double SXXmean, const double SXXstd, const double SYYmean, const double SYYstd, const double SZZmean, const double SZZstd,
-                                     const double SXYmean, const double SXYstd, const double SYZmean, const double SYZstd, const double SXZmean, const double SXZstd,
-                                     bool pert=false, double tol = 1e-5);
+        // Constructor declaration for Lattice based torch model @Mohib TODO: Remove b4 commit retundent
+//        torchANNBasedDG3DMaterialLaw(const int num, const double rho,
+//                                     const int numberOfInput, const int numInternalVars, const char* nameTorch,
+//                                     const double Radiusmean, const double Radiusstd,
+//                                     const double CellSizemean, const double CellSizestd,
+//                                     const double EXXmean, const double EXXstd, const double EYYmean, const double EYYstd, const double EZZmean, const double EZZstd,
+//                                     const double EXYmean, const double EXYstd, const double EYZmean, const double EYZstd, const double EXZmean, const double EXZstd,
+//                                     const double SXXmean, const double SXXstd, const double SYYmean, const double SYYstd, const double SZZmean, const double SZZstd,
+//                                     const double SXYmean, const double SXYstd, const double SYZmean, const double SYZstd, const double SXZmean, const double SXZstd,
+//                                     bool pert=false, double tol = 1e-5);
 
 		void setKinematicInput(const int i);
 
-        // Setter initializer for Lattice Parameters @Mohib
+        // Setter for Lattice Parameters @Mohib
         void setLatticeParameters(const double radius, const double size);
+
+        // Setter for extra inputs, populates the _extra_inp vector with normalization parameters
+        void setExtraInp(const double mean, const double std);
+        void setInitTime(const double curr_t);
+
+
     #ifndef SWIG
 		torchANNBasedDG3DMaterialLaw(const torchANNBasedDG3DMaterialLaw& src);
 		virtual ~torchANNBasedDG3DMaterialLaw();
@@ -562,7 +574,9 @@ class torchANNBasedDG3DMaterialLaw : public dG3DMaterialLaw{
 		virtual void createIPVariable(IPVariable* &ipv,  bool hasBodyForce, const MElement *ele, const int nbFF_, const IntPt *GP, const int gpt) const;
 		virtual void initLaws(const std::map<int,materialLaw*> &maplaw);
 		virtual void stress(IPVariable* ipv, const IPVariable* ipvprev, const bool stiff=true, const bool checkfrac=true, const bool dTangent=false);
-		virtual double scaleFactor() const {return 1.;}
+        // Set time at GP @Mohib
+        virtual void setTime(const double ctime, const double dtime);
+        virtual double scaleFactor() const {return 1.;}
 		virtual double soundSpeed() const;
 		virtual materialLaw* clone() const {return new torchANNBasedDG3DMaterialLaw(*this);};
 		virtual void checkInternalState(IPVariable* ipv, const IPVariable* ipvprev) const{}; // do nothing
-- 
GitLab