diff --git a/src/macro/Expressions.cc b/src/macro/Expressions.cc index 5fd216e4..16a4553a 100644 --- a/src/macro/Expressions.cc +++ b/src/macro/Expressions.cc @@ -39,25 +39,33 @@ Bool::is_equal(const BaseTypePtr &btp) const } BoolPtr -Bool::logical_and(const BaseTypePtr &btp) const +Bool::logical_and(const ExpressionPtr &ep) const { + if (!value) + return make_shared(false, env); + + auto btp = ep->eval(); if (auto btp2 = dynamic_pointer_cast(btp); btp2) - return make_shared(value && *btp2, env); + return make_shared(*btp2, env); if (auto btp2 = dynamic_pointer_cast(btp); btp2) - return make_shared(value && *btp2, env); + return make_shared(*btp2, env); throw StackTrace("Type mismatch for operands of && operator"); } BoolPtr -Bool::logical_or(const BaseTypePtr &btp) const +Bool::logical_or(const ExpressionPtr &ep) const { + if (value) + return make_shared(true, env); + + auto btp = ep->eval(); if (auto btp2 = dynamic_pointer_cast(btp); btp2) - return make_shared(value || *btp2, env); + return make_shared(*btp2, env); if (auto btp2 = dynamic_pointer_cast(btp); btp2) - return make_shared(value || *btp2, env); + return make_shared(*btp2, env); throw StackTrace("Type mismatch for operands of || operator"); } @@ -159,25 +167,33 @@ Real::is_equal(const BaseTypePtr &btp) const } BoolPtr -Real::logical_and(const BaseTypePtr &btp) const +Real::logical_and(const ExpressionPtr &ep) const { + if (!value) + return make_shared(false, env); + + auto btp = ep->eval(); if (auto btp2 = dynamic_pointer_cast(btp); btp2) - return make_shared(value && *btp2, env); + return make_shared(*btp2, env); if (auto btp2 = dynamic_pointer_cast(btp); btp2) - return make_shared(value && *btp2, env); + return make_shared(*btp2, env); throw StackTrace("Type mismatch for operands of && operator"); } BoolPtr -Real::logical_or(const BaseTypePtr &btp) const +Real::logical_or(const ExpressionPtr &ep) const { + if (value) + return make_shared(true, env); + + auto btp = ep->eval(); if (auto btp2 = dynamic_pointer_cast(btp); btp2) - return make_shared(value || *btp2, env); + return make_shared(*btp2, env); if (auto btp2 = dynamic_pointer_cast(btp); btp2) - return make_shared(value || *btp2, env); + return make_shared(*btp2, env); throw StackTrace("Type mismatch for operands of || operator"); } @@ -859,47 +875,46 @@ BinaryOp::eval() try { auto arg1bt = arg1->eval(); - auto arg2bt = arg2->eval(); switch(op_code) { case codes::BinaryOp::plus: - return arg1bt->plus(arg2bt); + return arg1bt->plus(arg2->eval()); case codes::BinaryOp::minus: - return arg1bt->minus(arg2bt); + return arg1bt->minus(arg2->eval()); case codes::BinaryOp::times: - return arg1bt->times(arg2bt); + return arg1bt->times(arg2->eval()); case codes::BinaryOp::divide: - return arg1bt->divide(arg2bt); + return arg1bt->divide(arg2->eval()); case codes::BinaryOp::power: - return arg1bt->power( arg2bt); + return arg1bt->power( arg2->eval()); case codes::BinaryOp::equal_equal: - return arg1bt->is_equal(arg2bt); + return arg1bt->is_equal(arg2->eval()); case codes::BinaryOp::not_equal: - return arg1bt->is_different(arg2bt); + return arg1bt->is_different(arg2->eval()); case codes::BinaryOp::less: - return arg1bt->is_less(arg2bt); + return arg1bt->is_less(arg2->eval()); case codes::BinaryOp::greater: - return arg1bt->is_greater(arg2bt); + return arg1bt->is_greater(arg2->eval()); case codes::BinaryOp::less_equal: - return arg1bt->is_less_equal(arg2bt); + return arg1bt->is_less_equal(arg2->eval()); case codes::BinaryOp::greater_equal: - return arg1bt->is_greater_equal(arg2bt); + return arg1bt->is_greater_equal(arg2->eval()); case codes::BinaryOp::logical_and: - return arg1bt->logical_and(arg2bt); + return arg1bt->logical_and(arg2); case codes::BinaryOp::logical_or: - return arg1bt->logical_or(arg2bt); + return arg1bt->logical_or(arg2); case codes::BinaryOp::in: - return arg2bt->contains(arg1bt); + return arg2->eval()->contains(arg1bt); case codes::BinaryOp::set_union: - return arg1bt->set_union(arg2bt); + return arg1bt->set_union(arg2->eval()); case codes::BinaryOp::set_intersection: - return arg1bt->set_intersection(arg2bt); + return arg1bt->set_intersection(arg2->eval()); case codes::BinaryOp::max: - return arg1bt->max(arg2bt); + return arg1bt->max(arg2->eval()); case codes::BinaryOp::min: - return arg1bt->min(arg2bt); + return arg1bt->min(arg2->eval()); case codes::BinaryOp::mod: - return arg1bt->mod(arg2bt); + return arg1bt->mod(arg2->eval()); } } catch (StackTrace &ex) diff --git a/src/macro/Expressions.hh b/src/macro/Expressions.hh index fdc3843a..668051e4 100644 --- a/src/macro/Expressions.hh +++ b/src/macro/Expressions.hh @@ -145,8 +145,8 @@ namespace macro virtual BoolPtr is_greater_equal(const BaseTypePtr &btp) const { throw StackTrace("Operator >= does not exist for this type"); } virtual BoolPtr is_equal(const BaseTypePtr &btp) const = 0; virtual BoolPtr is_different(const BaseTypePtr &btp) const final; - virtual BoolPtr logical_and(const BaseTypePtr &btp) const { throw StackTrace("Operator && does not exist for this type"); } - virtual BoolPtr logical_or(const BaseTypePtr &btp) const { throw StackTrace("Operator || does not exist for this type"); } + virtual BoolPtr logical_and(const ExpressionPtr &ep) const { throw StackTrace("Operator && does not exist for this type"); } + virtual BoolPtr logical_or(const ExpressionPtr &ep) const { throw StackTrace("Operator || does not exist for this type"); } virtual BoolPtr logical_not() const { throw StackTrace("Operator ! does not exist for this type"); } virtual ArrayPtr set_union(const BaseTypePtr &btp) const { throw StackTrace("Operator | does not exist for this type"); } virtual ArrayPtr set_intersection(const BaseTypePtr &btp) const { throw StackTrace("Operator & does not exist for this type"); } @@ -216,8 +216,8 @@ namespace macro public: operator bool() const { return value; } BoolPtr is_equal(const BaseTypePtr &btp) const override; - BoolPtr logical_and(const BaseTypePtr &btp) const override; - BoolPtr logical_or(const BaseTypePtr &btp) const override; + BoolPtr logical_and(const ExpressionPtr &ep) const override; + BoolPtr logical_or(const ExpressionPtr &ep) const override; BoolPtr logical_not() const override; inline BoolPtr isboolean() const noexcept override { return make_shared(true, env, location); } inline BoolPtr cast_bool() const override { return make_shared(value, env); } @@ -278,8 +278,8 @@ namespace macro double intpart; return make_shared(modf(value, &intpart) == 0.0, env, location); } - BoolPtr logical_and(const BaseTypePtr &btp) const override; - BoolPtr logical_or(const BaseTypePtr &btp) const override; + BoolPtr logical_and(const ExpressionPtr &ep) const override; + BoolPtr logical_or(const ExpressionPtr &ep) const override; BoolPtr logical_not() const override; RealPtr max(const BaseTypePtr &btp) const override; RealPtr min(const BaseTypePtr &btp) const override;