fix bug in macro processor ensuring short-circuit functionality of `||` and `&&` statements

closes dynare#1676
issue#70
Houtan Bastani 2019-12-10 16:25:28 +01:00
parent 5c081db76f
commit 952e899f3a
No known key found for this signature in database
GPG Key ID: 000094FB955BE169
2 changed files with 53 additions and 38 deletions

View File

@ -39,25 +39,33 @@ Bool::is_equal(const BaseTypePtr &btp) const
}
BoolPtr
Bool::logical_and(const BaseTypePtr &btp) const
Bool::logical_and(const ExpressionPtr &ep) const
{
if (!value)
return make_shared<Bool>(false, env);
auto btp = ep->eval();
if (auto btp2 = dynamic_pointer_cast<Bool>(btp); btp2)
return make_shared<Bool>(value && *btp2, env);
return make_shared<Bool>(*btp2, env);
if (auto btp2 = dynamic_pointer_cast<Real>(btp); btp2)
return make_shared<Bool>(value && *btp2, env);
return make_shared<Bool>(*btp2, env);
throw StackTrace("Type mismatch for operands of && operator");
}
BoolPtr
Bool::logical_or(const BaseTypePtr &btp) const
Bool::logical_or(const ExpressionPtr &ep) const
{
if (value)
return make_shared<Bool>(true, env);
auto btp = ep->eval();
if (auto btp2 = dynamic_pointer_cast<Bool>(btp); btp2)
return make_shared<Bool>(value || *btp2, env);
return make_shared<Bool>(*btp2, env);
if (auto btp2 = dynamic_pointer_cast<Real>(btp); btp2)
return make_shared<Bool>(value || *btp2, env);
return make_shared<Bool>(*btp2, env);
throw StackTrace("Type mismatch for operands of || operator");
}
@ -159,25 +167,33 @@ Real::is_equal(const BaseTypePtr &btp) const
}
BoolPtr
Real::logical_and(const BaseTypePtr &btp) const
Real::logical_and(const ExpressionPtr &ep) const
{
if (!value)
return make_shared<Bool>(false, env);
auto btp = ep->eval();
if (auto btp2 = dynamic_pointer_cast<Real>(btp); btp2)
return make_shared<Bool>(value && *btp2, env);
return make_shared<Bool>(*btp2, env);
if (auto btp2 = dynamic_pointer_cast<Bool>(btp); btp2)
return make_shared<Bool>(value && *btp2, env);
return make_shared<Bool>(*btp2, env);
throw StackTrace("Type mismatch for operands of && operator");
}
BoolPtr
Real::logical_or(const BaseTypePtr &btp) const
Real::logical_or(const ExpressionPtr &ep) const
{
if (value)
return make_shared<Bool>(true, env);
auto btp = ep->eval();
if (auto btp2 = dynamic_pointer_cast<Real>(btp); btp2)
return make_shared<Bool>(value || *btp2, env);
return make_shared<Bool>(*btp2, env);
if (auto btp2 = dynamic_pointer_cast<Bool>(btp); btp2)
return make_shared<Bool>(value || *btp2, env);
return make_shared<Bool>(*btp2, env);
throw StackTrace("Type mismatch for operands of || operator");
}
@ -859,47 +875,46 @@ BinaryOp::eval()
try
{
auto arg1bt = arg1->eval();
auto arg2bt = arg2->eval();
switch(op_code)
{
case codes::BinaryOp::plus:
return arg1bt->plus(arg2bt);
return arg1bt->plus(arg2->eval());
case codes::BinaryOp::minus:
return arg1bt->minus(arg2bt);
return arg1bt->minus(arg2->eval());
case codes::BinaryOp::times:
return arg1bt->times(arg2bt);
return arg1bt->times(arg2->eval());
case codes::BinaryOp::divide:
return arg1bt->divide(arg2bt);
return arg1bt->divide(arg2->eval());
case codes::BinaryOp::power:
return arg1bt->power( arg2bt);
return arg1bt->power( arg2->eval());
case codes::BinaryOp::equal_equal:
return arg1bt->is_equal(arg2bt);
return arg1bt->is_equal(arg2->eval());
case codes::BinaryOp::not_equal:
return arg1bt->is_different(arg2bt);
return arg1bt->is_different(arg2->eval());
case codes::BinaryOp::less:
return arg1bt->is_less(arg2bt);
return arg1bt->is_less(arg2->eval());
case codes::BinaryOp::greater:
return arg1bt->is_greater(arg2bt);
return arg1bt->is_greater(arg2->eval());
case codes::BinaryOp::less_equal:
return arg1bt->is_less_equal(arg2bt);
return arg1bt->is_less_equal(arg2->eval());
case codes::BinaryOp::greater_equal:
return arg1bt->is_greater_equal(arg2bt);
return arg1bt->is_greater_equal(arg2->eval());
case codes::BinaryOp::logical_and:
return arg1bt->logical_and(arg2bt);
return arg1bt->logical_and(arg2);
case codes::BinaryOp::logical_or:
return arg1bt->logical_or(arg2bt);
return arg1bt->logical_or(arg2);
case codes::BinaryOp::in:
return arg2bt->contains(arg1bt);
return arg2->eval()->contains(arg1bt);
case codes::BinaryOp::set_union:
return arg1bt->set_union(arg2bt);
return arg1bt->set_union(arg2->eval());
case codes::BinaryOp::set_intersection:
return arg1bt->set_intersection(arg2bt);
return arg1bt->set_intersection(arg2->eval());
case codes::BinaryOp::max:
return arg1bt->max(arg2bt);
return arg1bt->max(arg2->eval());
case codes::BinaryOp::min:
return arg1bt->min(arg2bt);
return arg1bt->min(arg2->eval());
case codes::BinaryOp::mod:
return arg1bt->mod(arg2bt);
return arg1bt->mod(arg2->eval());
}
}
catch (StackTrace &ex)

View File

@ -145,8 +145,8 @@ namespace macro
virtual BoolPtr is_greater_equal(const BaseTypePtr &btp) const { throw StackTrace("Operator >= does not exist for this type"); }
virtual BoolPtr is_equal(const BaseTypePtr &btp) const = 0;
virtual BoolPtr is_different(const BaseTypePtr &btp) const final;
virtual BoolPtr logical_and(const BaseTypePtr &btp) const { throw StackTrace("Operator && does not exist for this type"); }
virtual BoolPtr logical_or(const BaseTypePtr &btp) const { throw StackTrace("Operator || does not exist for this type"); }
virtual BoolPtr logical_and(const ExpressionPtr &ep) const { throw StackTrace("Operator && does not exist for this type"); }
virtual BoolPtr logical_or(const ExpressionPtr &ep) const { throw StackTrace("Operator || does not exist for this type"); }
virtual BoolPtr logical_not() const { throw StackTrace("Operator ! does not exist for this type"); }
virtual ArrayPtr set_union(const BaseTypePtr &btp) const { throw StackTrace("Operator | does not exist for this type"); }
virtual ArrayPtr set_intersection(const BaseTypePtr &btp) const { throw StackTrace("Operator & does not exist for this type"); }
@ -216,8 +216,8 @@ namespace macro
public:
operator bool() const { return value; }
BoolPtr is_equal(const BaseTypePtr &btp) const override;
BoolPtr logical_and(const BaseTypePtr &btp) const override;
BoolPtr logical_or(const BaseTypePtr &btp) const override;
BoolPtr logical_and(const ExpressionPtr &ep) const override;
BoolPtr logical_or(const ExpressionPtr &ep) const override;
BoolPtr logical_not() const override;
inline BoolPtr isboolean() const noexcept override { return make_shared<Bool>(true, env, location); }
inline BoolPtr cast_bool() const override { return make_shared<Bool>(value, env); }
@ -278,8 +278,8 @@ namespace macro
double intpart;
return make_shared<Bool>(modf(value, &intpart) == 0.0, env, location);
}
BoolPtr logical_and(const BaseTypePtr &btp) const override;
BoolPtr logical_or(const BaseTypePtr &btp) const override;
BoolPtr logical_and(const ExpressionPtr &ep) const override;
BoolPtr logical_or(const ExpressionPtr &ep) const override;
BoolPtr logical_not() const override;
RealPtr max(const BaseTypePtr &btp) const override;
RealPtr min(const BaseTypePtr &btp) const override;