diff --git a/Numeric/fullMatrix.cpp b/Numeric/fullMatrix.cpp index bd42f7611834fb280dd54435a17049c459bc31fb..73767d23ce4b651095389c543bf7c618b464783f 100644 --- a/Numeric/fullMatrix.cpp +++ b/Numeric/fullMatrix.cpp @@ -20,9 +20,10 @@ // Specialisation of fullVector/Matrix operations using BLAS and LAPACK #if defined(HAVE_BLAS) - extern "C" { void F77NAME(daxpy)(int *n, double *alpha, double *x, int *incx, double *y, int *incy); + void F77NAME(dcopy)(int *n, double *a, int *inca, double *b, int *incb); + void F77NAME(zcopy)(int *n, std::complex<double> *a, int *inca, std::complex<double> *b, int *incb); void F77NAME(dgemm)(const char *transa, const char *transb, int *m, int *n, int *k, double *alpha, double *a, int *lda, double *b, int *ldb, double *beta, @@ -50,6 +51,40 @@ void fullVector<double>::axpy(const fullVector<double> &x,double alpha) F77NAME(daxpy)(&M, &alpha, x._data,&INCX, _data, &INCY); } +template<> +void fullVector<double>::setAll(const fullVector<double> &m) +{ + int stride = 1; + F77NAME(dcopy)(&_r, m._data, &stride, _data, &stride); +} + +template<> +void fullVector<std::complex<double> >::setAll(const fullVector<std::complex<double> > &m) +{ + int stride = 1; + F77NAME(zcopy)(&_r, m._data, &stride, _data, &stride); +} + +template<> +void fullMatrix<double>::setAll(const fullMatrix<double> &m) +{ + if (_r != m._r || _c != m._c ) + Msg::Fatal("fullMatrix size does not match"); + int N = _r * _c; + int stride = 1; + F77NAME(dcopy)(&N, m._data, &stride, _data, &stride); +} + +template<> +void fullMatrix<std::complex<double> >::setAll(const fullMatrix<std::complex<double > > &m) +{ + if (_r != m._r || _c != m._c ) + Msg::Fatal("fullMatrix size does not match"); + int N = _r * _c; + int stride = 1; + F77NAME(zcopy)(&N, m._data, &stride, _data, &stride); +} + template<> void fullMatrix<double>::scale(const double s) { diff --git a/Numeric/fullMatrix.h b/Numeric/fullMatrix.h index ed8637191f612bd8c49a51d4ed02a9bce7d7e4bb..34de2ae24b0b7d39d86c4d1591a63f0a26379a5c 100644 --- a/Numeric/fullMatrix.h +++ b/Numeric/fullMatrix.h @@ -277,10 +277,13 @@ class fullVector m.size() must be greater or equal to @f$ N @f$. */ - inline void setAll(const fullVector<scalar> &m) + void setAll(const fullVector<scalar> &m) +#if !defined(HAVE_BLAS) { for(int i = 0; i < _r; i++) _data[i] = m._data[i]; } +#endif + ; /** @param other A fullVector. @@ -404,11 +407,12 @@ class fullMatrix _own_data = false; _data = original._data + c_start * _r; } - fullMatrix(int r, int c) : _r(r), _c(c) + fullMatrix(int r, int c, bool init0 = true) : _r(r), _c(c) { _data = new scalar[_r * _c]; _own_data = true; - setAll(scalar(0.)); + if (init0) + setAll(scalar(0.)); } fullMatrix(int r, int c, double *data) : _r(r), _c(c), _data(data), _own_data(false) @@ -476,6 +480,10 @@ class fullMatrix return false; // no reallocation } void reshape(int nbRows, int nbColumns){ + if (nbRows == -1 && nbColumns != -1) + nbRows = _r * _c / nbColumns; + if (nbRows != -1 && nbColumns == -1) + nbColumns = _r * _c / nbRows; if (nbRows*nbColumns != size1()*size2()) Msg::Error("Invalid reshape, total number of entries must be equal"); _r = nbRows; @@ -606,12 +614,15 @@ class fullMatrix { for(int i = 0; i < _r * _c; i++) _data[i] = m; } - inline void setAll(const fullMatrix<scalar> &m) + void setAll(const fullMatrix<scalar> &m) +#if !defined(HAVE_BLAS) { if (_r != m._r || _c != m._c ) Msg::Fatal("fullMatrix size does not match"); for(int i = 0; i < _r * _c; i++) _data[i] = m._data[i]; } +#endif + ; void scale(const double s) #if !defined(HAVE_BLAS) {