preprocessor: fix bug in adl implementation

issue#70
houtanb 2017-06-23 18:01:07 +02:00
parent 0d31c7a893
commit 2deb4b42fb
6 changed files with 38 additions and 27 deletions

View File

@ -264,10 +264,9 @@ DataTree::AddDiff(expr_t iArg1)
} }
expr_t expr_t
DataTree::AddAdl(expr_t iArg1, string &name, expr_t iArg2) DataTree::AddAdl(expr_t iArg1, const string &name, expr_t iArg2)
{ {
expr_t adlnode = AddBinaryOp(iArg1, oAdl, iArg2); expr_t adlnode = AddBinaryOp(iArg1, oAdl, iArg2, 0, string(name));
adl_map[adlnode] = new string(name);
return adlnode; return adlnode;
} }

View File

@ -57,10 +57,6 @@ protected:
//! A reference to the external functions table //! A reference to the external functions table
ExternalFunctionsTable &external_functions_table; ExternalFunctionsTable &external_functions_table;
//! A reference to the adl table
typedef map<expr_t, string *> adl_map_t;
adl_map_t adl_map;
typedef map<int, NumConstNode *> num_const_node_map_t; typedef map<int, NumConstNode *> num_const_node_map_t;
num_const_node_map_t num_const_node_map; num_const_node_map_t num_const_node_map;
//! Pair (symbol_id, lag) used as key //! Pair (symbol_id, lag) used as key
@ -70,7 +66,7 @@ protected:
typedef map<pair<pair<expr_t, UnaryOpcode>, pair<int, pair<int, int> > >, UnaryOpNode *> unary_op_node_map_t; typedef map<pair<pair<expr_t, UnaryOpcode>, pair<int, pair<int, int> > >, UnaryOpNode *> unary_op_node_map_t;
unary_op_node_map_t unary_op_node_map; unary_op_node_map_t unary_op_node_map;
//! Pair( Pair( Pair(arg1, arg2), order of Power Derivative), opCode) //! Pair( Pair( Pair(arg1, arg2), order of Power Derivative), opCode)
typedef map<pair<pair<pair<expr_t, expr_t>, int>, BinaryOpcode>, BinaryOpNode *> binary_op_node_map_t; typedef map<pair<pair<pair<expr_t, expr_t>, pair<int, string> >, BinaryOpcode>, BinaryOpNode *> binary_op_node_map_t;
binary_op_node_map_t binary_op_node_map; binary_op_node_map_t binary_op_node_map;
typedef map<pair<pair<pair<expr_t, expr_t>, expr_t>, TrinaryOpcode>, TrinaryOpNode *> trinary_op_node_map_t; typedef map<pair<pair<pair<expr_t, expr_t>, expr_t>, TrinaryOpcode>, TrinaryOpNode *> trinary_op_node_map_t;
trinary_op_node_map_t trinary_op_node_map; trinary_op_node_map_t trinary_op_node_map;
@ -108,7 +104,7 @@ private:
inline expr_t AddPossiblyNegativeConstant(double val); inline expr_t AddPossiblyNegativeConstant(double val);
inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0); inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0);
inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0); inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0, const string &adlparam = "");
inline expr_t AddTrinaryOp(expr_t arg1, TrinaryOpcode op_code, expr_t arg2, expr_t arg3); inline expr_t AddTrinaryOp(expr_t arg1, TrinaryOpcode op_code, expr_t arg2, expr_t arg3);
public: public:
@ -169,7 +165,7 @@ public:
//! Adds "diff(arg)" to model tree //! Adds "diff(arg)" to model tree
expr_t AddDiff(expr_t iArg1); expr_t AddDiff(expr_t iArg1);
//! Adds "adl(arg1, arg2)" to model tree //! Adds "adl(arg1, arg2)" to model tree
expr_t AddAdl(expr_t iArg1, string &name, expr_t iArg2); expr_t AddAdl(expr_t iArg1, const string &name, expr_t iArg2);
//! Adds "exp(arg)" to model tree //! Adds "exp(arg)" to model tree
expr_t AddExp(expr_t iArg1); expr_t AddExp(expr_t iArg1);
//! Adds "log(arg)" to model tree //! Adds "log(arg)" to model tree
@ -346,9 +342,11 @@ DataTree::AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set, int
} }
inline expr_t inline expr_t
DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder) DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder, const string &adlparam)
{ {
binary_op_node_map_t::iterator it = binary_op_node_map.find(make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code)); binary_op_node_map_t::iterator it = binary_op_node_map.find(make_pair(make_pair(make_pair(arg1, arg2),
make_pair(powerDerivOrder, adlparam)),
op_code));
if (it != binary_op_node_map.end()) if (it != binary_op_node_map.end())
return it->second; return it->second;
@ -363,7 +361,7 @@ DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerD
catch (ExprNode::EvalException &e) catch (ExprNode::EvalException &e)
{ {
} }
return new BinaryOpNode(*this, arg1, op_code, arg2, powerDerivOrder); return new BinaryOpNode(*this, arg1, op_code, arg2, powerDerivOrder, adlparam);
} }
inline expr_t inline expr_t

View File

@ -2887,9 +2887,10 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
arg1(arg1_arg), arg1(arg1_arg),
arg2(arg2_arg), arg2(arg2_arg),
op_code(op_code_arg), op_code(op_code_arg),
powerDerivOrder(0) powerDerivOrder(0),
adlparam("")
{ {
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code)] = this; datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
} }
BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg, BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
@ -2898,10 +2899,25 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
arg1(arg1_arg), arg1(arg1_arg),
arg2(arg2_arg), arg2(arg2_arg),
op_code(op_code_arg), op_code(op_code_arg),
powerDerivOrder(powerDerivOrder_arg) powerDerivOrder(powerDerivOrder_arg),
adlparam("")
{ {
assert(powerDerivOrder >= 0); assert(powerDerivOrder >= 0);
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code)] = this; datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
}
BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
BinaryOpcode op_code_arg, const expr_t arg2_arg,
int powerDerivOrder_arg, string adlparam_arg) :
ExprNode(datatree_arg),
arg1(arg1_arg),
arg2(arg2_arg),
op_code(op_code_arg),
powerDerivOrder(powerDerivOrder_arg),
adlparam(adlparam_arg)
{
assert(powerDerivOrder >= 0);
datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
} }
void void
@ -4101,9 +4117,7 @@ BinaryOpNode::buildSimilarBinaryOpNode(expr_t alt_arg1, expr_t alt_arg2, DataTre
case oPowerDeriv: case oPowerDeriv:
return alt_datatree.AddPowerDeriv(alt_arg1, alt_arg2, powerDerivOrder); return alt_datatree.AddPowerDeriv(alt_arg1, alt_arg2, powerDerivOrder);
case oAdl: case oAdl:
DataTree::adl_map_t::const_iterator it = datatree.adl_map.find(const_cast<BinaryOpNode *>(this)); return alt_datatree.AddAdl(alt_arg1, adlparam, alt_arg2);
assert (it != datatree.adl_map.end());
return alt_datatree.AddAdl(alt_arg1, *(it->second), alt_arg2);
} }
// Suppress GCC warning // Suppress GCC warning
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
@ -4301,16 +4315,13 @@ BinaryOpNode::substituteAdlAndDiff() const
} }
expr_t arg1subst = arg1->substituteAdlAndDiff(); expr_t arg1subst = arg1->substituteAdlAndDiff();
DataTree::adl_map_t::const_iterator it = datatree.adl_map.find(const_cast<BinaryOpNode *>(this));
assert (it != datatree.adl_map.end());
int i = 1; int i = 1;
expr_t retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(*(it->second), i), 0), expr_t retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(adlparam, i), 0),
arg1subst->decreaseLeadsLags(i)); arg1subst->decreaseLeadsLags(i));
i++; i++;
for (; i <= (int) arg2->eval(eval_context_t()); i++) for (; i <= (int) arg2->eval(eval_context_t()); i++)
retval = datatree.AddPlus(retval, retval = datatree.AddPlus(retval,
datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(*(it->second), i), 0), datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(adlparam, i), 0),
arg1subst->decreaseLeadsLags(i))); arg1subst->decreaseLeadsLags(i)));
return retval; return retval;
} }

View File

@ -730,11 +730,14 @@ private:
//! Returns the derivative of this node if darg1 and darg2 are the derivatives of the arguments //! Returns the derivative of this node if darg1 and darg2 are the derivatives of the arguments
expr_t composeDerivatives(expr_t darg1, expr_t darg2); expr_t composeDerivatives(expr_t darg1, expr_t darg2);
const int powerDerivOrder; const int powerDerivOrder;
const string adlparam;
public: public:
BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg, BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
BinaryOpcode op_code_arg, const expr_t arg2_arg); BinaryOpcode op_code_arg, const expr_t arg2_arg);
BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg, BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder); BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder);
BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder_arg, string adlparam_arg);
virtual void prepareForDerivation(); virtual void prepareForDerivation();
virtual int precedenceJson(const temporary_terms_t &temporary_terms) const; virtual int precedenceJson(const temporary_terms_t &temporary_terms) const;
virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const; virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const;

View File

@ -650,7 +650,7 @@ SymbolTable::addLagAuxiliaryVarInternal(bool endo, int orig_symb_id, int orig_le
} }
int int
SymbolTable::addAdlParameter(string &basename, int lag) throw (FrozenException) SymbolTable::addAdlParameter(const string &basename, int lag) throw (FrozenException)
{ {
ostringstream varname; ostringstream varname;
varname << basename << "_lag_" << lag; varname << basename << "_lag_" << lag;

View File

@ -287,7 +287,7 @@ public:
/* /*
// Adds a parameter for the transformation of the adl operator // Adds a parameter for the transformation of the adl operator
*/ */
int addAdlParameter(string &basename, int lag) throw (FrozenException); int addAdlParameter(const string &basename, int lag) throw (FrozenException);
//! Returns the number of auxiliary variables //! Returns the number of auxiliary variables
int int
AuxVarsSize() const AuxVarsSize() const