Select Git revision
gpu_support.hpp
-
Matteo Cicuttin authoredMatteo Cicuttin authored
node.cpp 22.52 KiB
/*
* Copyright (C) 1999, 2002, 2003 Free Software Foundation, Inc.
*
* This file is part of GNU libmatheval
*
* GNU libmatheval is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the Free
* Software Foundation; either version 2, or (at your option) any later
* version.
*
* GNU libmatheval is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* for more details.
*
* You should have received a copy of the GNU General Public License along with
* program; see the file COPYING. If not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*/
#include <assert.h>
#include <stdarg.h>
#include "common.h"
#include "node.h"
Node *
node_create(char type,...)
{
Node *node; /* New node. */
va_list ap; /* Variable argument list. */
/*
* Allocate memory for node and initialize its type.
*/
node = XMALLOC(Node, 1);
node->type = type;
/*
* According to node type, initialize rest of the node from variable
* argument list.
*/
va_start(ap, type);
switch (node->type) {
case 'c':
/*
* Initialize constant value.
*/
node->data.constant = va_arg(ap, double);
break;
case 'v':
/*
* Remember pointer to symbol table record describing
* variable.
*/
node->data.variable = va_arg(ap, Record *);
break;
case 'f':
/*
* Remember pointer to symbol table record describing
* function and initialize function argument.
*/
node->data.function.record = va_arg(ap, Record *);
node->data.function.child = va_arg(ap, Node *);
break;
case 'u':
/*
* Initialize operator type and operand.
*/
node->data.un_op.operatorr = (char) va_arg(ap, int);
node->data.un_op.child = va_arg(ap, Node *);
break;
case 'b':
/*
* Initialize operator type and operands.
*/
node->data.un_op.operatorr = (char) va_arg(ap, int);
node->data.bin_op.left = va_arg(ap, Node *);
node->data.bin_op.right = va_arg(ap, Node *);
break;
default:
assert(0);
}
va_end(ap);
return node;
}
void
node_destroy(Node * node)
{
/*
* Skip if node already null (this may occur during simplification).
*/
if (!node)
return;
/*
* If necessary, destroy subtree rooted at node.
*/
switch (node->type) {
case 'c':
case 'v':
break;
case 'f':
node_destroy(node->data.function.child);
break;
case 'u':
node_destroy(node->data.un_op.child);
break;
case 'b':
node_destroy(node->data.bin_op.left);
node_destroy(node->data.bin_op.right);
break;
}
/*
* Deallocate memory used by node.
*/
XFREE(node);
}
Node *
node_copy(Node * node)
{
/*
* According to node type, create (deep) copy of subtree rooted at
* node.
*/
switch (node->type) {
case 'c':
return node_create('c', node->data.constant);
case 'v':
return node_create('v', node->data.variable);
case 'f':
return node_create('f', node->data.function.record, node_copy(node->data.function.child));
case 'u':
return node_create('u', node->data.un_op.operatorr, node_copy(node->data.un_op.child));
case 'b':
return node_create('b', node->data.bin_op.operatorr, node_copy(node->data.bin_op.left), node_copy(node->data.bin_op.right));
}
return NULL;
}
Node *
node_simplify(Node * node)
{
/*
* According to node type, apply further simplifications.
*/
switch (node->type) {
case 'c':
case 'v':
return node;
case 'f':
/*
* Simplify function argument and if constant evaluate function
* and replace function node with constant node (unless the
* function is Rand(x)).
*/
node->data.function.child = node_simplify(node->data.function.child);
if (node->data.function.child->type == 'c' &&
strcmp(node->data.function.record->name, "Rand")) {
double value = node_evaluate(node);
node_destroy(node);
return node_create('c', value);
}
else
return node;
case 'u':
/*
* Simplify unary operator operand and if constant apply
* operator and replace operator node with constant node.
*/
node->data.un_op.child = node_simplify(node->data.un_op.child);
if (node->data.un_op.operatorr == '-' && node->data.un_op.child->type == 'c') {
double value = node_evaluate(node);
node_destroy(node);
return node_create('c', value);
}
else
return node;
case 'b':
/*
* Simplify binary operator operands.
*/
node->data.bin_op.left = node_simplify(node->data.bin_op.left);
node->data.bin_op.right = node_simplify(node->data.bin_op.right);
/*
* If operands constant apply operator and replace operator
* node with constant node.
*/
if (node->data.bin_op.left->type == 'c' && node->data.bin_op.right->type == 'c') {
double value = node_evaluate(node);
node_destroy(node);
return node_create('c', value);
}
/*
* Eliminate 0 as neutral addition operand.
*/
else if (node->data.bin_op.operatorr == '+')
if (node->data.bin_op.left->type == 'c' && node->data.bin_op.left->data.constant == 0) {
Node *right;
right = node->data.bin_op.right;
node->data.bin_op.right = NULL;
node_destroy(node);
return right;
}
else if (node->data.bin_op.right->type == 'c' && node->data.bin_op.right->data.constant == 0) {
Node *left;
left = node->data.bin_op.left;
node->data.bin_op.left = NULL;
node_destroy(node);
return left;
}
else
return node;
/*
* Eliminate 0 as neutral subtraction right operand.
*/
else if (node->data.bin_op.operatorr == '-')
if (node->data.bin_op.right->type == 'c' && node->data.bin_op.right->data.constant == 0) {
Node *left;
left = node->data.bin_op.left;
node->data.bin_op.left = NULL;
node_destroy(node);
return left;
}
else
return node;
/*
* Eliminate 1 as neutral multiplication operand.
*/
else if (node->data.bin_op.operatorr == '*')
if (node->data.bin_op.left->type == 'c' && node->data.bin_op.left->data.constant == 1) {
Node *right;
right = node->data.bin_op.right;
node->data.bin_op.right = NULL;
node_destroy(node);
return right;
}
else if (node->data.bin_op.right->type == 'c' && node->data.bin_op.right->data.constant == 1) {
Node *left;
left = node->data.bin_op.left;
node->data.bin_op.left = NULL;
node_destroy(node);
return left;
}
else
return node;
/*
* Eliminate 1 as neutral division right operand.
*/
else if (node->data.bin_op.operatorr == '/')
if (node->data.bin_op.right->type == 'c' && node->data.bin_op.right->data.constant == 1) {
Node *left;
left = node->data.bin_op.left;
node->data.bin_op.left = NULL;
node_destroy(node);
return left;
}
else
return node;
/*
* Eliminate 0 and 1 as both left and right exponentiation
* operands.
*/
else if (node->data.bin_op.operatorr == '^')
if (node->data.bin_op.left->type == 'c' && node->data.bin_op.left->data.constant == 0) {
node_destroy(node);
return node_create('c', 0.0);
}
else if (node->data.bin_op.left->type == 'c' && node->data.bin_op.left->data.constant == 1) {
node_destroy(node);
return node_create('c', 1.0);
}
else if (node->data.bin_op.right->type == 'c' && node->data.bin_op.right->data.constant == 0) {
node_destroy(node);
return node_create('c', 1.0);
}
else if (node->data.bin_op.right->type == 'c' && node->data.bin_op.right->data.constant == 1) {
Node *left;
left = node->data.bin_op.left;
node->data.bin_op.left = NULL;
node_destroy(node);
return left;
}
else
return node;
else
return node;
}
return NULL;
}
double
node_evaluate(Node * node)
{
/*
* According to node type, evaluate subtree rooted at node.
*/
switch (node->type) {
case 'c':
return node->data.constant;
case 'v':
/*
* Variable values are used from symbol table.
*/
return node->data.variable->data.value;
case 'f':
/*
* Functions are evaluated through symbol table.
*/
return (*node->data.function.record->data.function) (node_evaluate(node->data.function.child));
case 'u':
/*
* Unary operator node is evaluated according to operator
* type.
*/
switch (node->data.un_op.operatorr) {
case '-':
return -node_evaluate(node->data.un_op.child);
}
case 'b':
/*
* Binary operator node is evaluated according to operator
* type.
*/
switch (node->data.un_op.operatorr) {
case '+':
return node_evaluate(node->data.bin_op.left) + node_evaluate(node->data.bin_op.right);
case '-':
return node_evaluate(node->data.bin_op.left) - node_evaluate(node->data.bin_op.right);
case '*':
return node_evaluate(node->data.bin_op.left) * node_evaluate(node->data.bin_op.right);
case '/':
return node_evaluate(node->data.bin_op.left) / node_evaluate(node->data.bin_op.right);
case '^':
return pow(node_evaluate(node->data.bin_op.left), node_evaluate(node->data.bin_op.right));
}
}
return 0;
}
Node *
node_derivative(Node * node, char *name, SymbolTable * symbol_table)
{
/*
* According to node type, derivative tree for subtree rooted at node
* is created.
*/
switch (node->type) {
case 'c':
/*
* Derivative of constant equals 0.
*/
return node_create('c', 0.0);
case 'v':
/*
* Derivative of variable equals 1 if variable is derivative
* variable, 0 otherwise.
*/
return node_create('c', (!strcmp(name, node->data.variable->name)) ? 1.0 : 0.0);
case 'f':
/*
* Apply rule of exponential function derivative.
*/
if (!strcmp(node->data.function.record->name, "Exp"))
return node_create('b', '*', node_derivative(node->data.function.child, name, symbol_table), node_copy(node));
/*
* Apply rule of logarithmic function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Log"))
return node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_copy(node->data.function.child));
/*
* Apply rule of base 10 logarithmic function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Log10"))
return node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table),
node_create('b', '*', node_create('c', log(10.)), node_copy(node->data.function.child)));
/*
* Apply rule of square root function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Sqrt"))
return node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('b', '*', node_create('c', 2.0), node_copy(node)));
/*
* Apply rule of sine function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Sin"))
return node_create('b', '*', node_derivative(node->data.function.child, name, symbol_table), node_create('f', symbol_table_lookup(symbol_table, "Cos"), node_copy(node->data.function.child)));
/*
* Apply rule of cosine function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Cos"))
return node_create('u', '-', node_create('b', '*', node_derivative(node->data.function.child, name, symbol_table), node_create('f', symbol_table_lookup(symbol_table, "Sin"), node_copy(node->data.function.child))));
/*
* Apply rule of tangent function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Tan"))
return node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('b', '^', node_create('f', symbol_table_lookup(symbol_table, "Cos"), node_copy(node->data.function.child)), node_create('c', 2.0)));
/*
* Apply rule of cotangent function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Ctan"))
return node_create('u', '-', node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('b', '^', node_create('f', symbol_table_lookup(symbol_table, "Sin"), node_copy(node->data.function.child)), node_create('c', 2.0))));
/*
* Apply rule of inverse sine function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Asin"))
return node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('f', symbol_table_lookup(symbol_table, "Sqrt"), node_create('b', '-', node_create('c', 1.0), node_create('b', '^', node_copy(node->data.function.child), node_create('c', 2.0)))));
/*
* Apply rule of inverse cosine function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Acos"))
return node_create('u', '-', node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('f', symbol_table_lookup(symbol_table, "Sqrt"), node_create('b', '-', node_create('c', 1.0), node_create('b', '^', node_copy(node->data.function.child), node_create('c', 2.0))))));
/*
* Apply rule of inverse tangent function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Atan"))
return node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('b', '+', node_create('c', 1.0), node_create('b', '^', node_copy(node->data.function.child), node_create('c', 2.0))));
/*
* Apply rule of inverse cotanget function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Actan"))
return node_create('u', '-', node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('b', '+', node_create('c', 1.0), node_create('b', '^', node_copy(node->data.function.child), node_create('c', 2.0)))));
/*
* Apply rule of hyperbolic sine function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Sinh"))
return node_create('b', '*', node_derivative(node->data.function.child, name, symbol_table), node_create('f', symbol_table_lookup(symbol_table, "Cosh"), node_copy(node->data.function.child)));
/*
* Apply rule of hyperbolic cosine function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Cosh"))
return node_create('b', '*', node_derivative(node->data.function.child, name, symbol_table), node_create('f', symbol_table_lookup(symbol_table, "Sinh"), node_copy(node->data.function.child)));
/*
* Apply rule of hyperbolic tangent function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Tanh"))
return node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('b', '^', node_create('f', symbol_table_lookup(symbol_table, "Cosh"), node_copy(node->data.function.child)), node_create('c', 2.0)));
/*
* Apply rule of hyperbolic cotangent function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Ctanh"))
return node_create('u', '-', node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('b', '^', node_create('f', symbol_table_lookup(symbol_table, "sinh"), node_copy(node->data.function.child)), node_create('c', 2.0))));
/*
* Apply rule of inverse hyperbolic sine function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Asinh"))
return node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('f', symbol_table_lookup(symbol_table, "Sqrt"), node_create('b', '-', node_create('c', 1.0), node_create('b', '^', node_copy(node->data.function.child), node_create('c', 2.0)))));
/*
* Apply rule of inverse hyperbolic cosine function
* derivative.
*/
else if (!strcmp(node->data.function.record->name, "Acosh"))
return node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('f', symbol_table_lookup(symbol_table, "Sqrt"), node_create('b', '-', node_create('b', '^', node_copy(node->data.function.child), node_create('c', 2.0)), node_create('c', 1.0))));
/*
* Apply rule of inverse hyperbolic tangent function
* derivative.
*/
else if (!strcmp(node->data.function.record->name, "Atanh"))
return node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('b', '-', node_create('c', 1.0), node_create('b', '^', node_copy(node->data.function.child), node_create('c', 2.0))));
/*
* Apply rule of inverse hyperbolic cotangent function
* derivative.
*/
else if (!strcmp(node->data.function.record->name, "Actanh"))
return node_create('b', '/', node_derivative(node->data.function.child, name, symbol_table), node_create('b', '-', node_create('b', '^', node_copy(node->data.function.child), node_create('c', 2.0)), node_create('c', 1.0)));
/*
* Apply rule of absolute value function derivative.
*/
else if (!strcmp(node->data.function.record->name, "Fabs"))
return node_create('b', '/', node_create('b', '*', node_derivative(node->data.function.child, name, symbol_table), node_copy(node->data.function.child)), node_create('f', symbol_table_lookup(symbol_table, "Sqrt"), node_create('b', '^', node_copy(node->data.function.child), node_create('c', 2.0))));
case 'u':
switch (node->data.un_op.operatorr) {
case '-':
/*
* Apply (-f)'=-f' derivative rule.
*/
return node_create('u', '-', node_derivative(node->data.un_op.child, name, symbol_table));
}
case 'b':
switch (node->data.bin_op.operatorr) {
case '+':
/*
* Apply (f+g)'=f'+g' derivative rule.
*/
return node_create('b', '+', node_derivative(node->data.bin_op.left, name, symbol_table), node_derivative(node->data.bin_op.right, name, symbol_table));
case '-':
/*
* Apply (f-g)'=f'-g' derivative rule.
*/
return node_create('b', '-', node_derivative(node->data.bin_op.left, name, symbol_table), node_derivative(node->data.bin_op.right, name, symbol_table));
case '*':
/*
* Apply (f*g)'=f'*g+f*g' derivative rule.
*/
return node_create('b', '+', node_create('b', '*', node_derivative(node->data.bin_op.left, name, symbol_table), node_copy(node->data.bin_op.right)), node_create('b', '*', node_copy(node->data.bin_op.left), node_derivative(node->data.bin_op.right, name, symbol_table)));
case '/':
/*
* Apply (f/g)'=(f'*g-f*g')/g^2 derivative rule.
*/
return node_create('b', '/', node_create('b', '-', node_create('b', '*', node_derivative(node->data.bin_op.left, name, symbol_table), node_copy(node->data.bin_op.right)), node_create('b', '*', node_copy(node->data.bin_op.left), node_derivative(node->data.bin_op.right, name, symbol_table))), node_create('b', '^', node_copy(node->data.bin_op.right), node_create('c', 2.0)));
case '^':
/*
* If right operand of exponentiation constant apply
* (f^n)'=n*f^(n-1)*f' derivative rule.
*/
if (node->data.bin_op.right->type == 'c')
return node_create('b', '*', node_create('b', '*', node_create('c', node->data.bin_op.right->data.constant), node_derivative(node->data.bin_op.left, name, symbol_table)), node_create('b', '^', node_copy(node->data.bin_op.left), node_create('c', node->data.bin_op.right->data.constant - 1.0)));
/*
* Otherwise, apply logaritmhic derivative rule:
* (log(f^g))'=(f^g)'/f^g =>
* (f^g)'=f^g*(log(f^g))'=f^g*(g*log(f))'
*/
else {
Node *log_node, *derivative;
log_node = node_create('b', '*', node_copy(node->data.bin_op.right), node_create('f', symbol_table_lookup(symbol_table, "Log"), node_copy(node->data.bin_op.left)));
derivative = node_create('b', '*', node_copy(node), node_derivative(log_node, name, symbol_table));
node_destroy(log_node);
return derivative;
}
}
}
return NULL;
}
int
node_calculate_length(Node * node)
{
char string[1024]; /* String representing constant node
* value. */
int length; /* Length of above string. */
/*
* According to node type, calculate length of string representing
* subtree rooted at node.
*/
switch (node->type) {
case 'c':
length = 0;
if (node->data.constant < 0)
length += 1;
sprintf(string, "%g", node->data.constant);
length += strlen(string);
if (node->data.constant < 0)
length += 1;
return length;
case 'v':
return strlen(node->data.variable->name);
case 'f':
return strlen(node->data.function.record->name) + 1 + node_calculate_length(node->data.function.child) + 1;
break;
case 'u':
return 1 + 1 + node_calculate_length(node->data.un_op.child) + 1;
case 'b':
return 1 + node_calculate_length(node->data.bin_op.left) + 1 + node_calculate_length(node->data.bin_op.right) + 1;
}
return 0;
}
void
node_write(Node * node, char *string)
{
/*
* According to node type, write subtree rooted at node to node
* string variable. Always use parenthesis to resolve operator
* precedence.
*/
switch (node->type) {
case 'c':
if (node->data.constant < 0) {
sprintf(string, "%c", '(');
string += strlen(string);
}
sprintf(string, "%g", node->data.constant);
string += strlen(string);
if (node->data.constant < 0)
sprintf(string, "%c", ')');
break;
case 'v':
sprintf(string, "%s", node->data.variable->name);
break;
case 'f':
sprintf(string, "%s%c", node->data.function.record->name, '(');
string += strlen(string);
node_write(node->data.function.child, string);
string += strlen(string);
sprintf(string, "%c", ')');
break;
case 'u':
sprintf(string, "%c", '(');
string += strlen(string);
sprintf(string, "%c", node->data.un_op.operatorr);
string += strlen(string);
node_write(node->data.un_op.child, string);
string += strlen(string);
sprintf(string, "%c", ')');
break;
case 'b':
sprintf(string, "%c", '(');
string += strlen(string);
node_write(node->data.bin_op.left, string);
string += strlen(string);
sprintf(string, "%c", node->data.bin_op.operatorr);
string += strlen(string);
node_write(node->data.bin_op.right, string);
string += strlen(string);
sprintf(string, "%c", ')');
break;
}
}