diff --git a/common/optimization/LocalMinimumSearchVisitor.h b/common/optimization/LocalMinimumSearchVisitor.h new file mode 100644 index 0000000000000000000000000000000000000000..428e8bac0c7beda535cf352b253cf27bb1f4f5d5 --- /dev/null +++ b/common/optimization/LocalMinimumSearchVisitor.h @@ -0,0 +1,67 @@ +#ifndef H_COMMON_OPTIMIZATION_LOCALMINIMUMSEARCH_VISITOR +#define H_COMMON_OPTIMIZATION_LOCALMINIMUMSEARCH_VISITOR + +#include "localminimumsearch.h" +#include "../statefunctional.h" +#include "../../specific/data/objective/l2distance.h" +#include "../../specific/wave/equation/simplewave.h" + +class LocalMinimumSearchVisitor { + + + public: + bool done = false; + + void iteration(const LocalMinimumSearchInterface& lms, FunctionalInterface* const functional, const DescentSearchInterface* const descentsearch, const LineSearchInterface* const linesearch) { + + auto acousticStateFunctional = dynamic_cast<StateFunctional<Physic::acoustic>*> (functional); + if (!acousticStateFunctional) + return; + + auto objective = dynamic_cast<l2distance::Objective<Physic::acoustic>* const> (acousticStateFunctional->_objective); + if (!objective) + return; + + + std::array<bool,5> DataToBeUpdated = {true,false,false,false,false}; + std::array<bool,4> ModelToBeUpdated = {true,false,false,false}; + auto alphas = objective->improvedShotIntensity(acousticStateFunctional->_du.get(DataToBeUpdated,acousticStateFunctional->_mu.get(ModelToBeUpdated)).state(Type::FS)); + + gmshfem::msg::print << "Alphas : " << gmshfem::msg::endl; + gmshfem::msg::indent(); + for (const auto& freq : alphas) { + for (auto complex: freq) { + gmshfem::msg::print << real(complex) << " + " << imag(complex) << "i." << gmshfem::msg::endl; + } + gmshfem::msg::print << gmshfem::msg::endl; + } + gmshfem::msg::unindent(); + + if (!done) + { + gmshfem::msg::print << "Setting ideal alpha" << gmshfem::msg::endl; + + done = true; + gmshfem::msg::print << "There are " << acousticStateFunctional->_equation.size() << " equations" << gmshfem::msg::endl; + for (EquationInterface<Physic::acoustic>* eq : acousticStateFunctional->_equation) + { + auto parametrized = dynamic_cast<ParametrizedEquation<Physic::acoustic>*>(eq); + gmshfem::msg::print << "TYPE OF EQ : " << typeid(parametrized).name() << gmshfem::msg::endl; + auto casted = dynamic_cast < DifferentialEquationInterface<Physic::acoustic>*>(parametrized->_equation); + gmshfem::msg::print << "TYPE OF EQ : " << typeid(casted).name() << gmshfem::msg::endl; + + if (casted) { + gmshfem::msg::print << "Setting ideal alpha once" << gmshfem::msg::endl; + casted->setShotIntensities(alphas); + + } + else + gmshfem::msg::print << "(not a differential equation)" << gmshfem::msg::endl; + + } + } + } + +}; + +#endif \ No newline at end of file diff --git a/common/optimization/localminimumsearch.h b/common/optimization/localminimumsearch.h index 5f2f78ff328f0ed05f02b1a8bb356f8fa98e2cef..1664abbb9a52879852d12efccea18e1f2f68b9ac 100644 --- a/common/optimization/localminimumsearch.h +++ b/common/optimization/localminimumsearch.h @@ -84,6 +84,8 @@ public: virtual void operator()(ModelField* const m, FunctionalInterface* const functional, const DescentSearchInterface* const descentsearch, const LineSearchInterface* const linesearch) const = 0; virtual const LocalMinimumSearchHistoryInterface* const history() const = 0; + + friend class LocalMinimumSearchVisitor; }; /* diff --git a/common/statefunctional.h b/common/statefunctional.h index 28031e7e9f921192e733d05e9c65f60242c06b1d..2f8b40ca5a4b1f092f9069d67932ace120f1e272 100644 --- a/common/statefunctional.h +++ b/common/statefunctional.h @@ -17,7 +17,7 @@ class StateFunctional: public FunctionalInterface private: InnerProductInterface* const _innerproduct; RegularizationInterface* const _regularization; - const std::vector<EquationInterface<T_Physic>*> _equation; + std::vector<EquationInterface<T_Physic>*> _equation; ObjectiveInterface<T_Physic>* const _objective; const unsigned int _nf; @@ -86,6 +86,9 @@ public: /* directional2 */ virtual double directional2(const ModelFunction &dm2); virtual double directional2(const ModelField &dm2); + + // Friend class for visitors + friend class LocalMinimumSearchVisitor; }; #endif // H_STATEFUNCTIONAL diff --git a/common/wave/equation/equation.h b/common/wave/equation/equation.h index ca89a7b722dc33a60a40a808a298ac0267459a73..fe2000092e6165601aa7943e45a7396739062c06 100644 --- a/common/wave/equation/equation.h +++ b/common/wave/equation/equation.h @@ -26,6 +26,11 @@ protected: unsigned int _integrationDegreeBnd; bool _boundary; + + + std::vector<std::vector<std::complex<double>>> _shotIntensities; + bool _useVariableShots = false; + public: EquationInterface(const ConfigurationInterface* const config, const gmshfem::common::GmshFem& gmshFem,std::string suffix=""); @@ -43,6 +48,15 @@ public: virtual bool modelIsObsolete() {return false;};//Returns true, if sensitivity[DIAG] depends on model virtual void modelPerturbationIsObsolete() {}; + + + // Variable shot intensity + virtual bool canUseVariableShots() const {return false;} + bool enabledVariableShots() const {return _useVariableShots;} + void enableVariableShots() {if (!canUseVariableShots()) throw gmshfem::common::Exception("Cannot use variable shots on this equation."); _useVariableShots = true;} + void disableVariableShots() {if (!canUseVariableShots()) throw gmshfem::common::Exception("Cannot use variable shots on this equation."); _useVariableShots = false;} + void setShotIntensities(const std::vector<std::vector<std::complex<double>>>& val) {enableVariableShots(); gmshfem::msg::print << "HI" << gmshfem::msg::endl; _shotIntensities = val;} + }; /* @@ -134,6 +148,8 @@ public: virtual Sensitivity update_sensitivity(Order order, Support support, const DataStateEvaluator<T_Physic>& ds, const ModelStateEvaluator& ms, const WaveStateEvaluator<T_Physic>& ws); virtual bool modelIsObsolete() override;//Returns true, if sensitivity[DIAG] depends on model virtual void modelPerturbationIsObsolete() override; + + friend class LocalMinimumSearchVisitor; }; #endif //H_COMMON_EQUATION diff --git a/specific/optimization/localminimumsearch/classic.cpp b/specific/optimization/localminimumsearch/classic.cpp index 3c345564968a2a42a365d96ab1ad4540d6cf5404..60ce5d88225f46af48611b22fe8170430395bdd6 100644 --- a/specific/optimization/localminimumsearch/classic.cpp +++ b/specific/optimization/localminimumsearch/classic.cpp @@ -3,6 +3,7 @@ //GmshFWI Library #include "classic.h" +#include "../../../common/optimization/LocalMinimumSearchVisitor.h" using namespace gmshfem; using namespace gmshfem::common; @@ -96,12 +97,16 @@ namespace classic double mean_rel_djn = 0.; unsigned int n = 0; bool success = false; + + LocalMinimumSearchVisitor visitor; while(true) { msg::print << "--- Iteration #"+ std::to_string(n)+" --- " << msg::endl; msg::indent(); jn = functional->performance(); msg::print << "performance = " << jn << msg::endl; + // VISTIOR + visitor.iteration(*this, functional, descentsearch, linesearch); jjn = std::real( functional->innerproduct()->product(functional->gradient(),functional->gradient()) ); msg::print << "gradient norm 2 = " << jjn << msg::endl; mmn = std::real(functional->innerproduct()->penalization(*m,*m)); diff --git a/specific/optimization/localminimumsearch/eisenstat.cpp b/specific/optimization/localminimumsearch/eisenstat.cpp index 496ec90ece9b10347436f51ae0a8aa266ad239af..595c85eafb765c4388c99eb99ae069743a3bd914 100644 --- a/specific/optimization/localminimumsearch/eisenstat.cpp +++ b/specific/optimization/localminimumsearch/eisenstat.cpp @@ -4,6 +4,7 @@ //GmshFWI Library #include "eisenstat.h" +#include "../../../common/optimization/LocalMinimumSearchVisitor.h" using namespace gmshfem; using namespace gmshfem::common; @@ -61,12 +62,17 @@ namespace eisenstat double mean_rel_djn = 0.; unsigned int n = 0; bool success = false; + + LocalMinimumSearchVisitor visitor; while(true) { msg::print << "--- Iteration #"+ std::to_string(n)+" --- " << msg::endl; msg::indent(); jn = functional->performance(); msg::print << "performance = " << jn << msg::endl; + // VISTIOR + visitor.iteration(*this, functional, ncg_descentsearch, linesearch); + jjn_1=jjn; jjn = std::real( functional->innerproduct()->product(functional->gradient(),functional->gradient()) ); msg::print << "gradient norm 2 = " << jjn << msg::endl; diff --git a/specific/wave/equation/simplewave.cpp b/specific/wave/equation/simplewave.cpp index b2fe3c972cdd03307245449e0c39aa6cfe79019c..c267257d4d1cd52bbd3b7c49a7b5dcc6350857b5 100644 --- a/specific/wave/equation/simplewave.cpp +++ b/specific/wave/equation/simplewave.cpp @@ -166,7 +166,11 @@ namespace simplewave case Type::FS: for (unsigned int e = 0; e < _config->ne(s); e++) { - _formulation.integral(-1.,tf(_v), _config->emitter(s,e), integrationType(0)); + std::complex<double> intensity = 1.; + if (enabledVariableShots()) { + intensity = _shotIntensities[_f_idx][s]; + } + _formulation.integral(-intensity,tf(_v), _config->emitter(s,e), integrationType(0)); } break; case Type::AS: @@ -203,6 +207,9 @@ namespace simplewave } void Equation::setGreenRHS(unsigned int p) { + if (enabledVariableShots()) { + throw Exception("TODO: GreenRHS and variable shots ?"); + } _formulation.integral(-1.,tf(_v), _config->point(p), integrationType(0)); } diff --git a/specific/wave/equation/simplewave.h b/specific/wave/equation/simplewave.h index 2cbe35d648b5d907c4ebe90f62096b9ffa384909..f523abada45156a15a8eabb1460159832d36dc82 100644 --- a/specific/wave/equation/simplewave.h +++ b/specific/wave/equation/simplewave.h @@ -30,6 +30,9 @@ namespace simplewave //Model related virtual Sensitivity update_sensitivity(Order order, Support support, const DataStateEvaluator<Physic::acoustic>& ds, const ModelStateEvaluator& ms, const WaveStateEvaluator<Physic::acoustic>& ws); + + // Variable shots + virtual bool canUseVariableShots() const override {return true;} private: virtual bool modelIsObsolete() override;//Returns true, if sensitivity[DIAG] depends on model Sensitivity preconditioner(const DataStateEvaluator<Physic::acoustic>& ds, const ModelStateEvaluator& ms);