Select Git revision
ANNUtils.cpp
ANNUtils.cpp 6.51 KiB
//
// C++ Interface: ANN
//
//
// Author: <V.-D. Nguyen>, (C) 2020
//
// Copyright: See COPYING file that comes with this distribution
//
//
#include "ANNUtils.h"
DenseLayer::DenseLayer(int numIn, int numOut): numInput(numIn), numOutput(numOut), activationFunc(NULL),W(numIn, numOut),b(1,numOut), WT(NULL)
{
}
DenseLayer::DenseLayer(const DenseLayer& src):numInput(src.numInput), numOutput(src.numOutput),W(src.W),b(src.b), WT(NULL)
{
activationFunc = NULL;
if (src.activationFunc != NULL)
{
activationFunc = src.activationFunc->clone();
}
}
DenseLayer::~DenseLayer()
{
if (WT) delete WT;
WT = NULL;
if (activationFunc!=NULL) delete activationFunc;
activationFunc = NULL;
}
void DenseLayer::setWeights(int i, int j, double val)
{
if (i > W.size1() - 1 || j > W.size2() -1)
{
Msg::Error("indexes %d %d exeed the matrix dimensions %d %d",i,j,W.size1(),W.size2());
Msg::Exit(0);
}
else
{
W(i,j) = val;
}
}
void DenseLayer::setBias(int i, double val)
{
if (i > b.size2() -1)
{
Msg::Error("indexes %d exeed the vector dimension %d",i,W.size2());
Msg::Exit(0);
}
else
{
b(0,i) = val;
}
}
void DenseLayer::print_infos() const
{
printf("general dense layer: nbInput = %d, nbOutput=%d\n",numInput,numOutput);
}
DenseLayer* DenseLayer::clone() const
{
return new DenseLayer(*this);
}
void DenseLayer::predict(const fullMatrix<double>& y0, fullMatrix<double>& y1, bool stiff, fullMatrix<double>* Dy1Dy0) const
{
// x1 = y0*W + b and y1 = activationFunc(x1)
static fullMatrix<double> x1;
if (x1.size1() != 1 || x1.size2() != numOutput)
{
x1.resize(1,numOutput);
}
//
y0.mult(W,x1);
x1.add(b);
//
if (y1.size1() != 1 || y1.size2() != numOutput)
{
y1.resize(1,numOutput);
}
for (int i=0; i< numOutput; i++)
{
y1(0,i) = activationFunc->getVal(x1(0,i));
}
if (stiff)
{
if (Dy1Dy0->size1() != numOutput || Dy1Dy0->size2() != numInput)
{
Dy1Dy0->resize(numOutput,numInput);
}
static fullMatrix<double> Dy1Dx1;
if (Dy1Dx1.size1() != numOutput || Dy1Dx1.size2() != numOutput)
{
Dy1Dx1.resize(numOutput,numOutput);
}
Dy1Dx1.setAll(0.);
for (int i=0; i< numOutput; i++)
{
Dy1Dx1(i,i) = activationFunc->getDiff(x1(0,i));
}
if (WT == NULL)
{
WT = new fullMatrix<double>(W.transpose());
};
Dy1Dx1.mult(*WT,*Dy1Dy0);
}
};
ReluDenseLayer::ReluDenseLayer(int numIn, int numOut): DenseLayer(numIn,numOut)
{
if (activationFunc!=NULL) delete activationFunc;
activationFunc = new ReluActivationFunction();
}
ReluDenseLayer::~ReluDenseLayer()
{
};
void ReluDenseLayer::print_infos() const
{
printf("relu dense layer: nbInput = %d, nbOutput=%d\n",numInput,numOutput);
}
DenseLayer* ReluDenseLayer::clone() const
{
return new ReluDenseLayer(*this);
}
LeakyReluDenseLayer::LeakyReluDenseLayer(int numIn, int numOut, double a): DenseLayer(numIn,numOut)
{
if (activationFunc!=NULL) delete activationFunc;
activationFunc = new LeakyReluActivationFunction(a);
}
LeakyReluDenseLayer::~LeakyReluDenseLayer()
{
};
void LeakyReluDenseLayer::print_infos() const
{
printf("leaky relu dense layer: nbInput = %d, nbOutput=%d\n",numInput,numOutput);
}
DenseLayer* LeakyReluDenseLayer::clone() const
{
return new LeakyReluDenseLayer(*this);
}
LinearDenseLayer::LinearDenseLayer(int numIn, int numOut): DenseLayer(numIn,numOut)
{
if (activationFunc!=NULL) delete activationFunc;
activationFunc = new LinearActivationFunction();
}
LinearDenseLayer::~LinearDenseLayer()
{
};
void LinearDenseLayer::print_infos() const
{
printf("linear dense layer: nbInput = %d, nbOutput=%d\n",numInput,numOutput);
}
DenseLayer* LinearDenseLayer::clone() const
{
return new LinearDenseLayer(*this);
}
VecDouble::VecDouble():vals(){};
VecDouble::~VecDouble(){}
void VecDouble::set(double v)
{
vals.push_back(v);
}
void VecDouble::clear()
{
vals.clear();
};
void VecDouble::printVal()
{
printf("vals: [");
for (int i=0; i< vals.size(); i++)
{
printf(" %.5f",vals[i]);
}
printf("]\n");
}
ArtificialNN::ArtificialNN(int nbL): _allLayers(nbL,NULL),_numLayers(nbL){}
ArtificialNN::~ArtificialNN()
{
for (int i=0; i< _numLayers; i++)
{
if (_allLayers[i]) delete _allLayers[i];
}
};
void ArtificialNN::addInputName(const std::string iv)
{
_inputNames.push_back(iv);
}
void ArtificialNN::addOutputName(const std::string iv)
{
_outputNames.push_back(iv);
}
void ArtificialNN::addLayer(int index, const DenseLayer& layer)
{
if (index > _numLayers-1)
{
Msg::Error("index %d exceeds size of layers = %d",index,_numLayers);
Msg::Exit(0);
}
else
{
if (_allLayers[index] != NULL) delete _allLayers[index];
_allLayers[index] = layer.clone();
Msg::Info("add layer: %d",index);
_allLayers[index]->print_infos();
}
};
void ArtificialNN::print_infos() const
{
printf("ANN with %d layers: \n",_numLayers);
for (int i=0; i< _allLayers.size(); i++)
{
printf("layer %d:",i);
if (_allLayers[i] !=NULL)
{
_allLayers[i]->print_infos();
}
else
{
Msg::Error("layer %d is NULL !!!");
}
}
}
int ArtificialNN::getNumInputs() const {return _allLayers[0]->numInput;}
int ArtificialNN::getNumOutputs() const {return _allLayers[_numLayers-1]->numOutput;};
void ArtificialNN::predict(const fullMatrix<double>& yIn, fullMatrix<double>& yOut, bool stiff, fullMatrix<double>* DyoutDyInt) const
{
// first layer
int numInt = _allLayers[0]->numInput;
if (yIn.size2() != numInt)
{
Msg::Error("input data is not correct, size of input = %d",yIn.size2());
Msg::Exit(0);
}
yOut = yIn;
static std::vector<fullMatrix<double> > all_Dy;
if (all_Dy.size() != _numLayers && stiff)
{
all_Dy.resize(_numLayers);
}
for (int i=0; i< _numLayers; i++)
{
static fullMatrix<double> y1;
_allLayers[i]->predict(yOut,y1,stiff,&all_Dy[i]);
yOut = y1;
};
if (stiff)
{
(*DyoutDyInt) = all_Dy[_numLayers-1];
for (int i= _numLayers-2; i >-1; i--)
{
static fullMatrix<double> temp;
if (temp.size1() != DyoutDyInt->size1()|| temp.size2() != all_Dy[i].size2())
{
temp.resize(DyoutDyInt->size1(),all_Dy[i].size2());
}
DyoutDyInt->mult(all_Dy[i],temp);
(*DyoutDyInt) = temp;
}
}
};
void ArtificialNN::test_predict(const VecDouble& v) const
{
fullMatrix<double> x(1,v.vals.size());
for (int i=0; i< v.vals.size(); i++)
{
x(0,i) = v.vals[i];
}
fullMatrix<double> y, DyDx;
predict(x,y,true,&DyDx);
//
x.print("input");
y.print("prediction");
DyDx.print("derivative");
};