aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/control_flow.cc
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-03-28 22:06:56 +0000
committerTai Ly <tai.ly@arm.com>2023-05-05 19:23:15 +0000
commita4d748b08accce06fab93e2d2b96e499b35ae89b (patch)
tree20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/src/ops/control_flow.cc
parent0c71686875618b2e11290273b7a05b88ef8a8aae (diff)
downloadreference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/src/ops/control_flow.cc')
-rw-r--r--reference_model/src/ops/control_flow.cc21
1 files changed, 10 insertions, 11 deletions
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index f573d5b..03ad6c6 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -174,8 +174,8 @@ int OpCondIf::checkTensorAttributes()
{
ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand");
- ERROR_IF(inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0,
- "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()],
+ ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0,
+ "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNameTOSAREFTYPE(inputs[0]->getDtype()),
inputs[0]->getRank());
cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
@@ -223,9 +223,9 @@ int OpCondIf::checkTensorAttributes()
std::string else_block_input_name = else_block->GetInputs()[i];
TosaSerializationTensor* then_block_input = then_block->GetTensorByName(then_block_input_name);
TosaSerializationTensor* else_block_input = else_block->GetTensorByName(else_block_input_name);
- ERROR_IF(operator_input->getDtype() != then_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(then_block_input->GetDtype()),
"OpCondIf: input tensor type mismatch with then_block input type");
- ERROR_IF(operator_input->getDtype() != else_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(else_block_input->GetDtype()),
"OpCondIf: input tensor type mismatch with else_block input type");
ERROR_IF(operator_input->getRank() != (int32_t)then_block_input->GetShape().size(),
"OpCondIf: input tensor rank mismatch with then_block input rank");
@@ -247,9 +247,9 @@ int OpCondIf::checkTensorAttributes()
std::string else_block_output_name = else_block->GetOutputs()[i];
TosaSerializationTensor* then_block_output = then_block->GetTensorByName(then_block_output_name);
TosaSerializationTensor* else_block_output = else_block->GetTensorByName(else_block_output_name);
- ERROR_IF(operator_output->getDtype() != then_block_output->GetDtype(),
+ ERROR_IF(operator_output->getDtype() != ConvertDType(then_block_output->GetDtype()),
"OpCondIf: output tensor type mismatch with then_block output type");
- ERROR_IF(operator_output->getDtype() != else_block_output->GetDtype(),
+ ERROR_IF(operator_output->getDtype() != ConvertDType(else_block_output->GetDtype()),
"OpCondIf: output tensor type mismatch with else_block output type");
ERROR_IF(operator_output->getRank() != (int32_t)then_block_output->GetShape().size(),
"OpCondIf: output tensor rank mismatch with then_block output rank");
@@ -364,11 +364,11 @@ int OpWhileLoop::checkTensorAttributes()
TosaSerializationTensor* body_block_input = body_block->GetTensorByName(body_block_input_name);
TosaSerializationTensor* body_block_output = body_block->GetTensorByName(body_block_output_name);
- ERROR_IF(operator_input->getDtype() != cond_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(cond_block_input->GetDtype()),
"OpWhileLoop: input tensor type mismatch with cond_block input type");
- ERROR_IF(operator_input->getDtype() != body_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_input->GetDtype()),
"OpWhileLoop: input tensor type mismatch with body_block input type");
- ERROR_IF(operator_input->getDtype() != body_block_output->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_output->GetDtype()),
"OpWhileLoop: input tensor type mismatch with body_block output type");
ERROR_IF(operator_input->getRank() != (int32_t)cond_block_input->GetShape().size(),
"OpWhileLoop: input tensor rank mismatch with cond_block input rank");
@@ -399,8 +399,7 @@ int OpWhileLoop::checkTensorAttributes()
int OpWhileLoop::eval()
{
-
- TosaReference::Tensor0<bool> cond_output_ctensor(std::string("cond_output"), DType_BOOL, std::vector<int32_t>({}));
+ TosaReference::Tensor0<bool> cond_output_ctensor("cond_output", DType_BOOL, std::vector<int32_t>({}));
cond_output_ctensor.allocate();
std::vector<TosaReference::Tensor*> cond_block_outputs;