aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/control_flow.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/control_flow.cc')
-rw-r--r--reference_model/src/ops/control_flow.cc21
1 files changed, 10 insertions, 11 deletions
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index f573d5b..03ad6c6 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -174,8 +174,8 @@ int OpCondIf::checkTensorAttributes()
{
ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand");
- ERROR_IF(inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0,
- "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()],
+ ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0,
+ "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNameTOSAREFTYPE(inputs[0]->getDtype()),
inputs[0]->getRank());
cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
@@ -223,9 +223,9 @@ int OpCondIf::checkTensorAttributes()
std::string else_block_input_name = else_block->GetInputs()[i];
TosaSerializationTensor* then_block_input = then_block->GetTensorByName(then_block_input_name);
TosaSerializationTensor* else_block_input = else_block->GetTensorByName(else_block_input_name);
- ERROR_IF(operator_input->getDtype() != then_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(then_block_input->GetDtype()),
"OpCondIf: input tensor type mismatch with then_block input type");
- ERROR_IF(operator_input->getDtype() != else_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(else_block_input->GetDtype()),
"OpCondIf: input tensor type mismatch with else_block input type");
ERROR_IF(operator_input->getRank() != (int32_t)then_block_input->GetShape().size(),
"OpCondIf: input tensor rank mismatch with then_block input rank");
@@ -247,9 +247,9 @@ int OpCondIf::checkTensorAttributes()
std::string else_block_output_name = else_block->GetOutputs()[i];
TosaSerializationTensor* then_block_output = then_block->GetTensorByName(then_block_output_name);
TosaSerializationTensor* else_block_output = else_block->GetTensorByName(else_block_output_name);
- ERROR_IF(operator_output->getDtype() != then_block_output->GetDtype(),
+ ERROR_IF(operator_output->getDtype() != ConvertDType(then_block_output->GetDtype()),
"OpCondIf: output tensor type mismatch with then_block output type");
- ERROR_IF(operator_output->getDtype() != else_block_output->GetDtype(),
+ ERROR_IF(operator_output->getDtype() != ConvertDType(else_block_output->GetDtype()),
"OpCondIf: output tensor type mismatch with else_block output type");
ERROR_IF(operator_output->getRank() != (int32_t)then_block_output->GetShape().size(),
"OpCondIf: output tensor rank mismatch with then_block output rank");
@@ -364,11 +364,11 @@ int OpWhileLoop::checkTensorAttributes()
TosaSerializationTensor* body_block_input = body_block->GetTensorByName(body_block_input_name);
TosaSerializationTensor* body_block_output = body_block->GetTensorByName(body_block_output_name);
- ERROR_IF(operator_input->getDtype() != cond_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(cond_block_input->GetDtype()),
"OpWhileLoop: input tensor type mismatch with cond_block input type");
- ERROR_IF(operator_input->getDtype() != body_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_input->GetDtype()),
"OpWhileLoop: input tensor type mismatch with body_block input type");
- ERROR_IF(operator_input->getDtype() != body_block_output->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_output->GetDtype()),
"OpWhileLoop: input tensor type mismatch with body_block output type");
ERROR_IF(operator_input->getRank() != (int32_t)cond_block_input->GetShape().size(),
"OpWhileLoop: input tensor rank mismatch with cond_block input rank");
@@ -399,8 +399,7 @@ int OpWhileLoop::checkTensorAttributes()
int OpWhileLoop::eval()
{
-
- TosaReference::Tensor0<bool> cond_output_ctensor(std::string("cond_output"), DType_BOOL, std::vector<int32_t>({}));
+ TosaReference::Tensor0<bool> cond_output_ctensor("cond_output", DType_BOOL, std::vector<int32_t>({}));
cond_output_ctensor.allocate();
std::vector<TosaReference::Tensor*> cond_block_outputs;