Julia: adapt DynamicSetAuxiliarySeries.jl for TimeDataFrame objects

fix-tolerance-parameters
Sébastien Villemot 2022-04-01 17:34:18 +02:00
parent 71edfd05e4
commit bfdcc546ec
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
3 changed files with 71 additions and 14 deletions

View File

@ -4595,7 +4595,8 @@ void
DynamicModel::writeSetAuxiliaryVariables(const string &basename, bool julia) const DynamicModel::writeSetAuxiliaryVariables(const string &basename, bool julia) const
{ {
ostringstream output_func_body; ostringstream output_func_body;
writeAuxVarRecursiveDefinitions(output_func_body, ExprNodeOutputType::matlabDseries); writeAuxVarRecursiveDefinitions(output_func_body, julia ? ExprNodeOutputType::juliaTimeDataFrame
: ExprNodeOutputType::matlabDseries);
if (output_func_body.str().empty()) if (output_func_body.str().empty())
return; return;

View File

@ -948,6 +948,19 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
return; return;
} }
auto juliaTimeDataFrameHelper = [&]()
{
if (lag != 0)
output << "lag(";
output << "ds." << datatree.symbol_table.getName(symb_id);
if (lag != 0)
{
if (lag != -1)
output << "," << -lag;
output << ")";
}
};
int i; int i;
switch (type) switch (type)
{ {
@ -1012,6 +1025,9 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
if (lag != 0) if (lag != 0)
output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type); output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
break; break;
case ExprNodeOutputType::juliaTimeDataFrame:
juliaTimeDataFrameHelper();
break;
case ExprNodeOutputType::epilogueFile: case ExprNodeOutputType::epilogueFile:
output << "ds." << datatree.symbol_table.getName(symb_id); output << "ds." << datatree.symbol_table.getName(symb_id);
output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t"; output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
@ -1073,6 +1089,9 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
if (lag != 0) if (lag != 0)
output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type); output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
break; break;
case ExprNodeOutputType::juliaTimeDataFrame:
juliaTimeDataFrameHelper();
break;
case ExprNodeOutputType::epilogueFile: case ExprNodeOutputType::epilogueFile:
output << "ds." << datatree.symbol_table.getName(symb_id); output << "ds." << datatree.symbol_table.getName(symb_id);
output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t"; output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
@ -1131,6 +1150,9 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
if (lag != 0) if (lag != 0)
output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type); output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
break; break;
case ExprNodeOutputType::juliaTimeDataFrame:
juliaTimeDataFrameHelper();
break;
case ExprNodeOutputType::epilogueFile: case ExprNodeOutputType::epilogueFile:
output << "ds." << datatree.symbol_table.getName(symb_id); output << "ds." << datatree.symbol_table.getName(symb_id);
output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t"; output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
@ -1152,7 +1174,8 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
output << lag; output << lag;
output << RIGHT_ARRAY_SUBSCRIPT(output_type); output << RIGHT_ARRAY_SUBSCRIPT(output_type);
} }
else if (output_type == ExprNodeOutputType::matlabDseries) else if (output_type == ExprNodeOutputType::matlabDseries
|| output_type == ExprNodeOutputType::juliaTimeDataFrame)
// Only writing dseries for epilogue_static, hence no need to check lag // Only writing dseries for epilogue_static, hence no need to check lag
output << "ds." << datatree.symbol_table.getName(symb_id); output << "ds." << datatree.symbol_table.getName(symb_id);
else else
@ -2842,6 +2865,10 @@ UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
break; break;
} }
if (output_type == ExprNodeOutputType::juliaTimeDataFrame
&& op_code != UnaryOpcode::uminus)
output << "."; // Use vectorized form of the function
bool close_parenthesis = false; bool close_parenthesis = false;
/* Enclose argument with parentheses if: /* Enclose argument with parentheses if:
@ -4569,60 +4596,81 @@ BinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
case BinaryOpcode::times: case BinaryOpcode::times:
if (isLatexOutput(output_type)) if (isLatexOutput(output_type))
output << R"(\, )"; output << R"(\, )";
else if (output_type == ExprNodeOutputType::occbinDifferenceFile) else if (output_type == ExprNodeOutputType::occbinDifferenceFile // This file operates on vectors, see dynare#1826
output << ".*"; // This file operates on vectors, see dynare#1826 || output_type == ExprNodeOutputType::juliaTimeDataFrame)
output << ".*";
else else
output << "*"; output << "*";
break; break;
case BinaryOpcode::divide: case BinaryOpcode::divide:
if (!isLatexOutput(output_type)) if (!isLatexOutput(output_type))
{ {
if (output_type == ExprNodeOutputType::occbinDifferenceFile) if (output_type == ExprNodeOutputType::occbinDifferenceFile // This file operates on vectors, see dynare#1826
|| output_type == ExprNodeOutputType::juliaTimeDataFrame)
output << "./"; // This file operates on vectors, see dynare#1826 output << "./"; // This file operates on vectors, see dynare#1826
else else
output << "/"; output << "/";
} }
break; break;
case BinaryOpcode::power: case BinaryOpcode::power:
if (output_type == ExprNodeOutputType::occbinDifferenceFile) if (output_type == ExprNodeOutputType::occbinDifferenceFile // This file operates on vectors, see dynare#1826
|| output_type == ExprNodeOutputType::juliaTimeDataFrame)
output << ".^"; // This file operates on vectors, see dynare#1826 output << ".^"; // This file operates on vectors, see dynare#1826
else else
output << "^"; output << "^";
break; break;
case BinaryOpcode::less: case BinaryOpcode::less:
output << "<"; if (output_type == ExprNodeOutputType::juliaTimeDataFrame)
output << ".<";
else
output << "<";
break; break;
case BinaryOpcode::greater: case BinaryOpcode::greater:
output << ">"; if (output_type == ExprNodeOutputType::juliaTimeDataFrame)
output << ".>";
else
output << ">";
break; break;
case BinaryOpcode::lessEqual: case BinaryOpcode::lessEqual:
if (isLatexOutput(output_type)) if (isLatexOutput(output_type))
output << R"(\leq )"; output << R"(\leq )";
else if (output_type == ExprNodeOutputType::juliaTimeDataFrame)
output << ".<=";
else else
output << "<="; output << "<=";
break; break;
case BinaryOpcode::greaterEqual: case BinaryOpcode::greaterEqual:
if (isLatexOutput(output_type)) if (isLatexOutput(output_type))
output << R"(\geq )"; output << R"(\geq )";
else if (output_type == ExprNodeOutputType::juliaTimeDataFrame)
output << ".>=";
else else
output << ">="; output << ">=";
break; break;
case BinaryOpcode::equalEqual: case BinaryOpcode::equalEqual:
output << "=="; if (output_type == ExprNodeOutputType::juliaTimeDataFrame)
output << ".==";
else
output << "==";
break; break;
case BinaryOpcode::different: case BinaryOpcode::different:
if (isMatlabOutput(output_type)) if (isMatlabOutput(output_type))
output << "~="; output << "~=";
else else
{ {
if (isCOutput(output_type) || isJuliaOutput(output_type)) if (output_type == ExprNodeOutputType::juliaTimeDataFrame)
output << ".!=";
else if (isCOutput(output_type) || isJuliaOutput(output_type))
output << "!="; output << "!=";
else else
output << R"(\neq )"; output << R"(\neq )";
} }
break; break;
case BinaryOpcode::equal: case BinaryOpcode::equal:
output << "="; if (output_type == ExprNodeOutputType::juliaTimeDataFrame)
output << ".=";
else
output << "=";
break; break;
default: default:
; ;
@ -6029,7 +6077,10 @@ TrinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
else if (isJuliaOutput(output_type)) else if (isJuliaOutput(output_type))
{ {
// Julia API is normcdf(mu, sigma, x) ! // Julia API is normcdf(mu, sigma, x) !
output << "normcdf("; output << "normcdf";
if (output_type == ExprNodeOutputType::juliaTimeDataFrame)
output << ".";
output << "(";
arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
output << ","; output << ",";
arg3->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); arg3->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
@ -6075,7 +6126,10 @@ TrinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
} }
else else
{ {
output << "normpdf("; output << "normpdf";
if (output_type == ExprNodeOutputType::juliaTimeDataFrame)
output << ".";
output << "(";
arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
output << ","; output << ",";
arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);

View File

@ -97,6 +97,7 @@ enum class ExprNodeOutputType
steadyStateFile, //!< Matlab code, in the generated steady state file steadyStateFile, //!< Matlab code, in the generated steady state file
juliaSteadyStateFile, //!< Julia code, in the generated steady state file juliaSteadyStateFile, //!< Julia code, in the generated steady state file
matlabDseries, //!< Matlab code for dseries matlabDseries, //!< Matlab code for dseries
juliaTimeDataFrame, //!< Julia code for TimeDataFrame objects
epilogueFile, //!< Matlab code, in the generated epilogue file epilogueFile, //!< Matlab code, in the generated epilogue file
occbinDifferenceFile //!< MATLAB, in the generated occbin_difference file occbinDifferenceFile //!< MATLAB, in the generated occbin_difference file
}; };
@ -120,7 +121,8 @@ isJuliaOutput(ExprNodeOutputType output_type)
return output_type == ExprNodeOutputType::juliaStaticModel return output_type == ExprNodeOutputType::juliaStaticModel
|| output_type == ExprNodeOutputType::juliaDynamicModel || output_type == ExprNodeOutputType::juliaDynamicModel
|| output_type == ExprNodeOutputType::juliaDynamicSteadyStateOperator || output_type == ExprNodeOutputType::juliaDynamicSteadyStateOperator
|| output_type == ExprNodeOutputType::juliaSteadyStateFile; || output_type == ExprNodeOutputType::juliaSteadyStateFile
|| output_type == ExprNodeOutputType::juliaTimeDataFrame;
} }
inline bool inline bool