diff --git a/Geo/GFaceCompound.cpp b/Geo/GFaceCompound.cpp index 299ffd0fa5efac5c17baf82bfa5e93b17591c844..1e240a9697cb6692214df0e95811ed76ca26e0c8 100644 --- a/Geo/GFaceCompound.cpp +++ b/Geo/GFaceCompound.cpp @@ -2360,7 +2360,7 @@ void GFaceCompound::computeHessianMapping() const by(i) = vv[i]->y(); bz(i) = vv[i]->z(); } - ATAx.gemmWithAtranspose(A,A,1.,0.); + ATAx.gemm(A,A,1.,0., true, false); ATAy = ATAx; ATAz = ATAx; A.multWithATranspose(bx,1.,0.,ATbx); A.multWithATranspose(by,1.,0.,ATby); diff --git a/Mesh/meshMetric.cpp b/Mesh/meshMetric.cpp index 2ebb96843bb60d1ccce1d95330732984c7011cef..1bb47b88356e9f884f5596ae7f44592da3d340a1 100644 --- a/Mesh/meshMetric.cpp +++ b/Mesh/meshMetric.cpp @@ -285,7 +285,7 @@ void meshMetric::computeHessian() } b(i) = vals[vv[i]]; } - ATA.gemmWithAtranspose(A,A,1.,0.); + ATA.gemm(A,A,1.,0., true, false); A.multWithATranspose(b,1.,0.,ATb); ATA.luSolve(ATb,coeffs); const double &x = ver->x(), &y = ver->y(), &z = ver->z(); diff --git a/Numeric/fullMatrix.cpp b/Numeric/fullMatrix.cpp index b4325e0ef689f50373300ecf07721e86794f8488..bd42f7611834fb280dd54435a17049c459bc31fb 100644 --- a/Numeric/fullMatrix.cpp +++ b/Numeric/fullMatrix.cpp @@ -91,11 +91,11 @@ void fullMatrix<std::complex<double> >::mult(const fullMatrix<std::complex<doubl template<> void fullMatrix<double>::gemm(const fullMatrix<double> &a, const fullMatrix<double> &b, - double alpha, double beta) + double alpha, double beta, bool transposeA, bool transposeB) { - int M = size1(), N = size2(), K = a.size2(); + int M = size1(), N = size2(), K = transposeA ? a.size1() : a.size2(); int LDA = a.size1(), LDB = b.size1(), LDC = size1(); - F77NAME(dgemm)("N", "N", &M, &N, &K, &alpha, a._data, &LDA, b._data, &LDB, + F77NAME(dgemm)(transposeA ? "T" : "N", transposeB ? "T" :"N", &M, &N, &K, &alpha, a._data, &LDA, b._data, &LDB, &beta, _data, &LDC); } @@ -103,11 +103,13 @@ template<> void fullMatrix<std::complex<double> >::gemm(const fullMatrix<std::complex<double> > &a, const fullMatrix<std::complex<double> > &b, std::complex<double> alpha, - std::complex<double> beta) + std::complex<double> beta, + bool transposeA, + bool transposeB) { - int M = size1(), N = size2(), K = a.size2(); + int M = size1(), N = size2(), K = transposeA ? a.size1() : a.size2(); int LDA = a.size1(), LDB = b.size1(), LDC = size1(); - F77NAME(zgemm)("N", "N", &M, &N, &K, &alpha, a._data, &LDA, b._data, &LDB, + F77NAME(zgemm)(transposeA ? "T" : "N", transposeB ? "T" :"N", &M, &N, &K, &alpha, a._data, &LDA, b._data, &LDB, &beta, _data, &LDC); } @@ -175,29 +177,6 @@ void fullMatrix<double>::multWithATranspose(const fullVector<double> &x, double &beta, y._data, &INCY); } - -template<> -void fullMatrix<double>::gemmWithAtranspose(const fullMatrix<double> &a, const fullMatrix<double> &b, - double alpha, double beta) -{ - int M = size2(), N = size2(), K = a.size1(); - int LDA = a.size1(), LDB = b.size1(), LDC = size1(); - F77NAME(dgemm)("T", "N", &M, &N, &K, &alpha, a._data, &LDA, b._data, &LDB, - &beta, _data, &LDC); -} - -template<> -void fullMatrix<std::complex<double> >::gemmWithAtranspose(const fullMatrix<std::complex<double> > &a, - const fullMatrix<std::complex<double> > &b, - std::complex<double> alpha, - std::complex<double> beta) -{ - int M = size2(), N = size2(), K = a.size1(); - int LDA = a.size1(), LDB = b.size1(), LDC = size1(); - F77NAME(zgemm)("T", "N", &M, &N, &K, &alpha, a._data, &LDA, b._data, &LDB, - &beta, _data, &LDC); -} - #endif diff --git a/Numeric/fullMatrix.h b/Numeric/fullMatrix.h index 74debb530eefa04985ea788cf681b3cf529af794..ed8637191f612bd8c49a51d4ed02a9bce7d7e4bb 100644 --- a/Numeric/fullMatrix.h +++ b/Numeric/fullMatrix.h @@ -595,10 +595,10 @@ class fullMatrix add(temp); } void gemm(const fullMatrix<scalar> &a, const fullMatrix<scalar> &b, - scalar alpha=1., scalar beta=1.) + scalar alpha=1., scalar beta=1., bool transposeA = false, bool transposeB = false) #if !defined(HAVE_BLAS) { - gemm_naive(a,b,alpha,beta); + gemm_naive(transposeA ? a.transpose() : a, transposeB ? b.transpose() : b, alpha, beta); } #endif ; @@ -809,15 +809,7 @@ class fullMatrix _data[cind+i] = x(i); } - void gemmWithAtranspose(const fullMatrix<scalar> &a, const fullMatrix<scalar> &b, - scalar alpha=1., scalar beta=1.) -#if !defined(HAVE_BLAS) - { - Msg::Error("gemmWithAtranspose is only available with blas. If blas is not " - "installed please transpose a before used gemm_naive"); - } -#endif - ; - + bool getOwnData() {return _own_data;}; + void setOwnData(bool ownData) {_own_data = ownData;}; }; #endif