From 39692e640e3eac338898d920c7838ef890135f7d Mon Sep 17 00:00:00 2001
From: Boris Martin <boris.martin.be@gmail.com>
Date: Wed, 29 Mar 2023 08:24:28 +0200
Subject: [PATCH] First prototype

---
 .../optimization/LocalMinimumSearchVisitor.h  | 67 +++++++++++++++++++
 common/optimization/localminimumsearch.h      |  2 +
 common/statefunctional.h                      |  5 +-
 common/wave/equation/equation.h               | 16 +++++
 .../localminimumsearch/classic.cpp            |  5 ++
 .../localminimumsearch/eisenstat.cpp          |  6 ++
 specific/wave/equation/simplewave.cpp         |  9 ++-
 specific/wave/equation/simplewave.h           |  3 +
 8 files changed, 111 insertions(+), 2 deletions(-)
 create mode 100644 common/optimization/LocalMinimumSearchVisitor.h

diff --git a/common/optimization/LocalMinimumSearchVisitor.h b/common/optimization/LocalMinimumSearchVisitor.h
new file mode 100644
index 0000000..428e8ba
--- /dev/null
+++ b/common/optimization/LocalMinimumSearchVisitor.h
@@ -0,0 +1,67 @@
+#ifndef H_COMMON_OPTIMIZATION_LOCALMINIMUMSEARCH_VISITOR
+#define H_COMMON_OPTIMIZATION_LOCALMINIMUMSEARCH_VISITOR
+
+#include "localminimumsearch.h"
+#include "../statefunctional.h"
+#include "../../specific/data/objective/l2distance.h"
+#include "../../specific/wave/equation/simplewave.h"
+
+class LocalMinimumSearchVisitor {
+
+
+    public:
+    bool done = false;
+
+    void iteration(const LocalMinimumSearchInterface& lms,  FunctionalInterface* const functional, const DescentSearchInterface* const descentsearch, const LineSearchInterface* const linesearch) {
+        
+        auto acousticStateFunctional = dynamic_cast<StateFunctional<Physic::acoustic>*> (functional);
+        if (!acousticStateFunctional)
+            return;
+
+        auto objective = dynamic_cast<l2distance::Objective<Physic::acoustic>* const> (acousticStateFunctional->_objective);
+        if (!objective)
+            return;
+            
+
+        std::array<bool,5> DataToBeUpdated = {true,false,false,false,false};
+        std::array<bool,4> ModelToBeUpdated = {true,false,false,false};
+        auto alphas = objective->improvedShotIntensity(acousticStateFunctional->_du.get(DataToBeUpdated,acousticStateFunctional->_mu.get(ModelToBeUpdated)).state(Type::FS));
+
+        gmshfem::msg::print << "Alphas : " << gmshfem::msg::endl;
+        gmshfem::msg::indent();
+        for (const auto& freq : alphas) {
+            for (auto complex: freq) {
+                gmshfem::msg::print << real(complex) << " + " << imag(complex) << "i." << gmshfem::msg::endl;
+            }
+            gmshfem::msg::print << gmshfem::msg::endl;
+        }
+        gmshfem::msg::unindent();
+
+        if (!done)
+        {
+            gmshfem::msg::print << "Setting ideal alpha" << gmshfem::msg::endl;
+
+            done = true;
+            gmshfem::msg::print << "There are " << acousticStateFunctional->_equation.size() << " equations" << gmshfem::msg::endl;
+            for (EquationInterface<Physic::acoustic>* eq : acousticStateFunctional->_equation)
+            {
+                auto parametrized = dynamic_cast<ParametrizedEquation<Physic::acoustic>*>(eq);
+                gmshfem::msg::print << "TYPE OF EQ : " << typeid(parametrized).name() << gmshfem::msg::endl;
+                auto casted = dynamic_cast < DifferentialEquationInterface<Physic::acoustic>*>(parametrized->_equation);
+                gmshfem::msg::print << "TYPE OF EQ : " << typeid(casted).name() << gmshfem::msg::endl;
+                    
+                if (casted) {
+                    gmshfem::msg::print << "Setting ideal alpha once" << gmshfem::msg::endl;
+                                        casted->setShotIntensities(alphas);
+
+                }
+                else
+                    gmshfem::msg::print << "(not a differential equation)" << gmshfem::msg::endl;
+
+            }
+        }
+    }
+
+};
+
+#endif
\ No newline at end of file
diff --git a/common/optimization/localminimumsearch.h b/common/optimization/localminimumsearch.h
index 5f2f78f..1664abb 100644
--- a/common/optimization/localminimumsearch.h
+++ b/common/optimization/localminimumsearch.h
@@ -84,6 +84,8 @@ public:
     virtual void operator()(ModelField* const m, FunctionalInterface* const functional, const DescentSearchInterface* const descentsearch, const LineSearchInterface* const linesearch) const = 0;
 
     virtual const LocalMinimumSearchHistoryInterface* const history() const = 0;
+
+    friend class LocalMinimumSearchVisitor;
 };
 
 /*
diff --git a/common/statefunctional.h b/common/statefunctional.h
index 28031e7..2f8b40c 100644
--- a/common/statefunctional.h
+++ b/common/statefunctional.h
@@ -17,7 +17,7 @@ class StateFunctional: public FunctionalInterface
 private:
     InnerProductInterface* const _innerproduct;
     RegularizationInterface* const _regularization;
-    const std::vector<EquationInterface<T_Physic>*> _equation;
+    std::vector<EquationInterface<T_Physic>*> _equation;
     ObjectiveInterface<T_Physic>* const _objective;
     const unsigned int _nf;
 
@@ -86,6 +86,9 @@ public:
     /* directional2 */
     virtual double directional2(const ModelFunction &dm2);
     virtual double directional2(const ModelField &dm2);
+
+    // Friend class for visitors
+    friend class LocalMinimumSearchVisitor;
 };
 
 #endif // H_STATEFUNCTIONAL
diff --git a/common/wave/equation/equation.h b/common/wave/equation/equation.h
index ca89a7b..fe20000 100644
--- a/common/wave/equation/equation.h
+++ b/common/wave/equation/equation.h
@@ -26,6 +26,11 @@ protected:
     unsigned int _integrationDegreeBnd;
 
     bool _boundary;
+
+
+    std::vector<std::vector<std::complex<double>>> _shotIntensities;
+    bool _useVariableShots = false;
+
 public:
     EquationInterface(const ConfigurationInterface* const config, const gmshfem::common::GmshFem& gmshFem,std::string suffix="");
 
@@ -43,6 +48,15 @@ public:
 
     virtual bool modelIsObsolete() {return false;};//Returns true, if sensitivity[DIAG] depends on model
     virtual void modelPerturbationIsObsolete() {};
+
+
+    // Variable shot intensity
+    virtual bool canUseVariableShots() const {return false;}
+    bool enabledVariableShots() const {return _useVariableShots;}
+    void enableVariableShots() {if (!canUseVariableShots()) throw gmshfem::common::Exception("Cannot use variable shots on this equation."); _useVariableShots = true;}
+    void disableVariableShots() {if (!canUseVariableShots()) throw gmshfem::common::Exception("Cannot use variable shots on this equation."); _useVariableShots = false;}
+    void setShotIntensities(const std::vector<std::vector<std::complex<double>>>& val) {enableVariableShots(); gmshfem::msg::print << "HI" << gmshfem::msg::endl; _shotIntensities = val;}
+
 };
 
 /*
@@ -134,6 +148,8 @@ public:
     virtual Sensitivity update_sensitivity(Order order, Support support, const DataStateEvaluator<T_Physic>& ds, const ModelStateEvaluator& ms, const WaveStateEvaluator<T_Physic>& ws);
     virtual bool modelIsObsolete() override;//Returns true, if sensitivity[DIAG] depends on model
     virtual void modelPerturbationIsObsolete() override;
+
+    friend class LocalMinimumSearchVisitor;
 };
 
 #endif //H_COMMON_EQUATION
diff --git a/specific/optimization/localminimumsearch/classic.cpp b/specific/optimization/localminimumsearch/classic.cpp
index 3c34556..60ce5d8 100644
--- a/specific/optimization/localminimumsearch/classic.cpp
+++ b/specific/optimization/localminimumsearch/classic.cpp
@@ -3,6 +3,7 @@
 
 //GmshFWI Library
 #include "classic.h"
+#include "../../../common/optimization/LocalMinimumSearchVisitor.h"
 
 using namespace gmshfem;
 using namespace gmshfem::common;
@@ -96,12 +97,16 @@ namespace classic
         double mean_rel_djn = 0.;
         unsigned int n = 0;
         bool success = false;
+
+        LocalMinimumSearchVisitor visitor;
         while(true)
         {
             msg::print << "--- Iteration #"+ std::to_string(n)+" --- " << msg::endl;
             msg::indent();
             jn = functional->performance();
             msg::print << "performance = " <<  jn  << msg::endl;
+             // VISTIOR
+            visitor.iteration(*this, functional, descentsearch, linesearch);
             jjn = std::real( functional->innerproduct()->product(functional->gradient(),functional->gradient()) );
             msg::print << "gradient norm 2 = " <<  jjn  << msg::endl;
             mmn = std::real(functional->innerproduct()->penalization(*m,*m));
diff --git a/specific/optimization/localminimumsearch/eisenstat.cpp b/specific/optimization/localminimumsearch/eisenstat.cpp
index 496ec90..595c85e 100644
--- a/specific/optimization/localminimumsearch/eisenstat.cpp
+++ b/specific/optimization/localminimumsearch/eisenstat.cpp
@@ -4,6 +4,7 @@
 
 //GmshFWI Library
 #include "eisenstat.h"
+#include "../../../common/optimization/LocalMinimumSearchVisitor.h"
 
 using namespace gmshfem;
 using namespace gmshfem::common;
@@ -61,12 +62,17 @@ namespace eisenstat
         double mean_rel_djn = 0.;
         unsigned int n = 0;
         bool success = false;
+
+        LocalMinimumSearchVisitor visitor;
         while(true)
         {
             msg::print << "--- Iteration #"+ std::to_string(n)+" --- " << msg::endl;
             msg::indent();
             jn = functional->performance();
             msg::print << "performance = " <<  jn  << msg::endl;
+            // VISTIOR
+            visitor.iteration(*this, functional, ncg_descentsearch, linesearch);
+            
             jjn_1=jjn;
             jjn = std::real( functional->innerproduct()->product(functional->gradient(),functional->gradient()) );
             msg::print << "gradient norm 2 = " <<  jjn  << msg::endl;
diff --git a/specific/wave/equation/simplewave.cpp b/specific/wave/equation/simplewave.cpp
index b2fe3c9..c267257 100644
--- a/specific/wave/equation/simplewave.cpp
+++ b/specific/wave/equation/simplewave.cpp
@@ -166,7 +166,11 @@ namespace simplewave
             case Type::FS:
                 for (unsigned int e = 0; e < _config->ne(s); e++)
                 {
-                  _formulation.integral(-1.,tf(_v), _config->emitter(s,e), integrationType(0));
+                  std::complex<double> intensity = 1.;
+                  if (enabledVariableShots()) {
+                    intensity = _shotIntensities[_f_idx][s];
+                  }
+                  _formulation.integral(-intensity,tf(_v), _config->emitter(s,e), integrationType(0));
                 }
                 break;
             case Type::AS:
@@ -203,6 +207,9 @@ namespace simplewave
     }
     void Equation::setGreenRHS(unsigned int p)
     {
+        if (enabledVariableShots()) {
+            throw Exception("TODO: GreenRHS and variable shots ?");
+        }
         _formulation.integral(-1.,tf(_v), _config->point(p), integrationType(0));
     }
 
diff --git a/specific/wave/equation/simplewave.h b/specific/wave/equation/simplewave.h
index 2cbe35d..f523aba 100644
--- a/specific/wave/equation/simplewave.h
+++ b/specific/wave/equation/simplewave.h
@@ -30,6 +30,9 @@ namespace simplewave
 
         //Model related
         virtual Sensitivity update_sensitivity(Order order, Support support, const DataStateEvaluator<Physic::acoustic>& ds, const ModelStateEvaluator& ms, const WaveStateEvaluator<Physic::acoustic>& ws);
+    
+        // Variable shots
+        virtual bool canUseVariableShots() const override {return true;}
     private:
         virtual bool modelIsObsolete() override;//Returns true, if sensitivity[DIAG] depends on model
         Sensitivity preconditioner(const DataStateEvaluator<Physic::acoustic>& ds, const ModelStateEvaluator& ms);
-- 
GitLab