aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2024-06-18 13:26:23 -0700
committerJerry Ge <jerry.ge@arm.com>2024-06-24 16:48:25 -0700
commit1c64867fc19cbd5109c3f7a184687a38afb6cd0f (patch)
tree041ab4cb477b5dc0091eec562b530ae9954caab3
parent5e1d8b07478d06ddc3bcf1fdd7ad0e5276c1e333 (diff)
downloadreference_model-1c64867fc19cbd5109c3f7a184687a38afb6cd0f.tar.gz
Update tosa Level_check to include MAX_TENSOR_LIST_SIZE
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: Ie729c68a334431d89abb021125dc77819329eb4c
-rw-r--r--reference_model/include/func_config.h17
-rw-r--r--reference_model/src/ops/control_flow.cc16
-rw-r--r--reference_model/src/ops/custom.cc11
-rw-r--r--reference_model/src/ops/data_layout.cc2
4 files changed, 33 insertions, 13 deletions
diff --git a/reference_model/include/func_config.h b/reference_model/include/func_config.h
index 97afa82..554f8dc 100644
--- a/reference_model/include/func_config.h
+++ b/reference_model/include/func_config.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -21,15 +21,16 @@
struct tosa_level_t
{
- int32_t MAX_RANK = 0;
- int32_t MAX_KERNEL = 0;
- int32_t MAX_STRIDE = 0;
- int32_t MAX_SCALE = 0;
+ int32_t MAX_RANK = 0;
+ int32_t MAX_KERNEL = 0;
+ int32_t MAX_STRIDE = 0;
+ int32_t MAX_SCALE = 0;
+ int32_t MAX_TENSOR_LIST_SIZE = 0;
bool operator!=(const tosa_level_t& rhs)
{
return !(MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && MAX_STRIDE == rhs.MAX_STRIDE &&
- MAX_SCALE == rhs.MAX_SCALE);
+ MAX_SCALE == rhs.MAX_SCALE && MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE);
}
};
@@ -60,8 +61,8 @@ struct func_config_t
bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian()
tosa_level_t tosa_level;
- static constexpr tosa_level_t EIGHTK = { 6, 8192, 8192, 256 };
- static constexpr tosa_level_t NONE = { 0, 0, 0, 0 };
+ static constexpr tosa_level_t EIGHTK = { 6, 8192, 8192, 256, 64 };
+ static constexpr tosa_level_t NONE = { 0, 0, 0, 0, 0 };
};
#endif
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index ac09bbb..4b0553e 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -172,12 +172,17 @@ int OpCondIf::checkTensorAttributes()
{
ERROR_IF(!tsh, "OpCondIf: tosa serialization handler must not be null");
- ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand");
+ int32_t num_inputs = getInputs().size();
+ ERROR_IF(num_inputs < 1, "OpCondIf: must have at least 1 operand");
ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0,
"OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNameTOSAREFTYPE(inputs[0]->getDtype()),
inputs[0]->getRank());
+ auto tosa_level = g_func_config.tosa_level;
+ LEVEL_CHECK(num_inputs <= tosa_level.MAX_TENSOR_LIST_SIZE,
+ "num_inputs should be smaller than or equal to MAX_TENSOR_LIST_SIZE");
+
cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
ASSERT_MEM(cond);
@@ -315,7 +320,8 @@ int OpWhileLoop::checkTensorAttributes()
return 1;
}
- if (getInputs().size() <= 0)
+ int32_t num_inputs = getInputs().size();
+ if (num_inputs <= 0)
{
WARNING("OpWhileLoop: must have at least 1 operands");
return 1;
@@ -327,6 +333,10 @@ int OpWhileLoop::checkTensorAttributes()
return 1;
}
+ auto tosa_level = g_func_config.tosa_level;
+ LEVEL_CHECK(num_inputs <= tosa_level.MAX_TENSOR_LIST_SIZE,
+ "num_inputs should be smaller than or equal to MAX_TENSOR_LIST_SIZE");
+
auto cond_region = tsh->GetRegionByName(attribute->cond_graph());
auto body_region = tsh->GetRegionByName(attribute->body_graph());
if (cond_region && body_region)
diff --git a/reference_model/src/ops/custom.cc b/reference_model/src/ops/custom.cc
index 39a6f87..3773592 100644
--- a/reference_model/src/ops/custom.cc
+++ b/reference_model/src/ops/custom.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, 2023, ARM Limited.
+// Copyright (c) 2020, 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -51,8 +51,15 @@ int OpCustom::checkTensorAttributes()
int OpCustom::eval()
{
+ auto inputs = getInputs();
+ int32_t num_inputs = inputs.size();
+ auto tosa_level = g_func_config.tosa_level;
+ LEVEL_CHECK(num_inputs <= tosa_level.MAX_TENSOR_LIST_SIZE,
+ "num_inputs should be smaller than or equal to MAX_TENSOR_LIST_SIZE");
+
auto implementation_attrs_vec = attribute->implementation_attrs();
std::string implementation_attrs(implementation_attrs_vec.begin(), implementation_attrs_vec.end());
- custom_op_ptr->eval(getInputs(), getOutputs(), implementation_attrs);
+ custom_op_ptr->eval(inputs, getOutputs(), implementation_attrs);
+
return GraphNode::eval();
}
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 3e3770e..3b8d13a 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -54,6 +54,8 @@ int OpConcat<Rank, Dtype>::checkTensorAttributes()
}
int32_t num_inputs = inputs.size();
+ LEVEL_CHECK(num_inputs <= tosa_level.MAX_TENSOR_LIST_SIZE,
+ "num_inputs should be smaller than or equal to MAX_TENSOR_LIST_SIZE");
// output and input must be the same types and rank
for (int32_t i = 0; i < num_inputs; i++)