Skip to content
Snippets Groups Projects
Commit d6b48db1 authored by Jonathan Lambrechts's avatar Jonathan Lambrechts
Browse files

Solver/function : add numpy function, combined with blitz those function

are >5 times faster than previous "functionPython"
parent 654cb0e1
No related branches found
No related tags found
No related merge requests found
......@@ -41,6 +41,7 @@ option(ENABLE_MPEG_ENCODE "Enable built-in MPEG encoder" ON)
option(ENABLE_MPI "Enable MPI parallelization" OFF)
option(ENABLE_MSVC_STATIC_RUNTIME "Use static Visual C++ runtime" OFF)
option(ENABLE_NATIVE_FILE_CHOOSER "Enable native file chooser in GUI" ON)
option(ENABLE_NUMPY "Enable Numpy function in solvers (requires SWIG)" ON)
option(ENABLE_NETGEN "Enable Netgen mesh generator" ON)
option(ENABLE_OCC "Enable Open CASCADE geometrical models" ON)
option(ENABLE_OSMESA "Use OSMesa for offscreen rendering" OFF)
......@@ -799,6 +800,26 @@ if(ENABLE_OSMESA)
endif(OSMESA_LIB)
endif(ENABLE_OSMESA)
if(ENABLE_SWIG)
find_package(SWIG)
find_package(PythonLibs)
if(SWIG_FOUND AND PYTHONLIBS_FOUND)
message(STATUS "Found SWIG version " ${SWIG_VERSION})
string(SUBSTRING ${SWIG_VERSION} 0 1 SWIG_MAJOR_VERSION)
if(SWIG_MAJOR_VERSION EQUAL 1)
message("WARNING: Python bindings require SWIG >= 2: disabling Python")
else(SWIG_MAJOR_VERSION EQUAL 1)
set_config_option(HAVE_SWIG "Swig")
if(ENABLE_NUMPY)
find_path(NUMPY_INC "numpyconfig.h" PATH_SUFFIXES include/numpy)
if(NUMPY_INC)
set_config_option(HAVE_NUMPY "Numpy")
endif(NUMPY_INC)
endif(ENABLE_NUMPY)
endif(SWIG_MAJOR_VERSION EQUAL 1)
endif(SWIG_FOUND AND PYTHONLIBS_FOUND)
endif(ENABLE_SWIG)
check_function_exists(vsnprintf HAVE_VSNPRINTF)
if(NOT HAVE_VSNPRINTF)
set_config_option(HAVE_NO_VSNPRINTF "NoVsnprintf")
......
......@@ -38,6 +38,7 @@
#cmakedefine HAVE_NETGEN
#cmakedefine HAVE_NO_SOCKLEN_T
#cmakedefine HAVE_NO_VSNPRINTF
#cmakedefine HAVE_NUMPY
#cmakedefine HAVE_OCC
#cmakedefine HAVE_OPENGL
#cmakedefine HAVE_OSMESA
......
#ifndef _FUNCTION_NUMPY_H_
#define _FUNCTION_NUMPY_H_
#include "GmshConfig.h"
#ifdef HAVE_NUMPY
#include "Python.h"
#include "numpy/arrayobject.h"
class functionNumpy : public function {
PyObject *_pycallback;
std::vector<fullMatrix<double> > args;
public:
static PyObject *pyArrayFromFullMatrix(fullMatrix<double> &m) {
long int n[2] = {m.size1(), m.size2()};
return PyArray_New(&PyArray_Type, 2, n, NPY_DOUBLE, NULL, &m(0, 0), 0, NPY_FARRAY, NULL);
}
void call (dataCacheMap *m, fullMatrix<double> &res)
{
PyObject *swigR;
PyObject *pyargs;
std::vector<PyObject*> swigA(args.size());
swigR = pyArrayFromFullMatrix(res);
for (int i = 0; i < args.size(); i++) {
swigA[i] = pyArrayFromFullMatrix(args[i]);
}
switch(args.size()) {
case 0 : pyargs = Py_BuildValue("(N)", swigR); break;
case 1 : pyargs = Py_BuildValue("(NN)", swigR, swigA[0]); break;
case 2 : pyargs = Py_BuildValue("(NNN)", swigR, swigA[0], swigA[1]); break;
case 3 : pyargs = Py_BuildValue("(NNNN)", swigR, swigA[0], swigA[1], swigA[2]); break;
case 4 : pyargs = Py_BuildValue("(NNNNN)", swigR, swigA[0], swigA[1], swigA[2], swigA[3]); break;
case 5 : pyargs = Py_BuildValue("(NNNNNN)", swigR, swigA[0], swigA[1], swigA[2], swigA[3], swigA[4]); break;
case 6 : pyargs = Py_BuildValue("(NNNNNNN)", swigR, swigA[0], swigA[1], swigA[2], swigA[3], swigA[4], swigA[5]); break;
case 7 : pyargs = Py_BuildValue("(NNNNNNNN)", swigR, swigA[0], swigA[1], swigA[2], swigA[3], swigA[4], swigA[5], swigA[6]); break;
case 8 : pyargs = Py_BuildValue("(NNNNNNNNN)", swigR, swigA[0], swigA[1], swigA[2], swigA[3], swigA[4], swigA[5], swigA[6], swigA[7]); break;
default:Msg::Error("python function not implemented for more than 8 arguments");
}
PyObject *result = PyEval_CallObject(_pycallback, pyargs);
if (result) {
Py_DECREF(result);
}
else {
PyErr_Print();
Msg::Fatal("An error occurs in the python function.");
}
/* for (int i = 0; i < args.size(); i++) {
Py_DECREF(swigA[i]);
}
Py_DECREF(swigR);*/
}
functionNumpy (int nbCol, PyObject *callback, std::vector<const function*> dependencies)
: function(nbCol), _pycallback(callback)
{
args.resize(dependencies.size());
for (int i = 0; i < dependencies.size(); i++) {
setArgument(args[i], dependencies[i]);
}
static bool _arrayImported = false;
if (! _arrayImported){
_import_array();
_arrayImported = true;
}
}
functionNumpy (int nbCol, PyObject *callback, std::vector<std::pair<const function*, int> > dependencies)
: function(nbCol), _pycallback(callback)
{
args.resize(dependencies.size());
for (int i = 0; i < dependencies.size(); i++) {
setArgument(args[i], dependencies[i].first, dependencies[i].second);
}
printf("import array !!!\n");
static bool _arrayImported = false;
if (! _arrayImported) {
printf("import array !!!\n");
_import_array();
_arrayImported = true;
}
}
};
#endif
#endif
......@@ -56,17 +56,13 @@ MACRO(SWIG_GET_WRAPPER_DEPENDENCIES swigFile genWrapper language DEST_VARIABLE)
ENDIF(NOT ${swig_getdeps_error} EQUAL 0)
ENDMACRO(SWIG_GET_WRAPPER_DEPENDENCIES)
if(ENABLE_SWIG)
find_package(SWIG)
if(SWIG_FOUND)
message(STATUS "Found SWIG version " ${SWIG_VERSION})
string(SUBSTRING ${SWIG_VERSION} 0 1 SWIG_MAJOR_VERSION)
if(SWIG_MAJOR_VERSION EQUAL 1)
message("WARNING: Python bindings require SWIG >= 2: disabling Python")
else(SWIG_MAJOR_VERSION EQUAL 1)
if(HAVE_SWIG)
include(${SWIG_USE_FILE})
find_package(PythonLibs)
include_directories(${PYTHON_INCLUDE_DIR})
if(HAVE_NUMPY)
include_directories(${NUMPY_INC})
add_definitions(-DHAVE_NUMPY)
endif(HAVE_NUMPY)
foreach(module ${SWIG_MODULES})
set_source_files_properties(${module}.i PROPERTIES CPLUSPLUS ON)
......@@ -85,8 +81,4 @@ if(ENABLE_SWIG)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/__init__.py.in
${CMAKE_CURRENT_BINARY_DIR}/__init__.py)
endif(SWIG_MAJOR_VERSION EQUAL 1)
endif(SWIG_FOUND)
endif(ENABLE_SWIG)
endif(HAVE_SWIG)
......@@ -13,6 +13,7 @@
#include "function.h"
#include "functionDerivator.h"
#include "functionPython.h"
#include "functionNumpy.h"
#include "linearSystem.h"
#include "linearSystemCSR.h"
#include "linearSystemFull.h"
......@@ -31,6 +32,7 @@ namespace std {
%include "function.h"
%include "functionDerivator.h"
%include "functionPython.h"
%include "functionNumpy.h"
%include "linearSystem.h"
%template(linearSystemDouble) linearSystem<double>;
%template(linearSystemFullMatrixDouble) linearSystem<fullMatrix<double> >;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment