aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/control_flow.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/control_flow.cc')
-rw-r--r--reference_model/src/ops/control_flow.cc16
1 files changed, 13 insertions, 3 deletions
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index ac09bbb..4b0553e 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -172,12 +172,17 @@ int OpCondIf::checkTensorAttributes()
{
ERROR_IF(!tsh, "OpCondIf: tosa serialization handler must not be null");
- ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand");
+ int32_t num_inputs = getInputs().size();
+ ERROR_IF(num_inputs < 1, "OpCondIf: must have at least 1 operand");
ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0,
"OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNameTOSAREFTYPE(inputs[0]->getDtype()),
inputs[0]->getRank());
+ auto tosa_level = g_func_config.tosa_level;
+ LEVEL_CHECK(num_inputs <= tosa_level.MAX_TENSOR_LIST_SIZE,
+ "num_inputs should be smaller than or equal to MAX_TENSOR_LIST_SIZE");
+
cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
ASSERT_MEM(cond);
@@ -315,7 +320,8 @@ int OpWhileLoop::checkTensorAttributes()
return 1;
}
- if (getInputs().size() <= 0)
+ int32_t num_inputs = getInputs().size();
+ if (num_inputs <= 0)
{
WARNING("OpWhileLoop: must have at least 1 operands");
return 1;
@@ -327,6 +333,10 @@ int OpWhileLoop::checkTensorAttributes()
return 1;
}
+ auto tosa_level = g_func_config.tosa_level;
+ LEVEL_CHECK(num_inputs <= tosa_level.MAX_TENSOR_LIST_SIZE,
+ "num_inputs should be smaller than or equal to MAX_TENSOR_LIST_SIZE");
+
auto cond_region = tsh->GetRegionByName(attribute->cond_graph());
auto body_region = tsh->GetRegionByName(attribute->body_graph());
if (cond_region && body_region)