diff options
Diffstat (limited to 'reference_model/src/ops/control_flow.cc')
-rw-r--r-- | reference_model/src/ops/control_flow.cc | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index ac09bbb..4b0553e 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -172,12 +172,17 @@ int OpCondIf::checkTensorAttributes() { ERROR_IF(!tsh, "OpCondIf: tosa serialization handler must not be null"); - ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand"); + int32_t num_inputs = getInputs().size(); + ERROR_IF(num_inputs < 1, "OpCondIf: must have at least 1 operand"); ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0, "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNameTOSAREFTYPE(inputs[0]->getDtype()), inputs[0]->getRank()); + auto tosa_level = g_func_config.tosa_level; + LEVEL_CHECK(num_inputs <= tosa_level.MAX_TENSOR_LIST_SIZE, + "num_inputs should be smaller than or equal to MAX_TENSOR_LIST_SIZE"); + cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]); ASSERT_MEM(cond); @@ -315,7 +320,8 @@ int OpWhileLoop::checkTensorAttributes() return 1; } - if (getInputs().size() <= 0) + int32_t num_inputs = getInputs().size(); + if (num_inputs <= 0) { WARNING("OpWhileLoop: must have at least 1 operands"); return 1; @@ -327,6 +333,10 @@ int OpWhileLoop::checkTensorAttributes() return 1; } + auto tosa_level = g_func_config.tosa_level; + LEVEL_CHECK(num_inputs <= tosa_level.MAX_TENSOR_LIST_SIZE, + "num_inputs should be smaller than or equal to MAX_TENSOR_LIST_SIZE"); + auto cond_region = tsh->GetRegionByName(attribute->cond_graph()); auto body_region = tsh->GetRegionByName(attribute->body_graph()); if (cond_region && body_region) |