preprocessor: fix bug in adl implementation

issue#70
houtanb 2017-06-23 18:01:07 +02:00
parent 0d31c7a893
commit 2deb4b42fb
6 changed files with 38 additions and 27 deletions

View File

@ -264,10 +264,9 @@ DataTree::AddDiff(expr_t iArg1)
}
expr_t
DataTree::AddAdl(expr_t iArg1, string &name, expr_t iArg2)
DataTree::AddAdl(expr_t iArg1, const string &name, expr_t iArg2)
{
expr_t adlnode = AddBinaryOp(iArg1, oAdl, iArg2);
adl_map[adlnode] = new string(name);
expr_t adlnode = AddBinaryOp(iArg1, oAdl, iArg2, 0, string(name));
return adlnode;
}

View File

@ -57,10 +57,6 @@ protected:
//! A reference to the external functions table
ExternalFunctionsTable &external_functions_table;
//! A reference to the adl table
typedef map<expr_t, string *> adl_map_t;
adl_map_t adl_map;
typedef map<int, NumConstNode *> num_const_node_map_t;
num_const_node_map_t num_const_node_map;
//! Pair (symbol_id, lag) used as key
@ -70,7 +66,7 @@ protected:
typedef map<pair<pair<expr_t, UnaryOpcode>, pair<int, pair<int, int> > >, UnaryOpNode *> unary_op_node_map_t;
unary_op_node_map_t unary_op_node_map;
//! Pair( Pair( Pair(arg1, arg2), order of Power Derivative), opCode)
typedef map<pair<pair<pair<expr_t, expr_t>, int>, BinaryOpcode>, BinaryOpNode *> binary_op_node_map_t;
typedef map<pair<pair<pair<expr_t, expr_t>, pair<int, string> >, BinaryOpcode>, BinaryOpNode *> binary_op_node_map_t;
binary_op_node_map_t binary_op_node_map;
typedef map<pair<pair<pair<expr_t, expr_t>, expr_t>, TrinaryOpcode>, TrinaryOpNode *> trinary_op_node_map_t;
trinary_op_node_map_t trinary_op_node_map;
@ -108,7 +104,7 @@ private:
inline expr_t AddPossiblyNegativeConstant(double val);
inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0);
inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0);
inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0, const string &adlparam = "");
inline expr_t AddTrinaryOp(expr_t arg1, TrinaryOpcode op_code, expr_t arg2, expr_t arg3);
public:
@ -169,7 +165,7 @@ public:
//! Adds "diff(arg)" to model tree
expr_t AddDiff(expr_t iArg1);
//! Adds "adl(arg1, arg2)" to model tree
expr_t AddAdl(expr_t iArg1, string &name, expr_t iArg2);
expr_t AddAdl(expr_t iArg1, const string &name, expr_t iArg2);
//! Adds "exp(arg)" to model tree
expr_t AddExp(expr_t iArg1);
//! Adds "log(arg)" to model tree
@ -346,9 +342,11 @@ DataTree::AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set, int
}
inline expr_t
DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder)
DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder, const string &adlparam)
{
binary_op_node_map_t::iterator it = binary_op_node_map.find(make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code));
binary_op_node_map_t::iterator it = binary_op_node_map.find(make_pair(make_pair(make_pair(arg1, arg2),
make_pair(powerDerivOrder, adlparam)),
op_code));
if (it != binary_op_node_map.end())
return it->second;
@ -363,7 +361,7 @@ DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerD
catch (ExprNode::EvalException &e)
{
}
return new BinaryOpNode(*this, arg1, op_code, arg2, powerDerivOrder);
return new BinaryOpNode(*this, arg1, op_code, arg2, powerDerivOrder, adlparam);
}
inline expr_t

View File

@ -2887,9 +2887,10 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
arg1(arg1_arg),
arg2(arg2_arg),
op_code(op_code_arg),
powerDerivOrder(0)
powerDerivOrder(0),
adlparam("")
{
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code)] = this;
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
}
BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
@ -2898,10 +2899,25 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
arg1(arg1_arg),
arg2(arg2_arg),
op_code(op_code_arg),
powerDerivOrder(powerDerivOrder_arg)
powerDerivOrder(powerDerivOrder_arg),
adlparam("")
{
assert(powerDerivOrder >= 0);
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code)] = this;
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
}
BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
BinaryOpcode op_code_arg, const expr_t arg2_arg,
int powerDerivOrder_arg, string adlparam_arg) :
ExprNode(datatree_arg),
arg1(arg1_arg),
arg2(arg2_arg),
op_code(op_code_arg),
powerDerivOrder(powerDerivOrder_arg),
adlparam(adlparam_arg)
{
assert(powerDerivOrder >= 0);
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
}
void
@ -4101,9 +4117,7 @@ BinaryOpNode::buildSimilarBinaryOpNode(expr_t alt_arg1, expr_t alt_arg2, DataTre
case oPowerDeriv:
return alt_datatree.AddPowerDeriv(alt_arg1, alt_arg2, powerDerivOrder);
case oAdl:
DataTree::adl_map_t::const_iterator it = datatree.adl_map.find(const_cast<BinaryOpNode *>(this));
assert (it != datatree.adl_map.end());
return alt_datatree.AddAdl(alt_arg1, *(it->second), alt_arg2);
return alt_datatree.AddAdl(alt_arg1, adlparam, alt_arg2);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
@ -4301,16 +4315,13 @@ BinaryOpNode::substituteAdlAndDiff() const
}
expr_t arg1subst = arg1->substituteAdlAndDiff();
DataTree::adl_map_t::const_iterator it = datatree.adl_map.find(const_cast<BinaryOpNode *>(this));
assert (it != datatree.adl_map.end());
int i = 1;
expr_t retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(*(it->second), i), 0),
expr_t retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(adlparam, i), 0),
arg1subst->decreaseLeadsLags(i));
i++;
for (; i <= (int) arg2->eval(eval_context_t()); i++)
retval = datatree.AddPlus(retval,
datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(*(it->second), i), 0),
datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(adlparam, i), 0),
arg1subst->decreaseLeadsLags(i)));
return retval;
}

View File

@ -730,11 +730,14 @@ private:
//! Returns the derivative of this node if darg1 and darg2 are the derivatives of the arguments
expr_t composeDerivatives(expr_t darg1, expr_t darg2);
const int powerDerivOrder;
const string adlparam;
public:
BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
BinaryOpcode op_code_arg, const expr_t arg2_arg);
BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder);
BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder_arg, string adlparam_arg);
virtual void prepareForDerivation();
virtual int precedenceJson(const temporary_terms_t &temporary_terms) const;
virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const;

View File

@ -650,7 +650,7 @@ SymbolTable::addLagAuxiliaryVarInternal(bool endo, int orig_symb_id, int orig_le
}
int
SymbolTable::addAdlParameter(string &basename, int lag) throw (FrozenException)
SymbolTable::addAdlParameter(const string &basename, int lag) throw (FrozenException)
{
ostringstream varname;
varname << basename << "_lag_" << lag;

View File

@ -287,7 +287,7 @@ public:
/*
// Adds a parameter for the transformation of the adl operator
*/
int addAdlParameter(string &basename, int lag) throw (FrozenException);
int addAdlParameter(const string &basename, int lag) throw (FrozenException);
//! Returns the number of auxiliary variables
int
AuxVarsSize() const