From daa8d016868f8502bd708d39fd39fba51141cdcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= Date: Thu, 2 Apr 2020 14:36:26 +0200 Subject: [PATCH] Complete rewrite of the equation normalization symbolic engine --- src/ExprNode.cc | 634 ++++++++++++++++------------------------------- src/ExprNode.hh | 41 ++- src/ModelTree.cc | 20 +- 3 files changed, 250 insertions(+), 445 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 66e111a1..218f0793 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -171,13 +171,6 @@ ExprNode::computeTemporaryTerms(map &reference_count, // Nothing to do for a terminal node } -pair -ExprNode::normalizeEquation(int var_endo, vector> &List_of_Op_RHS) const -{ - /* nothing to do */ - return { 0, nullptr }; -} - void ExprNode::writeOutput(ostream &output) const { @@ -497,11 +490,16 @@ NumConstNode::collectDynamicVariables(SymbolType type_arg, set> & { } -pair -NumConstNode::normalizeEquation(int var_endo, vector> &List_of_Op_RHS) const +void +NumConstNode::computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const { - /* return the numercial constant */ - return { 0, datatree.AddNonNegativeConstant(datatree.num_constants.get(id)) }; +} + +BinaryOpNode * +NumConstNode::normalizeEquationHelper(const set &contain_var, expr_t rhs) const +{ + cerr << "NumConstNode::normalizeEquation: this should not happen" << endl; + exit(EXIT_FAILURE); } expr_t @@ -1360,36 +1358,20 @@ VariableNode::collectDynamicVariables(SymbolType type_arg, set> & datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result); } -pair -VariableNode::normalizeEquation(int var_endo, vector> &List_of_Op_RHS) const +void +VariableNode::computeSubExprContainingVariable(int symb_id_arg, int lag_arg, set &contain_var) const { - /* The equation has to be normalized with respect to the current endogenous variable ascribed to it. - The two input arguments are : - - The ID of the endogenous variable associated to the equation. - - The list of operators and operands needed to normalize the equation* + if (symb_id == symb_id_arg && lag == lag_arg) + contain_var.insert(const_cast(this)); +} - The pair returned by NormalizeEquation is composed of - - a flag indicating if the expression returned contains (flag = 1) or not (flag = 0) - the endogenous variable related to the equation. - If the expression contains more than one occurence of the associated endogenous variable, - the flag is equal to 2. - - an expression equal to the RHS if flag = 0 and equal to NULL elsewhere - */ - if (get_type() == SymbolType::endogenous) - { - if (datatree.symbol_table.getTypeSpecificID(symb_id) == var_endo && lag == 0) - /* the endogenous variable */ - return { 1, nullptr }; - else - return { 0, datatree.AddVariable(symb_id, lag) }; - } - else - { - if (get_type() == SymbolType::parameter) - return { 0, datatree.AddVariable(symb_id, 0) }; - else - return { 0, datatree.AddVariable(symb_id, lag) }; - } +BinaryOpNode * +VariableNode::normalizeEquationHelper(const set &contain_var, expr_t rhs) const +{ + assert(contain_var.count(const_cast(this)) > 0); + + // This the LHS variable: we have finished the normalization + return datatree.AddEqual(const_cast(this), rhs); } expr_t @@ -3073,144 +3055,80 @@ UnaryOpNode::collectDynamicVariables(SymbolType type_arg, set> &r arg->collectDynamicVariables(type_arg, result); } -pair -UnaryOpNode::normalizeEquation(int var_endo, vector> &List_of_Op_RHS) const +void +UnaryOpNode::computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const { - pair res = arg->normalizeEquation(var_endo, List_of_Op_RHS); - int is_endogenous_present = res.first; - expr_t New_expr_t = res.second; + arg->computeSubExprContainingVariable(symb_id, lag, contain_var); + if (contain_var.count(arg) > 0) + contain_var.insert(const_cast(this)); +} - if (is_endogenous_present == 2) /* The equation could not be normalized and the process is given-up*/ - return { 2, nullptr }; - else if (is_endogenous_present) /* The argument of the function contains the current values of - the endogenous variable associated to the equation. - In order to normalized, we have to apply the invert function to the RHS.*/ +BinaryOpNode * +UnaryOpNode::normalizeEquationHelper(const set &contain_var, expr_t rhs) const +{ + assert(contain_var.count(const_cast(this)) > 0); + + switch (op_code) { - switch (op_code) - { - case UnaryOpcode::uminus: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::uminus), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::exp: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::log), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::log: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::exp), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::log10: - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::power), nullptr, datatree.AddNonNegativeConstant("10")); - return { 1, nullptr }; - case UnaryOpcode::cos: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::acos), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::sin: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::asin), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::tan: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::atan), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::acos: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::cos), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::asin: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::sin), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::atan: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::tan), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::cosh: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::acosh), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::sinh: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::asinh), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::tanh: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::atanh), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::acosh: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::cosh), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::asinh: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::sinh), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::atanh: - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::tanh), nullptr, nullptr); - return { 1, nullptr }; - case UnaryOpcode::sqrt: - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::power), nullptr, datatree.Two); - return { 1, nullptr }; - case UnaryOpcode::cbrt: - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::power), nullptr, datatree.Three); - return { 1, nullptr }; - case UnaryOpcode::abs: - return { 2, nullptr }; - case UnaryOpcode::sign: - return { 2, nullptr }; - case UnaryOpcode::steadyState: - return { 2, nullptr }; - case UnaryOpcode::erf: - return { 2, nullptr }; - default: - cerr << "Unary operator not handled during the normalization process" << endl; - return { 2, nullptr }; // Could not be normalized - } + case UnaryOpcode::uminus: + rhs = datatree.AddUMinus(rhs); + break; + case UnaryOpcode::exp: + rhs = datatree.AddLog(rhs); + break; + case UnaryOpcode::log: + rhs = datatree.AddExp(rhs); + break; + case UnaryOpcode::log10: + rhs = datatree.AddPower(datatree.AddNonNegativeConstant("10"), rhs); + break; + case UnaryOpcode::cos: + rhs = datatree.AddAcos(rhs); + break; + case UnaryOpcode::sin: + rhs = datatree.AddAsin(rhs); + break; + case UnaryOpcode::tan: + rhs = datatree.AddAtan(rhs); + break; + case UnaryOpcode::acos: + rhs = datatree.AddCos(rhs); + break; + case UnaryOpcode::asin: + rhs = datatree.AddSin(rhs); + break; + case UnaryOpcode::atan: + rhs = datatree.AddTan(rhs); + break; + case UnaryOpcode::cosh: + rhs = datatree.AddAcosh(rhs); + break; + case UnaryOpcode::sinh: + rhs = datatree.AddAsinh(rhs); + break; + case UnaryOpcode::tanh: + rhs = datatree.AddAtanh(rhs); + break; + case UnaryOpcode::acosh: + rhs = datatree.AddCosh(rhs); + break; + case UnaryOpcode::asinh: + rhs = datatree.AddSinh(rhs); + break; + case UnaryOpcode::atanh: + rhs = datatree.AddTanh(rhs); + break; + case UnaryOpcode::sqrt: + rhs = datatree.AddPower(rhs, datatree.Two); + break; + case UnaryOpcode::cbrt: + rhs = datatree.AddPower(rhs, datatree.Three); + break; + default: + throw NormalizationFailed(); } - else - { /* If the argument of the function do not contain the current values of the endogenous variable - related to the equation, the function with its argument is stored in the RHS*/ - switch (op_code) - { - case UnaryOpcode::uminus: - return { 0, datatree.AddUMinus(New_expr_t) }; - case UnaryOpcode::exp: - return { 0, datatree.AddExp(New_expr_t) }; - case UnaryOpcode::log: - return { 0, datatree.AddLog(New_expr_t) }; - case UnaryOpcode::log10: - return { 0, datatree.AddLog10(New_expr_t) }; - case UnaryOpcode::cos: - return { 0, datatree.AddCos(New_expr_t) }; - case UnaryOpcode::sin: - return { 0, datatree.AddSin(New_expr_t) }; - case UnaryOpcode::tan: - return { 0, datatree.AddTan(New_expr_t) }; - case UnaryOpcode::acos: - return { 0, datatree.AddAcos(New_expr_t) }; - case UnaryOpcode::asin: - return { 0, datatree.AddAsin(New_expr_t) }; - case UnaryOpcode::atan: - return { 0, datatree.AddAtan(New_expr_t) }; - case UnaryOpcode::cosh: - return { 0, datatree.AddCosh(New_expr_t) }; - case UnaryOpcode::sinh: - return { 0, datatree.AddSinh(New_expr_t) }; - case UnaryOpcode::tanh: - return { 0, datatree.AddTanh(New_expr_t) }; - case UnaryOpcode::acosh: - return { 0, datatree.AddAcosh(New_expr_t) }; - case UnaryOpcode::asinh: - return { 0, datatree.AddAsinh(New_expr_t) }; - case UnaryOpcode::atanh: - return { 0, datatree.AddAtanh(New_expr_t) }; - case UnaryOpcode::sqrt: - return { 0, datatree.AddSqrt(New_expr_t) }; - case UnaryOpcode::cbrt: - return { 0, datatree.AddCbrt(New_expr_t) }; - case UnaryOpcode::abs: - return { 0, datatree.AddAbs(New_expr_t) }; - case UnaryOpcode::sign: - return { 0, datatree.AddSign(New_expr_t) }; - case UnaryOpcode::steadyState: - return { 0, datatree.AddSteadyState(New_expr_t) }; - case UnaryOpcode::erf: - return { 0, datatree.AddErf(New_expr_t) }; - default: - cerr << "Unary operator not handled during the normalization process" << endl; - return { 2, nullptr }; // Could not be normalized - } - } - cerr << "UnaryOpNode::normalizeEquation: impossible case" << endl; - exit(EXIT_FAILURE); + + return arg->normalizeEquationHelper(contain_var, rhs); } expr_t @@ -4852,24 +4770,19 @@ BinaryOpNode::collectDynamicVariables(SymbolType type_arg, set> & expr_t BinaryOpNode::Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const { - temporary_terms_t temp; switch (op_type) { case 0: /*Unary Operator*/ switch (static_cast(op)) { case UnaryOpcode::uminus: - return (datatree.AddUMinus(arg1)); - break; + return datatree.AddUMinus(arg1); case UnaryOpcode::exp: - return (datatree.AddExp(arg1)); - break; + return datatree.AddExp(arg1); case UnaryOpcode::log: - return (datatree.AddLog(arg1)); - break; + return datatree.AddLog(arg1); case UnaryOpcode::log10: - return (datatree.AddLog10(arg1)); - break; + return datatree.AddLog10(arg1); default: cerr << "BinaryOpNode::Compute_RHS: case not handled"; exit(EXIT_FAILURE); @@ -4879,20 +4792,15 @@ BinaryOpNode::Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const switch (static_cast(op)) { case BinaryOpcode::plus: - return (datatree.AddPlus(arg1, arg2)); - break; + return datatree.AddPlus(arg1, arg2); case BinaryOpcode::minus: - return (datatree.AddMinus(arg1, arg2)); - break; + return datatree.AddMinus(arg1, arg2); case BinaryOpcode::times: - return (datatree.AddTimes(arg1, arg2)); - break; + return datatree.AddTimes(arg1, arg2); case BinaryOpcode::divide: - return (datatree.AddDivide(arg1, arg2)); - break; + return datatree.AddDivide(arg1, arg2); case BinaryOpcode::power: - return (datatree.AddPower(arg1, arg2)); - break; + return datatree.AddPower(arg1, arg2); default: cerr << "BinaryOpNode::Compute_RHS: case not handled"; exit(EXIT_FAILURE); @@ -4902,224 +4810,90 @@ BinaryOpNode::Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const return nullptr; } -pair -BinaryOpNode::normalizeEquation(int var_endo, vector> &List_of_Op_RHS) const +void +BinaryOpNode::computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const { - /* Checks if the current value of the endogenous variable related to the equation - is present in the arguments of the binary operator. */ - vector> List_of_Op_RHS1, List_of_Op_RHS2; - pair res = arg1->normalizeEquation(var_endo, List_of_Op_RHS1); - int is_endogenous_present_1 = res.first; - expr_t expr_t_1 = res.second; + arg1->computeSubExprContainingVariable(symb_id, lag, contain_var); + arg2->computeSubExprContainingVariable(symb_id, lag, contain_var); + if (contain_var.count(arg1) > 0 || contain_var.count(arg2) > 0) + contain_var.insert(const_cast(this)); +} - res = arg2->normalizeEquation(var_endo, List_of_Op_RHS2); - int is_endogenous_present_2 = res.first; - expr_t expr_t_2 = res.second; +BinaryOpNode * +BinaryOpNode::normalizeEquationHelper(const set &contain_var, expr_t rhs) const +{ + assert(contain_var.count(const_cast(this)) > 0); + + bool arg1_contains_var = contain_var.count(arg1) > 0; + bool arg2_contains_var = contain_var.count(arg2) > 0; + assert(arg1_contains_var || arg2_contains_var); + + if (arg1_contains_var && arg2_contains_var) + throw NormalizationFailed(); - /* If the two expressions contains the current value of the endogenous variable associated to the equation - the equation could not be normalized and the process is given-up.*/ - if (is_endogenous_present_1 == 2 || is_endogenous_present_2 == 2) - return { 2, nullptr }; - else if (is_endogenous_present_1 && is_endogenous_present_2) - return { 2, nullptr }; - else if (is_endogenous_present_1) /*If the current values of the endogenous variable associated to the equation - is present only in the first operand of the expression, we try to normalize the equation*/ - { - if (op_code == BinaryOpcode::equal) /* The end of the normalization process : - All the operations needed to normalize the equation are applied. */ - while (!List_of_Op_RHS1.empty()) - { - tuple it = List_of_Op_RHS1.back(); - List_of_Op_RHS1.pop_back(); - if (get<1>(it) && !get<2>(it)) /*Binary operator*/ - expr_t_2 = Compute_RHS(expr_t_2, static_cast(get<1>(it)), get<0>(it), 1); - else if (get<2>(it) && !get<1>(it)) /*Binary operator*/ - expr_t_2 = Compute_RHS(get<2>(it), expr_t_2, get<0>(it), 1); - else if (get<2>(it) && get<1>(it)) /*Binary operator*/ - expr_t_2 = Compute_RHS(get<1>(it), get<2>(it), get<0>(it), 1); - else /*Unary operator*/ - expr_t_2 = Compute_RHS(static_cast(expr_t_2), static_cast(get<1>(it)), get<0>(it), 0); - } - else - List_of_Op_RHS = List_of_Op_RHS1; - } - else if (is_endogenous_present_2) - { - if (op_code == BinaryOpcode::equal) - while (!List_of_Op_RHS2.empty()) - { - tuple it = List_of_Op_RHS2.back(); - List_of_Op_RHS2.pop_back(); - if (get<1>(it) && !get<2>(it)) /*Binary operator*/ - expr_t_1 = Compute_RHS(static_cast(expr_t_1), static_cast(get<1>(it)), get<0>(it), 1); - else if (get<2>(it) && !get<1>(it)) /*Binary operator*/ - expr_t_1 = Compute_RHS(static_cast(get<2>(it)), static_cast(expr_t_1), get<0>(it), 1); - else if (get<2>(it) && get<1>(it)) /*Binary operator*/ - expr_t_1 = Compute_RHS(get<1>(it), get<2>(it), get<0>(it), 1); - else - expr_t_1 = Compute_RHS(static_cast(expr_t_1), static_cast(get<1>(it)), get<0>(it), 0); - } - else - List_of_Op_RHS = List_of_Op_RHS2; - } switch (op_code) { case BinaryOpcode::plus: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddPlus(expr_t_1, expr_t_2) }; - else if (is_endogenous_present_1 && is_endogenous_present_2) - return { 2, nullptr }; - else if (!is_endogenous_present_1 && is_endogenous_present_2) - { - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::minus), expr_t_1, nullptr); - return { 1, expr_t_1 }; - } - else if (is_endogenous_present_1 && !is_endogenous_present_2) - { - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::minus), expr_t_2, nullptr); - return { 1, expr_t_2 }; - } + if (arg1_contains_var) + rhs = datatree.AddMinus(rhs, arg2); + else + rhs = datatree.AddMinus(rhs, arg1); break; case BinaryOpcode::minus: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddMinus(expr_t_1, expr_t_2) }; - else if (is_endogenous_present_1 && is_endogenous_present_2) - return { 2, nullptr }; - else if (!is_endogenous_present_1 && is_endogenous_present_2) - { - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::uminus), nullptr, nullptr); - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::minus), expr_t_1, nullptr); - return { 1, expr_t_1 }; - } - else if (is_endogenous_present_1 && !is_endogenous_present_2) - { - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::plus), expr_t_2, nullptr); - return { 1, datatree.AddUMinus(expr_t_2) }; - } + if (arg1_contains_var) + rhs = datatree.AddPlus(rhs, arg2); + else + rhs = datatree.AddMinus(arg1, rhs); break; case BinaryOpcode::times: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddTimes(expr_t_1, expr_t_2) }; - else if (!is_endogenous_present_1 && is_endogenous_present_2) - { - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::divide), expr_t_1, nullptr); - return { 1, expr_t_1 }; - } - else if (is_endogenous_present_1 && !is_endogenous_present_2) - { - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::divide), expr_t_2, nullptr); - return { 1, expr_t_2 }; - } + if (arg1_contains_var) + rhs = datatree.AddDivide(rhs, arg2); else - return { 2, nullptr }; + rhs = datatree.AddDivide(rhs, arg1); break; case BinaryOpcode::divide: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddDivide(expr_t_1, expr_t_2) }; - else if (!is_endogenous_present_1 && is_endogenous_present_2) - { - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::divide), nullptr, expr_t_1); - return { 1, expr_t_1 }; - } - else if (is_endogenous_present_1 && !is_endogenous_present_2) - { - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::times), expr_t_2, nullptr); - return { 1, expr_t_2 }; - } + if (arg1_contains_var) + rhs = datatree.AddTimes(rhs, arg2); else - return { 2, nullptr }; + rhs = datatree.AddDivide(arg1, rhs); break; case BinaryOpcode::power: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddPower(expr_t_1, expr_t_2) }; - else if (is_endogenous_present_1 && !is_endogenous_present_2) - { - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::power), datatree.AddDivide(datatree.One, expr_t_2), nullptr); - return { 1, nullptr }; - } - else if (!is_endogenous_present_1 && is_endogenous_present_2) - { - /* we have to nomalize a^f(X) = RHS */ - /* First computes the ln(RHS)*/ - List_of_Op_RHS.emplace_back(static_cast(UnaryOpcode::log), nullptr, nullptr); - /* Second computes f(X) = ln(RHS) / ln(a)*/ - List_of_Op_RHS.emplace_back(static_cast(BinaryOpcode::divide), nullptr, datatree.AddLog(expr_t_1)); - return { 1, nullptr }; - } + if (arg1_contains_var) + rhs = datatree.AddPower(rhs, datatree.AddDivide(datatree.One, arg2)); + else + // a^f(X)=rhs is normalized in f(X)=ln(rhs)/ln(a) + rhs = datatree.AddDivide(datatree.AddLog(rhs), datatree.AddLog(arg1)); break; case BinaryOpcode::equal: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - { - return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), datatree.AddMinus(expr_t_2, expr_t_1)) }; - } - else if (is_endogenous_present_1 && is_endogenous_present_2) - { - return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), datatree.Zero) }; - } - else if (!is_endogenous_present_1 && is_endogenous_present_2) - { - return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), /*datatree.AddUMinus(expr_t_1)*/ expr_t_1) }; - } - else if (is_endogenous_present_1 && !is_endogenous_present_2) - { - return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), expr_t_2) }; - } - break; - case BinaryOpcode::max: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddMax(expr_t_1, expr_t_2) }; - else - return { 2, nullptr }; - break; - case BinaryOpcode::min: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddMin(expr_t_1, expr_t_2) }; - else - return { 2, nullptr }; - break; - case BinaryOpcode::less: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddLess(expr_t_1, expr_t_2) }; - else - return { 2, nullptr }; - break; - case BinaryOpcode::greater: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddGreater(expr_t_1, expr_t_2) }; - else - return { 2, nullptr }; - break; - case BinaryOpcode::lessEqual: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddLessEqual(expr_t_1, expr_t_2) }; - else - return { 2, nullptr }; - break; - case BinaryOpcode::greaterEqual: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddGreaterEqual(expr_t_1, expr_t_2) }; - else - return { 2, nullptr }; - break; - case BinaryOpcode::equalEqual: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddEqualEqual(expr_t_1, expr_t_2) }; - else - return { 2, nullptr }; - break; - case BinaryOpcode::different: - if (!is_endogenous_present_1 && !is_endogenous_present_2) - return { 0, datatree.AddDifferent(expr_t_1, expr_t_2) }; - else - return { 2, nullptr }; - break; + cerr << "BinaryOpCode::normalizeEquationHelper: this case should not happen" << endl; + exit(EXIT_FAILURE); default: - cerr << "Binary operator not handled during the normalization process" << endl; - return { 2, nullptr }; // Could not be normalized + throw NormalizationFailed(); } - // Suppress GCC warning - cerr << "BinaryOpNode::normalizeEquation: impossible case" << endl; - exit(EXIT_FAILURE); + + if (arg1_contains_var) + return arg1->normalizeEquationHelper(contain_var, rhs); + else + return arg2->normalizeEquationHelper(contain_var, rhs); +} + +BinaryOpNode * +BinaryOpNode::normalizeEquation(int symb_id, int lag) const +{ + assert(op_code == BinaryOpcode::equal); + + set contain_var; + computeSubExprContainingVariable(symb_id, lag, contain_var); + + bool arg1_contains_var = contain_var.count(arg1) > 0; + bool arg2_contains_var = contain_var.count(arg2) > 0; + assert(arg1_contains_var || arg2_contains_var); + + if (arg1_contains_var && arg2_contains_var) + throw NormalizationFailed(); + + return arg1_contains_var ? arg1->normalizeEquationHelper(contain_var, arg2) + : arg2->normalizeEquationHelper(contain_var, arg1); } expr_t @@ -6477,22 +6251,20 @@ TrinaryOpNode::collectDynamicVariables(SymbolType type_arg, set> arg3->collectDynamicVariables(type_arg, result); } -pair -TrinaryOpNode::normalizeEquation(int var_endo, vector> &List_of_Op_RHS) const +void +TrinaryOpNode::computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const { - pair res = arg1->normalizeEquation(var_endo, List_of_Op_RHS); - bool is_endogenous_present_1 = res.first; - expr_t expr_t_1 = res.second; - res = arg2->normalizeEquation(var_endo, List_of_Op_RHS); - bool is_endogenous_present_2 = res.first; - expr_t expr_t_2 = res.second; - res = arg3->normalizeEquation(var_endo, List_of_Op_RHS); - bool is_endogenous_present_3 = res.first; - expr_t expr_t_3 = res.second; - if (!is_endogenous_present_1 && !is_endogenous_present_2 && !is_endogenous_present_3) - return { 0, datatree.AddNormcdf(expr_t_1, expr_t_2, expr_t_3) }; - else - return { 2, nullptr }; + arg1->computeSubExprContainingVariable(symb_id, lag, contain_var); + arg2->computeSubExprContainingVariable(symb_id, lag, contain_var); + arg3->computeSubExprContainingVariable(symb_id, lag, contain_var); + if (contain_var.count(arg1) > 0 || contain_var.count(arg2) > 0 || contain_var.count(arg3) > 0) + contain_var.insert(const_cast(this)); +} + +BinaryOpNode * +TrinaryOpNode::normalizeEquationHelper(const set &contain_var, expr_t rhs) const +{ + throw NormalizationFailed(); } expr_t @@ -7405,22 +7177,24 @@ AbstractExternalFunctionNode::getEndosAndMaxLags(map &model_endos_a argument->getEndosAndMaxLags(model_endos_and_lags); } -pair -AbstractExternalFunctionNode::normalizeEquation(int var_endo, vector> &List_of_Op_RHS) const + +void +AbstractExternalFunctionNode::computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const { - vector> V_arguments; - vector V_expr_t; - bool present = false; - for (auto argument : arguments) + bool var_present = false; + for (auto arg : arguments) { - V_arguments.emplace_back(argument->normalizeEquation(var_endo, List_of_Op_RHS)); - present = present || V_arguments[V_arguments.size()-1].first; - V_expr_t.push_back(V_arguments[V_arguments.size()-1].second); + arg->computeSubExprContainingVariable(symb_id, lag, contain_var); + var_present = var_present || contain_var.count(arg) > 0; } - if (!present) - return { 0, datatree.AddExternalFunction(symb_id, V_expr_t) }; - else - return { 2, nullptr }; + if (var_present) + contain_var.insert(const_cast(this)); +} + +BinaryOpNode * +AbstractExternalFunctionNode::normalizeEquationHelper(const set &contain_var, expr_t rhs) const +{ + throw NormalizationFailed(); } void @@ -8806,11 +8580,15 @@ VarExpectationNode::compile(ostream &CompileCode, unsigned int &instruction_numb exit(EXIT_FAILURE); } -pair -VarExpectationNode::normalizeEquation(int var_endo, vector> &List_of_Op_RHS) const +void +VarExpectationNode::computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const { - cerr << "VarExpectationNode::normalizeEquation not implemented." << endl; - exit(EXIT_FAILURE); +} + +BinaryOpNode * +VarExpectationNode::normalizeEquationHelper(const set &contain_var, expr_t rhs) const +{ + throw NormalizationFailed(); } expr_t @@ -9238,11 +9016,15 @@ PacExpectationNode::countDiffs() const return 0; } -pair -PacExpectationNode::normalizeEquation(int var_endo, vector> &List_of_Op_RHS) const +void +PacExpectationNode::computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const { - cerr << "PacExpectationNode::normalizeEquation not implemented." << endl; - exit(EXIT_FAILURE); +} + +BinaryOpNode * +PacExpectationNode::normalizeEquationHelper(const set &contain_var, expr_t rhs) const +{ + throw NormalizationFailed(); } expr_t diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 210c0360..9f8bcd15 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -431,8 +431,18 @@ public: */ // virtual void computeXrefs(set ¶m, set &endo, set &exo, set &exo_det) const = 0; virtual void computeXrefs(EquationInfo &ei) const = 0; - //! Try to normalize an equation linear in its endogenous variable - virtual pair normalizeEquation(int symb_id_endo, vector> &List_of_Op_RHS) const = 0; + + // Computes the set of all sub-expressions that contain the variable (symb_id, lag) + virtual void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const = 0; + + //! Helper for normalization of equations + /*! Normalize the equation this = rhs. + Must be called on a node containing the desired LHS variable. + Returns an equal node of the form: LHS variable = new RHS. + Must be given the set of all subexpressions that contain the desired LHS variable. + Throws a NormallizationFailed() exception if normalization is not possible. */ + virtual BinaryOpNode *normalizeEquationHelper(const set &contain_var, expr_t rhs) const = 0; + class NormalizationFailed {}; //! Returns the maximum lead of endogenous in this expression /*! Always returns a non-negative value */ @@ -744,7 +754,8 @@ public: void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; - pair normalizeEquation(int symb_id_endo, vector> &List_of_Op_RHS) const override; + void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; + BinaryOpNode *normalizeEquationHelper(const set &contain_var, expr_t rhs) const override; expr_t getChainRuleDerivative(int deriv_id, const map &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; @@ -827,7 +838,8 @@ public: expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; SymbolType get_type() const; - pair normalizeEquation(int symb_id_endo, vector> &List_of_Op_RHS) const override; + void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; + BinaryOpNode *normalizeEquationHelper(const set &contain_var, expr_t rhs) const override; expr_t getChainRuleDerivative(int deriv_id, const map &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; @@ -935,7 +947,8 @@ public: void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; - pair normalizeEquation(int symb_id_endo, vector> &List_of_Op_RHS) const override; + void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; + BinaryOpNode *normalizeEquationHelper(const set &contain_var, expr_t rhs) const override; expr_t getChainRuleDerivative(int deriv_id, const map &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; @@ -1047,7 +1060,11 @@ public: expr_t Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; - pair normalizeEquation(int symb_id_endo, vector> &List_of_Op_RHS) const override; + void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; + BinaryOpNode *normalizeEquationHelper(const set &contain_var, expr_t rhs) const override; + //! Try to normalize an equation with respect to a given dynamic variable. + /*! Should only be called on Equal nodes. The variable must appear in the equation. */ + BinaryOpNode *normalizeEquation(int symb_id, int lag) const; expr_t getChainRuleDerivative(int deriv_id, const map &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; @@ -1178,7 +1195,8 @@ public: void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; - pair normalizeEquation(int symb_id_endo, vector> &List_of_Op_RHS) const override; + void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; + BinaryOpNode *normalizeEquationHelper(const set &contain_var, expr_t rhs) const override; expr_t getChainRuleDerivative(int deriv_id, const map &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; @@ -1298,7 +1316,8 @@ public: void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override = 0; expr_t toStatic(DataTree &static_datatree) const override = 0; void computeXrefs(EquationInfo &ei) const override = 0; - pair normalizeEquation(int symb_id_endo, vector> &List_of_Op_RHS) const override; + void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; + BinaryOpNode *normalizeEquationHelper(const set &contain_var, expr_t rhs) const override; expr_t getChainRuleDerivative(int deriv_id, const map &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; @@ -1529,7 +1548,8 @@ public: expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substitutePacExpectation(const string &name, expr_t subexpr) override; - pair normalizeEquation(int symb_id_endo, vector> &List_of_Op_RHS) const override; + void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; + BinaryOpNode *normalizeEquationHelper(const set &contain_var, expr_t rhs) const override; void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, @@ -1610,7 +1630,8 @@ public: expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substitutePacExpectation(const string &name, expr_t subexpr) override; - pair normalizeEquation(int symb_id_endo, vector> &List_of_Op_RHS) const override; + void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; + BinaryOpNode *normalizeEquationHelper(const set &contain_var, expr_t rhs) const override; void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, diff --git a/src/ModelTree.cc b/src/ModelTree.cc index a25b1539..769665f9 100644 --- a/src/ModelTree.cc +++ b/src/ModelTree.cc @@ -661,7 +661,7 @@ ModelTree::equationTypeDetermination(const map, expr_t> &fi int var = variable_reordered[i]; expr_t lhs = equations[eq]->arg1; EquationType Equation_Simulation_Type = EquationType::solve; - pair res; + BinaryOpNode *normalized_eq = nullptr; if (auto it = first_order_endo_derivatives.find({ eq, var, 0 }); it != first_order_endo_derivatives.end()) { @@ -676,16 +676,18 @@ ModelTree::equationTypeDetermination(const map, expr_t> &fi derivative->collectEndogenous(result); bool variable_not_in_derivative = result.find({ var, 0 }) == result.end(); - vector> List_of_Op_RHS; - res = equations[eq]->normalizeEquation(var, List_of_Op_RHS); - - if (mfs == 2 && variable_not_in_derivative && res.second) - Equation_Simulation_Type = EquationType::evaluate_s; - else if (mfs == 3 && res.second) // The equation could be solved analytically - Equation_Simulation_Type = EquationType::evaluate_s; + try + { + normalized_eq = equations[eq]->normalizeEquation(symbol_table.getID(SymbolType::endogenous, var), 0); + if ((mfs == 2 && variable_not_in_derivative) || mfs == 3) + Equation_Simulation_Type = EquationType::evaluate_s; + } + catch (ExprNode::NormalizationFailed &e) + { + } } } - equation_type_and_normalized_equation[eq] = { Equation_Simulation_Type, dynamic_cast(res.second) }; + equation_type_and_normalized_equation[eq] = { Equation_Simulation_Type, normalized_eq }; } }