diff options
-rw-r--r-- | reference_model/include/func_config.h | 17 | ||||
-rw-r--r-- | reference_model/src/ops/control_flow.cc | 16 | ||||
-rw-r--r-- | reference_model/src/ops/custom.cc | 11 | ||||
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 2 |
4 files changed, 33 insertions, 13 deletions
diff --git a/reference_model/include/func_config.h b/reference_model/include/func_config.h index 97afa82..554f8dc 100644 --- a/reference_model/include/func_config.h +++ b/reference_model/include/func_config.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,15 +21,16 @@ struct tosa_level_t { - int32_t MAX_RANK = 0; - int32_t MAX_KERNEL = 0; - int32_t MAX_STRIDE = 0; - int32_t MAX_SCALE = 0; + int32_t MAX_RANK = 0; + int32_t MAX_KERNEL = 0; + int32_t MAX_STRIDE = 0; + int32_t MAX_SCALE = 0; + int32_t MAX_TENSOR_LIST_SIZE = 0; bool operator!=(const tosa_level_t& rhs) { return !(MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && MAX_STRIDE == rhs.MAX_STRIDE && - MAX_SCALE == rhs.MAX_SCALE); + MAX_SCALE == rhs.MAX_SCALE && MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE); } }; @@ -60,8 +61,8 @@ struct func_config_t bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian() tosa_level_t tosa_level; - static constexpr tosa_level_t EIGHTK = { 6, 8192, 8192, 256 }; - static constexpr tosa_level_t NONE = { 0, 0, 0, 0 }; + static constexpr tosa_level_t EIGHTK = { 6, 8192, 8192, 256, 64 }; + static constexpr tosa_level_t NONE = { 0, 0, 0, 0, 0 }; }; #endif diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index ac09bbb..4b0553e 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -172,12 +172,17 @@ int OpCondIf::checkTensorAttributes() { ERROR_IF(!tsh, "OpCondIf: tosa serialization handler must not be null"); - ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand"); + int32_t num_inputs = getInputs().size(); + ERROR_IF(num_inputs < 1, "OpCondIf: must have at least 1 operand"); ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0, "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNameTOSAREFTYPE(inputs[0]->getDtype()), inputs[0]->getRank()); + auto tosa_level = g_func_config.tosa_level; + LEVEL_CHECK(num_inputs <= tosa_level.MAX_TENSOR_LIST_SIZE, + "num_inputs should be smaller than or equal to MAX_TENSOR_LIST_SIZE"); + cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]); ASSERT_MEM(cond); @@ -315,7 +320,8 @@ int OpWhileLoop::checkTensorAttributes() return 1; } - if (getInputs().size() <= 0) + int32_t num_inputs = getInputs().size(); + if (num_inputs <= 0) { WARNING("OpWhileLoop: must have at least 1 operands"); return 1; @@ -327,6 +333,10 @@ int OpWhileLoop::checkTensorAttributes() return 1; } + auto tosa_level = g_func_config.tosa_level; + LEVEL_CHECK(num_inputs <= tosa_level.MAX_TENSOR_LIST_SIZE, + "num_inputs should be smaller than or equal to MAX_TENSOR_LIST_SIZE"); + auto cond_region = tsh->GetRegionByName(attribute->cond_graph()); auto body_region = tsh->GetRegionByName(attribute->body_graph()); if (cond_region && body_region) diff --git a/reference_model/src/ops/custom.cc b/reference_model/src/ops/custom.cc index 39a6f87..3773592 100644 --- a/reference_model/src/ops/custom.cc +++ b/reference_model/src/ops/custom.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020, 2023, ARM Limited. +// Copyright (c) 2020, 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -51,8 +51,15 @@ int OpCustom::checkTensorAttributes() int OpCustom::eval() { + auto inputs = getInputs(); + int32_t num_inputs = inputs.size(); + auto tosa_level = g_func_config.tosa_level; + LEVEL_CHECK(num_inputs <= tosa_level.MAX_TENSOR_LIST_SIZE, + "num_inputs should be smaller than or equal to MAX_TENSOR_LIST_SIZE"); + auto implementation_attrs_vec = attribute->implementation_attrs(); std::string implementation_attrs(implementation_attrs_vec.begin(), implementation_attrs_vec.end()); - custom_op_ptr->eval(getInputs(), getOutputs(), implementation_attrs); + custom_op_ptr->eval(inputs, getOutputs(), implementation_attrs); + return GraphNode::eval(); } diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 3e3770e..3b8d13a 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -54,6 +54,8 @@ int OpConcat<Rank, Dtype>::checkTensorAttributes() } int32_t num_inputs = inputs.size(); + LEVEL_CHECK(num_inputs <= tosa_level.MAX_TENSOR_LIST_SIZE, + "num_inputs should be smaller than or equal to MAX_TENSOR_LIST_SIZE"); // output and input must be the same types and rank for (int32_t i = 0; i < num_inputs; i++) |