/* * Copyright (C) 2003-2010 Dynare Team * * This file is part of Dynare. * * Dynare is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * Dynare is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with Dynare. If not, see . */ #include #include #include #include "DataTree.hh" DataTree::DataTree(SymbolTable &symbol_table_arg, NumericalConstants &num_constants_arg, ExternalFunctionsTable &external_functions_table_arg) : symbol_table(symbol_table_arg), num_constants(num_constants_arg), external_functions_table(external_functions_table_arg), node_counter(0) { Zero = AddNumConstant("0"); One = AddNumConstant("1"); Two = AddNumConstant("2"); MinusOne = AddUMinus(One); NaN = AddNumConstant("NaN"); Infinity = AddNumConstant("Inf"); MinusInfinity = AddUMinus(Infinity); Pi = AddNumConstant("3.141592653589793"); } DataTree::~DataTree() { for (node_list_t::iterator it = node_list.begin(); it != node_list.end(); it++) delete *it; } expr_t DataTree::AddNumConstant(const string &value) { int id = num_constants.AddConstant(value); num_const_node_map_t::iterator it = num_const_node_map.find(id); if (it != num_const_node_map.end()) return it->second; else return new NumConstNode(*this, id); } VariableNode * DataTree::AddVariableInternal(int symb_id, int lag) { variable_node_map_t::iterator it = variable_node_map.find(make_pair(symb_id, lag)); if (it != variable_node_map.end()) return it->second; else return new VariableNode(*this, symb_id, lag); } VariableNode * DataTree::AddVariable(int symb_id, int lag) { assert(lag == 0); return AddVariableInternal(symb_id, lag); } expr_t DataTree::AddPlus(expr_t iArg1, expr_t iArg2) { if (iArg1 != Zero && iArg2 != Zero) { // Simplify x+(-y) in x-y UnaryOpNode *uarg2 = dynamic_cast(iArg2); if (uarg2 != NULL && uarg2->get_op_code() == oUminus) return AddMinus(iArg1, uarg2->get_arg()); // To treat commutativity of "+" // Nodes iArg1 and iArg2 are sorted by index if (iArg1->idx > iArg2->idx) { expr_t tmp = iArg1; iArg1 = iArg2; iArg2 = tmp; } return AddBinaryOp(iArg1, oPlus, iArg2); } else if (iArg1 != Zero) return iArg1; else if (iArg2 != Zero) return iArg2; else return Zero; } expr_t DataTree::AddMinus(expr_t iArg1, expr_t iArg2) { if (iArg2 == Zero) return iArg1; if (iArg1 == Zero) return AddUMinus(iArg2); if (iArg1 == iArg2) return Zero; return AddBinaryOp(iArg1, oMinus, iArg2); } expr_t DataTree::AddUMinus(expr_t iArg1) { if (iArg1 != Zero) { // Simplify -(-x) in x UnaryOpNode *uarg = dynamic_cast(iArg1); if (uarg != NULL && uarg->get_op_code() == oUminus) return uarg->get_arg(); return AddUnaryOp(oUminus, iArg1); } else return Zero; } expr_t DataTree::AddTimes(expr_t iArg1, expr_t iArg2) { if (iArg1 == MinusOne) return AddUMinus(iArg2); else if (iArg2 == MinusOne) return AddUMinus(iArg1); else if (iArg1 != Zero && iArg1 != One && iArg2 != Zero && iArg2 != One) { // To treat commutativity of "*" // Nodes iArg1 and iArg2 are sorted by index if (iArg1->idx > iArg2->idx) { expr_t tmp = iArg1; iArg1 = iArg2; iArg2 = tmp; } return AddBinaryOp(iArg1, oTimes, iArg2); } else if (iArg1 != Zero && iArg1 != One && iArg2 == One) return iArg1; else if (iArg2 != Zero && iArg2 != One && iArg1 == One) return iArg2; else if (iArg2 == One && iArg1 == One) return One; else return Zero; } expr_t DataTree::AddDivide(expr_t iArg1, expr_t iArg2) { if (iArg2 == One) return iArg1; // This test should be before the next two, otherwise 0/0 won't be rejected if (iArg2 == Zero) { cerr << "ERROR: Division by zero!" << endl; exit(EXIT_FAILURE); } if (iArg1 == Zero) return Zero; if (iArg1 == iArg2) return One; return AddBinaryOp(iArg1, oDivide, iArg2); } expr_t DataTree::AddLess(expr_t iArg1, expr_t iArg2) { return AddBinaryOp(iArg1, oLess, iArg2); } expr_t DataTree::AddGreater(expr_t iArg1, expr_t iArg2) { return AddBinaryOp(iArg1, oGreater, iArg2); } expr_t DataTree::AddLessEqual(expr_t iArg1, expr_t iArg2) { return AddBinaryOp(iArg1, oLessEqual, iArg2); } expr_t DataTree::AddGreaterEqual(expr_t iArg1, expr_t iArg2) { return AddBinaryOp(iArg1, oGreaterEqual, iArg2); } expr_t DataTree::AddEqualEqual(expr_t iArg1, expr_t iArg2) { return AddBinaryOp(iArg1, oEqualEqual, iArg2); } expr_t DataTree::AddDifferent(expr_t iArg1, expr_t iArg2) { return AddBinaryOp(iArg1, oDifferent, iArg2); } expr_t DataTree::AddPower(expr_t iArg1, expr_t iArg2) { if (iArg1 != Zero && iArg2 != Zero && iArg2 != One) return AddBinaryOp(iArg1, oPower, iArg2); else if (iArg2 == One) return iArg1; else if (iArg2 == Zero) return One; else return Zero; } expr_t DataTree::AddExp(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oExp, iArg1); else return One; } expr_t DataTree::AddLog(expr_t iArg1) { if (iArg1 != Zero && iArg1 != One) return AddUnaryOp(oLog, iArg1); else if (iArg1 == One) return Zero; else { cerr << "ERROR: log(0) not defined!" << endl; exit(EXIT_FAILURE); } } expr_t DataTree::AddLog10(expr_t iArg1) { if (iArg1 != Zero && iArg1 != One) return AddUnaryOp(oLog10, iArg1); else if (iArg1 == One) return Zero; else { cerr << "ERROR: log10(0) not defined!" << endl; exit(EXIT_FAILURE); } } expr_t DataTree::AddCos(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oCos, iArg1); else return One; } expr_t DataTree::AddSin(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oSin, iArg1); else return Zero; } expr_t DataTree::AddTan(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oTan, iArg1); else return Zero; } expr_t DataTree::AddAcos(expr_t iArg1) { if (iArg1 != One) return AddUnaryOp(oAcos, iArg1); else return Zero; } expr_t DataTree::AddAsin(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oAsin, iArg1); else return Zero; } expr_t DataTree::AddAtan(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oAtan, iArg1); else return Zero; } expr_t DataTree::AddCosh(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oCosh, iArg1); else return One; } expr_t DataTree::AddSinh(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oSinh, iArg1); else return Zero; } expr_t DataTree::AddTanh(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oTanh, iArg1); else return Zero; } expr_t DataTree::AddAcosh(expr_t iArg1) { if (iArg1 != One) return AddUnaryOp(oAcosh, iArg1); else return Zero; } expr_t DataTree::AddAsinh(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oAsinh, iArg1); else return Zero; } expr_t DataTree::AddAtanh(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oAtanh, iArg1); else return Zero; } expr_t DataTree::AddSqrt(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oSqrt, iArg1); else return Zero; } expr_t DataTree::AddErf(expr_t iArg1) { if (iArg1 != Zero) return AddUnaryOp(oErf, iArg1); else return Zero; } expr_t DataTree::AddMax(expr_t iArg1, expr_t iArg2) { return AddBinaryOp(iArg1, oMax, iArg2); } expr_t DataTree::AddMin(expr_t iArg1, expr_t iArg2) { return AddBinaryOp(iArg1, oMin, iArg2); } expr_t DataTree::AddNormcdf(expr_t iArg1, expr_t iArg2, expr_t iArg3) { return AddTrinaryOp(iArg1, oNormcdf, iArg2, iArg3); } expr_t DataTree::AddNormpdf(expr_t iArg1, expr_t iArg2, expr_t iArg3) { return AddTrinaryOp(iArg1, oNormpdf, iArg2, iArg3); } expr_t DataTree::AddSteadyState(expr_t iArg1) { return AddUnaryOp(oSteadyState, iArg1); } expr_t DataTree::AddExpectation(int iArg1, expr_t iArg2) { return AddUnaryOp(oExpectation, iArg2, iArg1); } expr_t DataTree::AddExpectation(string *iArg1, expr_t iArg2) { return AddUnaryOp(oExpectation, iArg2, 0, *iArg1); } expr_t DataTree::AddEqual(expr_t iArg1, expr_t iArg2) { return AddBinaryOp(iArg1, oEqual, iArg2); } void DataTree::AddLocalVariable(int symb_id, expr_t value) throw (LocalVariableException) { assert(symbol_table.getType(symb_id) == eModelLocalVariable); // Throw an exception if symbol already declared map::iterator it = local_variables_table.find(symb_id); if (it != local_variables_table.end()) throw LocalVariableException(symbol_table.getName(symb_id)); local_variables_table[symb_id] = value; } expr_t DataTree::AddExternalFunction(int symb_id, const vector &arguments) { assert(symbol_table.getType(symb_id) == eExternalFunction); external_function_node_map_t::iterator it = external_function_node_map.find(make_pair(arguments, symb_id)); if (it != external_function_node_map.end()) return it->second; return new ExternalFunctionNode(*this, symb_id, arguments); } expr_t DataTree::AddFirstDerivExternalFunctionNode(int top_level_symb_id, const vector &arguments, int input_index) { assert(symbol_table.getType(top_level_symb_id) == eExternalFunction); first_deriv_external_function_node_map_t::iterator it = first_deriv_external_function_node_map.find(make_pair(make_pair(arguments, input_index), top_level_symb_id)); if (it != first_deriv_external_function_node_map.end()) return it->second; return new FirstDerivExternalFunctionNode(*this, top_level_symb_id, arguments, input_index); } expr_t DataTree::AddSecondDerivExternalFunctionNode(int top_level_symb_id, const vector &arguments, int input_index1, int input_index2) { assert(symbol_table.getType(top_level_symb_id) == eExternalFunction); second_deriv_external_function_node_map_t::iterator it = second_deriv_external_function_node_map.find(make_pair(make_pair(arguments, make_pair(input_index1, input_index2)), top_level_symb_id)); if (it != second_deriv_external_function_node_map.end()) return it->second; return new SecondDerivExternalFunctionNode(*this, top_level_symb_id, arguments, input_index1, input_index2); } bool DataTree::isSymbolUsed(int symb_id) const { for (variable_node_map_t::const_iterator it = variable_node_map.begin(); it != variable_node_map.end(); it++) if (it->first.first == symb_id) return true; if (local_variables_table.find(symb_id) != local_variables_table.end()) return true; return false; } int DataTree::getDerivID(int symb_id, int lag) const throw (UnknownDerivIDException) { throw UnknownDerivIDException(); } int DataTree::getDynJacobianCol(int deriv_id) const throw (UnknownDerivIDException) { throw UnknownDerivIDException(); } bool DataTree::isUnaryOpUsed(UnaryOpcode opcode) const { for (unary_op_node_map_t::const_iterator it = unary_op_node_map.begin(); it != unary_op_node_map.end(); it++) if (it->first.second == opcode) return true; return false; } bool DataTree::isBinaryOpUsed(BinaryOpcode opcode) const { for (binary_op_node_map_t::const_iterator it = binary_op_node_map.begin(); it != binary_op_node_map.end(); it++) if (it->first.second == opcode) return true; return false; } bool DataTree::isTrinaryOpUsed(TrinaryOpcode opcode) const { for (trinary_op_node_map_t::const_iterator it = trinary_op_node_map.begin(); it != trinary_op_node_map.end(); it++) if (it->first.second == opcode) return true; return false; } bool DataTree::isExternalFunctionUsed(int symb_id) const { for (external_function_node_map_t::const_iterator it = external_function_node_map.begin(); it != external_function_node_map.end(); it++) if (it->first.second == symb_id) return true; return false; } bool DataTree::isFirstDerivExternalFunctionUsed(int symb_id) const { for (first_deriv_external_function_node_map_t::const_iterator it = first_deriv_external_function_node_map.begin(); it != first_deriv_external_function_node_map.end(); it++) if (it->first.second == symb_id) return true; return false; } bool DataTree::isSecondDerivExternalFunctionUsed(int symb_id) const { for (second_deriv_external_function_node_map_t::const_iterator it = second_deriv_external_function_node_map.begin(); it != second_deriv_external_function_node_map.end(); it++) if (it->first.second == symb_id) return true; return false; }