diff options
Diffstat (limited to 'reference_model/src/ops/control_flow.cc')
-rw-r--r-- | reference_model/src/ops/control_flow.cc | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index f573d5b..03ad6c6 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -174,8 +174,8 @@ int OpCondIf::checkTensorAttributes() { ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand"); - ERROR_IF(inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0, - "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()], + ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0, + "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNameTOSAREFTYPE(inputs[0]->getDtype()), inputs[0]->getRank()); cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]); @@ -223,9 +223,9 @@ int OpCondIf::checkTensorAttributes() std::string else_block_input_name = else_block->GetInputs()[i]; TosaSerializationTensor* then_block_input = then_block->GetTensorByName(then_block_input_name); TosaSerializationTensor* else_block_input = else_block->GetTensorByName(else_block_input_name); - ERROR_IF(operator_input->getDtype() != then_block_input->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(then_block_input->GetDtype()), "OpCondIf: input tensor type mismatch with then_block input type"); - ERROR_IF(operator_input->getDtype() != else_block_input->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(else_block_input->GetDtype()), "OpCondIf: input tensor type mismatch with else_block input type"); ERROR_IF(operator_input->getRank() != (int32_t)then_block_input->GetShape().size(), "OpCondIf: input tensor rank mismatch with then_block input rank"); @@ -247,9 +247,9 @@ int OpCondIf::checkTensorAttributes() std::string else_block_output_name = else_block->GetOutputs()[i]; TosaSerializationTensor* then_block_output = then_block->GetTensorByName(then_block_output_name); TosaSerializationTensor* else_block_output = else_block->GetTensorByName(else_block_output_name); - ERROR_IF(operator_output->getDtype() != then_block_output->GetDtype(), + ERROR_IF(operator_output->getDtype() != ConvertDType(then_block_output->GetDtype()), "OpCondIf: output tensor type mismatch with then_block output type"); - ERROR_IF(operator_output->getDtype() != else_block_output->GetDtype(), + ERROR_IF(operator_output->getDtype() != ConvertDType(else_block_output->GetDtype()), "OpCondIf: output tensor type mismatch with else_block output type"); ERROR_IF(operator_output->getRank() != (int32_t)then_block_output->GetShape().size(), "OpCondIf: output tensor rank mismatch with then_block output rank"); @@ -364,11 +364,11 @@ int OpWhileLoop::checkTensorAttributes() TosaSerializationTensor* body_block_input = body_block->GetTensorByName(body_block_input_name); TosaSerializationTensor* body_block_output = body_block->GetTensorByName(body_block_output_name); - ERROR_IF(operator_input->getDtype() != cond_block_input->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(cond_block_input->GetDtype()), "OpWhileLoop: input tensor type mismatch with cond_block input type"); - ERROR_IF(operator_input->getDtype() != body_block_input->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_input->GetDtype()), "OpWhileLoop: input tensor type mismatch with body_block input type"); - ERROR_IF(operator_input->getDtype() != body_block_output->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_output->GetDtype()), "OpWhileLoop: input tensor type mismatch with body_block output type"); ERROR_IF(operator_input->getRank() != (int32_t)cond_block_input->GetShape().size(), "OpWhileLoop: input tensor rank mismatch with cond_block input rank"); @@ -399,8 +399,7 @@ int OpWhileLoop::checkTensorAttributes() int OpWhileLoop::eval() { - - TosaReference::Tensor0<bool> cond_output_ctensor(std::string("cond_output"), DType_BOOL, std::vector<int32_t>({})); + TosaReference::Tensor0<bool> cond_output_ctensor("cond_output", DType_BOOL, std::vector<int32_t>({})); cond_output_ctensor.allocate(); std::vector<TosaReference::Tensor*> cond_block_outputs; |