From d6b48db112ef6d3840d1b695b3c3652fe047437d Mon Sep 17 00:00:00 2001
From: Jonathan Lambrechts <jonathan.lambrechts@uclouvain.be>
Date: Sat, 19 Nov 2011 07:02:10 +0000
Subject: [PATCH] Solver/function : add numpy function, combined with blitz
 those function are >5 times faster than previous "functionPython"

---
 CMakeLists.txt         | 21 +++++++++++
 Common/GmshConfig.h.in |  1 +
 Solver/functionNumpy.h | 80 ++++++++++++++++++++++++++++++++++++++++++
 gmshpy/CMakeLists.txt  | 52 ++++++++++++---------------
 gmshpy/gmshSolver.i    |  2 ++
 5 files changed, 126 insertions(+), 30 deletions(-)
 create mode 100644 Solver/functionNumpy.h

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 2bb0932573..2a88e89592 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -41,6 +41,7 @@ option(ENABLE_MPEG_ENCODE "Enable built-in MPEG encoder" ON)
 option(ENABLE_MPI "Enable MPI parallelization" OFF)
 option(ENABLE_MSVC_STATIC_RUNTIME "Use static Visual C++ runtime" OFF)
 option(ENABLE_NATIVE_FILE_CHOOSER "Enable native file chooser in GUI" ON)
+option(ENABLE_NUMPY "Enable Numpy function in solvers (requires SWIG)" ON)
 option(ENABLE_NETGEN "Enable Netgen mesh generator" ON)
 option(ENABLE_OCC "Enable Open CASCADE geometrical models" ON)
 option(ENABLE_OSMESA "Use OSMesa for offscreen rendering" OFF)
@@ -799,6 +800,26 @@ if(ENABLE_OSMESA)
   endif(OSMESA_LIB)
 endif(ENABLE_OSMESA)
 
+if(ENABLE_SWIG)
+  find_package(SWIG)
+  find_package(PythonLibs)
+  if(SWIG_FOUND AND PYTHONLIBS_FOUND)
+    message(STATUS "Found SWIG version " ${SWIG_VERSION})
+    string(SUBSTRING ${SWIG_VERSION} 0 1 SWIG_MAJOR_VERSION)
+    if(SWIG_MAJOR_VERSION EQUAL 1)
+      message("WARNING: Python bindings require SWIG >= 2: disabling Python")
+    else(SWIG_MAJOR_VERSION EQUAL 1)
+      set_config_option(HAVE_SWIG "Swig")
+      if(ENABLE_NUMPY)
+        find_path(NUMPY_INC "numpyconfig.h" PATH_SUFFIXES include/numpy)
+        if(NUMPY_INC)
+          set_config_option(HAVE_NUMPY "Numpy")
+        endif(NUMPY_INC)
+      endif(ENABLE_NUMPY)
+    endif(SWIG_MAJOR_VERSION EQUAL 1)
+  endif(SWIG_FOUND AND PYTHONLIBS_FOUND)
+endif(ENABLE_SWIG)
+
 check_function_exists(vsnprintf HAVE_VSNPRINTF)
 if(NOT HAVE_VSNPRINTF)
   set_config_option(HAVE_NO_VSNPRINTF "NoVsnprintf")
diff --git a/Common/GmshConfig.h.in b/Common/GmshConfig.h.in
index 30771075fb..edc5de2f8e 100644
--- a/Common/GmshConfig.h.in
+++ b/Common/GmshConfig.h.in
@@ -38,6 +38,7 @@
 #cmakedefine HAVE_NETGEN
 #cmakedefine HAVE_NO_SOCKLEN_T
 #cmakedefine HAVE_NO_VSNPRINTF
+#cmakedefine HAVE_NUMPY
 #cmakedefine HAVE_OCC
 #cmakedefine HAVE_OPENGL
 #cmakedefine HAVE_OSMESA
diff --git a/Solver/functionNumpy.h b/Solver/functionNumpy.h
new file mode 100644
index 0000000000..bfce74bc37
--- /dev/null
+++ b/Solver/functionNumpy.h
@@ -0,0 +1,80 @@
+#ifndef _FUNCTION_NUMPY_H_
+#define _FUNCTION_NUMPY_H_
+#include "GmshConfig.h"
+#ifdef HAVE_NUMPY
+#include "Python.h"
+#include "numpy/arrayobject.h"
+
+class functionNumpy : public function {
+  PyObject *_pycallback;
+  std::vector<fullMatrix<double> > args;
+public:
+  static PyObject *pyArrayFromFullMatrix(fullMatrix<double> &m) {
+    long int n[2] = {m.size1(), m.size2()};
+    return PyArray_New(&PyArray_Type, 2, n, NPY_DOUBLE, NULL, &m(0, 0), 0, NPY_FARRAY, NULL);
+  }
+  void call (dataCacheMap *m, fullMatrix<double> &res) 
+  {
+    PyObject *swigR;
+    PyObject *pyargs;
+    std::vector<PyObject*> swigA(args.size());
+    swigR = pyArrayFromFullMatrix(res);
+    for (int i = 0; i < args.size(); i++) {
+      swigA[i] = pyArrayFromFullMatrix(args[i]);
+    }
+    switch(args.size()) {
+      case 0 : pyargs = Py_BuildValue("(N)", swigR); break;
+      case 1 : pyargs = Py_BuildValue("(NN)", swigR, swigA[0]); break;
+      case 2 : pyargs = Py_BuildValue("(NNN)", swigR, swigA[0], swigA[1]); break;
+      case 3 : pyargs = Py_BuildValue("(NNNN)", swigR, swigA[0], swigA[1], swigA[2]); break;
+      case 4 : pyargs = Py_BuildValue("(NNNNN)", swigR, swigA[0], swigA[1], swigA[2], swigA[3]); break;
+      case 5 : pyargs = Py_BuildValue("(NNNNNN)", swigR, swigA[0], swigA[1], swigA[2], swigA[3], swigA[4]); break;
+      case 6 : pyargs = Py_BuildValue("(NNNNNNN)", swigR, swigA[0], swigA[1], swigA[2], swigA[3], swigA[4], swigA[5]); break;
+      case 7 : pyargs = Py_BuildValue("(NNNNNNNN)", swigR, swigA[0], swigA[1], swigA[2], swigA[3], swigA[4], swigA[5], swigA[6]); break;
+      case 8 : pyargs = Py_BuildValue("(NNNNNNNNN)", swigR, swigA[0], swigA[1], swigA[2], swigA[3], swigA[4], swigA[5], swigA[6], swigA[7]); break;
+      default:Msg::Error("python function not implemented for more than 8 arguments");
+    }
+    PyObject *result = PyEval_CallObject(_pycallback, pyargs);
+    if (result) {
+      Py_DECREF(result);
+    }
+    else {
+      PyErr_Print();
+      Msg::Fatal("An error occurs in the python function.");
+    }
+/*    for (int i = 0; i < args.size(); i++) {
+      Py_DECREF(swigA[i]);
+    }
+    Py_DECREF(swigR);*/
+  }
+  functionNumpy (int nbCol, PyObject *callback, std::vector<const function*> dependencies)
+    : function(nbCol), _pycallback(callback)
+  {
+    args.resize(dependencies.size());
+    for (int i = 0; i < dependencies.size(); i++) {
+      setArgument(args[i], dependencies[i]);
+    }
+    static bool _arrayImported = false;
+    if (! _arrayImported){
+      _import_array();
+      _arrayImported = true;
+    }
+  }
+  functionNumpy (int nbCol, PyObject *callback, std::vector<std::pair<const function*, int> > dependencies)
+    : function(nbCol), _pycallback(callback)
+  {
+    args.resize(dependencies.size());
+    for (int i = 0; i < dependencies.size(); i++) {
+      setArgument(args[i], dependencies[i].first, dependencies[i].second);
+    }
+    printf("import array !!!\n");
+    static bool _arrayImported = false;
+    if (! _arrayImported) {
+      printf("import array !!!\n");
+      _import_array();
+      _arrayImported = true;
+    }
+  }
+};
+#endif
+#endif
diff --git a/gmshpy/CMakeLists.txt b/gmshpy/CMakeLists.txt
index 9521b8628b..4933e7db41 100644
--- a/gmshpy/CMakeLists.txt
+++ b/gmshpy/CMakeLists.txt
@@ -56,37 +56,29 @@ MACRO(SWIG_GET_WRAPPER_DEPENDENCIES swigFile genWrapper language DEST_VARIABLE)
   ENDIF(NOT ${swig_getdeps_error} EQUAL 0)
 ENDMACRO(SWIG_GET_WRAPPER_DEPENDENCIES)
 
-if(ENABLE_SWIG)
-  find_package(SWIG)
-  if(SWIG_FOUND)
-    message(STATUS "Found SWIG version " ${SWIG_VERSION})
-    string(SUBSTRING ${SWIG_VERSION} 0 1 SWIG_MAJOR_VERSION)
-    if(SWIG_MAJOR_VERSION EQUAL 1)
-      message("WARNING: Python bindings require SWIG >= 2: disabling Python")
-    else(SWIG_MAJOR_VERSION EQUAL 1)
-      include(${SWIG_USE_FILE})
-      find_package(PythonLibs)
-      include_directories(${PYTHON_INCLUDE_DIR})
+if(HAVE_SWIG)
+  include(${SWIG_USE_FILE})
+  include_directories(${PYTHON_INCLUDE_DIR})
+  if(HAVE_NUMPY)
+    include_directories(${NUMPY_INC})
+    add_definitions(-DHAVE_NUMPY)
+  endif(HAVE_NUMPY)
 
-      foreach(module ${SWIG_MODULES})
-        set_source_files_properties(${module}.i PROPERTIES CPLUSPLUS ON)
+  foreach(module ${SWIG_MODULES})
+    set_source_files_properties(${module}.i PROPERTIES CPLUSPLUS ON)
 
-        # code backported from CMake git version, see CMake bug 4147
-        SWIG_GET_WRAPPER_DEPENDENCIES(${CMAKE_CURRENT_SOURCE_DIR}/\${module}.i ${CMAKE_CURRENT_BINARY_DIR}/${module}PYTHON_wrap.cxx python swig_extra_dependencies)
-        LIST(APPEND SWIG_MODULE_${module}_EXTRA_DEPS ${swig_extra_dependencies})
+    # code backported from CMake git version, see CMake bug 4147
+    SWIG_GET_WRAPPER_DEPENDENCIES(${CMAKE_CURRENT_SOURCE_DIR}/\${module}.i ${CMAKE_CURRENT_BINARY_DIR}/${module}PYTHON_wrap.cxx python swig_extra_dependencies)
+    LIST(APPEND SWIG_MODULE_${module}_EXTRA_DEPS ${swig_extra_dependencies})
 
-        swig_add_module(${module} python ${module}.i)
-        swig_link_libraries(${module} ${PYTHON_LIBRARIES} shared)
-        SET(GMSH_PYTHON_MODULES_INCLUDE_CODE 
-            "${GMSH_PYTHON_MODULES_INCLUDE_CODE}from ${module} import *\n")
-        list(APPEND GMSHPY_DEPENDS "_${module}")
-      endforeach(module)
-      add_custom_target("_gmshpy" DEPENDS ${GMSHPY_DEPENDS})
+    swig_add_module(${module} python ${module}.i)
+    swig_link_libraries(${module} ${PYTHON_LIBRARIES} shared)
+    SET(GMSH_PYTHON_MODULES_INCLUDE_CODE 
+        "${GMSH_PYTHON_MODULES_INCLUDE_CODE}from ${module} import *\n")
+    list(APPEND GMSHPY_DEPENDS "_${module}")
+  endforeach(module)
+  add_custom_target("_gmshpy" DEPENDS ${GMSHPY_DEPENDS})
 
-      configure_file(${CMAKE_CURRENT_SOURCE_DIR}/__init__.py.in 
-                     ${CMAKE_CURRENT_BINARY_DIR}/__init__.py)
-
-
-    endif(SWIG_MAJOR_VERSION EQUAL 1)
-  endif(SWIG_FOUND)
-endif(ENABLE_SWIG)
+  configure_file(${CMAKE_CURRENT_SOURCE_DIR}/__init__.py.in 
+                 ${CMAKE_CURRENT_BINARY_DIR}/__init__.py)
+endif(HAVE_SWIG)
diff --git a/gmshpy/gmshSolver.i b/gmshpy/gmshSolver.i
index 1d350c8b66..97832aa66a 100644
--- a/gmshpy/gmshSolver.i
+++ b/gmshpy/gmshSolver.i
@@ -13,6 +13,7 @@
   #include "function.h"
   #include "functionDerivator.h"
   #include "functionPython.h"
+  #include "functionNumpy.h"
   #include "linearSystem.h"
   #include "linearSystemCSR.h"
   #include "linearSystemFull.h"
@@ -31,6 +32,7 @@ namespace std {
 %include "function.h"
 %include "functionDerivator.h"
 %include "functionPython.h"
+%include "functionNumpy.h"
 %include "linearSystem.h"
 %template(linearSystemDouble) linearSystem<double>;
 %template(linearSystemFullMatrixDouble) linearSystem<fullMatrix<double> >;
-- 
GitLab