preprocessor: fix bug in adl implementation
parent
0d31c7a893
commit
2deb4b42fb
|
@ -264,10 +264,9 @@ DataTree::AddDiff(expr_t iArg1)
|
|||
}
|
||||
|
||||
expr_t
|
||||
DataTree::AddAdl(expr_t iArg1, string &name, expr_t iArg2)
|
||||
DataTree::AddAdl(expr_t iArg1, const string &name, expr_t iArg2)
|
||||
{
|
||||
expr_t adlnode = AddBinaryOp(iArg1, oAdl, iArg2);
|
||||
adl_map[adlnode] = new string(name);
|
||||
expr_t adlnode = AddBinaryOp(iArg1, oAdl, iArg2, 0, string(name));
|
||||
return adlnode;
|
||||
}
|
||||
|
||||
|
|
18
DataTree.hh
18
DataTree.hh
|
@ -57,10 +57,6 @@ protected:
|
|||
//! A reference to the external functions table
|
||||
ExternalFunctionsTable &external_functions_table;
|
||||
|
||||
//! A reference to the adl table
|
||||
typedef map<expr_t, string *> adl_map_t;
|
||||
adl_map_t adl_map;
|
||||
|
||||
typedef map<int, NumConstNode *> num_const_node_map_t;
|
||||
num_const_node_map_t num_const_node_map;
|
||||
//! Pair (symbol_id, lag) used as key
|
||||
|
@ -70,7 +66,7 @@ protected:
|
|||
typedef map<pair<pair<expr_t, UnaryOpcode>, pair<int, pair<int, int> > >, UnaryOpNode *> unary_op_node_map_t;
|
||||
unary_op_node_map_t unary_op_node_map;
|
||||
//! Pair( Pair( Pair(arg1, arg2), order of Power Derivative), opCode)
|
||||
typedef map<pair<pair<pair<expr_t, expr_t>, int>, BinaryOpcode>, BinaryOpNode *> binary_op_node_map_t;
|
||||
typedef map<pair<pair<pair<expr_t, expr_t>, pair<int, string> >, BinaryOpcode>, BinaryOpNode *> binary_op_node_map_t;
|
||||
binary_op_node_map_t binary_op_node_map;
|
||||
typedef map<pair<pair<pair<expr_t, expr_t>, expr_t>, TrinaryOpcode>, TrinaryOpNode *> trinary_op_node_map_t;
|
||||
trinary_op_node_map_t trinary_op_node_map;
|
||||
|
@ -108,7 +104,7 @@ private:
|
|||
|
||||
inline expr_t AddPossiblyNegativeConstant(double val);
|
||||
inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0);
|
||||
inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0);
|
||||
inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0, const string &adlparam = "");
|
||||
inline expr_t AddTrinaryOp(expr_t arg1, TrinaryOpcode op_code, expr_t arg2, expr_t arg3);
|
||||
|
||||
public:
|
||||
|
@ -169,7 +165,7 @@ public:
|
|||
//! Adds "diff(arg)" to model tree
|
||||
expr_t AddDiff(expr_t iArg1);
|
||||
//! Adds "adl(arg1, arg2)" to model tree
|
||||
expr_t AddAdl(expr_t iArg1, string &name, expr_t iArg2);
|
||||
expr_t AddAdl(expr_t iArg1, const string &name, expr_t iArg2);
|
||||
//! Adds "exp(arg)" to model tree
|
||||
expr_t AddExp(expr_t iArg1);
|
||||
//! Adds "log(arg)" to model tree
|
||||
|
@ -346,9 +342,11 @@ DataTree::AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set, int
|
|||
}
|
||||
|
||||
inline expr_t
|
||||
DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder)
|
||||
DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder, const string &adlparam)
|
||||
{
|
||||
binary_op_node_map_t::iterator it = binary_op_node_map.find(make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code));
|
||||
binary_op_node_map_t::iterator it = binary_op_node_map.find(make_pair(make_pair(make_pair(arg1, arg2),
|
||||
make_pair(powerDerivOrder, adlparam)),
|
||||
op_code));
|
||||
if (it != binary_op_node_map.end())
|
||||
return it->second;
|
||||
|
||||
|
@ -363,7 +361,7 @@ DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerD
|
|||
catch (ExprNode::EvalException &e)
|
||||
{
|
||||
}
|
||||
return new BinaryOpNode(*this, arg1, op_code, arg2, powerDerivOrder);
|
||||
return new BinaryOpNode(*this, arg1, op_code, arg2, powerDerivOrder, adlparam);
|
||||
}
|
||||
|
||||
inline expr_t
|
||||
|
|
35
ExprNode.cc
35
ExprNode.cc
|
@ -2887,9 +2887,10 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
|
|||
arg1(arg1_arg),
|
||||
arg2(arg2_arg),
|
||||
op_code(op_code_arg),
|
||||
powerDerivOrder(0)
|
||||
powerDerivOrder(0),
|
||||
adlparam("")
|
||||
{
|
||||
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code)] = this;
|
||||
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
|
||||
}
|
||||
|
||||
BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
|
||||
|
@ -2898,10 +2899,25 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
|
|||
arg1(arg1_arg),
|
||||
arg2(arg2_arg),
|
||||
op_code(op_code_arg),
|
||||
powerDerivOrder(powerDerivOrder_arg)
|
||||
powerDerivOrder(powerDerivOrder_arg),
|
||||
adlparam("")
|
||||
{
|
||||
assert(powerDerivOrder >= 0);
|
||||
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code)] = this;
|
||||
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
|
||||
}
|
||||
|
||||
BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
|
||||
BinaryOpcode op_code_arg, const expr_t arg2_arg,
|
||||
int powerDerivOrder_arg, string adlparam_arg) :
|
||||
ExprNode(datatree_arg),
|
||||
arg1(arg1_arg),
|
||||
arg2(arg2_arg),
|
||||
op_code(op_code_arg),
|
||||
powerDerivOrder(powerDerivOrder_arg),
|
||||
adlparam(adlparam_arg)
|
||||
{
|
||||
assert(powerDerivOrder >= 0);
|
||||
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -4101,9 +4117,7 @@ BinaryOpNode::buildSimilarBinaryOpNode(expr_t alt_arg1, expr_t alt_arg2, DataTre
|
|||
case oPowerDeriv:
|
||||
return alt_datatree.AddPowerDeriv(alt_arg1, alt_arg2, powerDerivOrder);
|
||||
case oAdl:
|
||||
DataTree::adl_map_t::const_iterator it = datatree.adl_map.find(const_cast<BinaryOpNode *>(this));
|
||||
assert (it != datatree.adl_map.end());
|
||||
return alt_datatree.AddAdl(alt_arg1, *(it->second), alt_arg2);
|
||||
return alt_datatree.AddAdl(alt_arg1, adlparam, alt_arg2);
|
||||
}
|
||||
// Suppress GCC warning
|
||||
exit(EXIT_FAILURE);
|
||||
|
@ -4301,16 +4315,13 @@ BinaryOpNode::substituteAdlAndDiff() const
|
|||
}
|
||||
|
||||
expr_t arg1subst = arg1->substituteAdlAndDiff();
|
||||
DataTree::adl_map_t::const_iterator it = datatree.adl_map.find(const_cast<BinaryOpNode *>(this));
|
||||
assert (it != datatree.adl_map.end());
|
||||
|
||||
int i = 1;
|
||||
expr_t retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(*(it->second), i), 0),
|
||||
expr_t retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(adlparam, i), 0),
|
||||
arg1subst->decreaseLeadsLags(i));
|
||||
i++;
|
||||
for (; i <= (int) arg2->eval(eval_context_t()); i++)
|
||||
retval = datatree.AddPlus(retval,
|
||||
datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(*(it->second), i), 0),
|
||||
datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(adlparam, i), 0),
|
||||
arg1subst->decreaseLeadsLags(i)));
|
||||
return retval;
|
||||
}
|
||||
|
|
|
@ -730,11 +730,14 @@ private:
|
|||
//! Returns the derivative of this node if darg1 and darg2 are the derivatives of the arguments
|
||||
expr_t composeDerivatives(expr_t darg1, expr_t darg2);
|
||||
const int powerDerivOrder;
|
||||
const string adlparam;
|
||||
public:
|
||||
BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
|
||||
BinaryOpcode op_code_arg, const expr_t arg2_arg);
|
||||
BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
|
||||
BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder);
|
||||
BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
|
||||
BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder_arg, string adlparam_arg);
|
||||
virtual void prepareForDerivation();
|
||||
virtual int precedenceJson(const temporary_terms_t &temporary_terms) const;
|
||||
virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const;
|
||||
|
|
|
@ -650,7 +650,7 @@ SymbolTable::addLagAuxiliaryVarInternal(bool endo, int orig_symb_id, int orig_le
|
|||
}
|
||||
|
||||
int
|
||||
SymbolTable::addAdlParameter(string &basename, int lag) throw (FrozenException)
|
||||
SymbolTable::addAdlParameter(const string &basename, int lag) throw (FrozenException)
|
||||
{
|
||||
ostringstream varname;
|
||||
varname << basename << "_lag_" << lag;
|
||||
|
|
|
@ -287,7 +287,7 @@ public:
|
|||
/*
|
||||
// Adds a parameter for the transformation of the adl operator
|
||||
*/
|
||||
int addAdlParameter(string &basename, int lag) throw (FrozenException);
|
||||
int addAdlParameter(const string &basename, int lag) throw (FrozenException);
|
||||
//! Returns the number of auxiliary variables
|
||||
int
|
||||
AuxVarsSize() const
|
||||
|
|
Loading…
Reference in New Issue