aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-05-01 18:36:43 +0000
committerJerry Ge <jerry.ge@arm.com>2023-05-10 02:40:49 +0000
commit0bd4ec89d52cc1fd36e92dff2fb496b3550ee7f5 (patch)
treed2662a0e62aec08a648edf61da62ee789a481080
parenta4d748b08accce06fab93e2d2b96e499b35ae89b (diff)
downloadreference_model-0bd4ec89d52cc1fd36e92dff2fb496b3550ee7f5.tar.gz
Refactor ref_model rank checking and add level check to argmax
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: Iad035b31d5e5e83040068e6311501490765bfff7
-rw-r--r--reference_model/src/graph_node.h18
-rw-r--r--reference_model/src/ops/data_layout.cc18
-rw-r--r--reference_model/src/ops/ewise_binary.cc6
-rw-r--r--reference_model/src/ops/ewise_ternary.cc7
-rw-r--r--reference_model/src/ops/ewise_unary.cc6
-rw-r--r--reference_model/src/ops/image.cc10
-rw-r--r--reference_model/src/ops/reduction.cc2
-rw-r--r--reference_model/src/ops/tensor_ops.cc26
-rw-r--r--reference_model/src/ops/type_conversion.cc6
-rw-r--r--reference_model/src/tensor.h4
-rw-r--r--verif/generator/tosa_test_gen.py2
-rw-r--r--verif/runner/tosa_refmodel_sut_run.py1
-rw-r--r--verif/tests/test_tosa_refmodel.py3
13 files changed, 41 insertions, 68 deletions
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index 3433192..aafc07f 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -270,20 +270,16 @@ protected:
int setRequiredRank(const int min, const int max = -1)
{
- if (max == -1)
- {
- requiredRankMin = requiredRankMax = min;
- }
- else
+ requiredRankMin = min;
+ requiredRankMax = max;
+
+ if (requiredRankMin >= 0 && requiredRankMax >= 0)
{
- requiredRankMin = min;
- requiredRankMax = max;
+ ASSERT_MSG(requiredRankMin <= requiredRankMax,
+ "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin,
+ requiredRankMax);
}
- ASSERT_MSG(requiredRankMin <= requiredRankMax,
- "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin,
- requiredRankMax);
-
return 0;
}
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 442cef8..fd19f96 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -27,7 +27,7 @@ OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_CONCAT, id_)
{
setRequiredOperands(-1, 1);
- setRequiredRank(1, 6);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Axis);
}
@@ -131,7 +131,7 @@ OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_PAD, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 6);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Pad);
}
@@ -221,7 +221,6 @@ OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_RESHAPE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
INIT_ATTRIBUTE(Reshape);
}
@@ -244,11 +243,6 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
- {
- return 1;
- }
-
// output and input must be the same types
if (inputs[0]->matchType(*outputs[0]))
{
@@ -321,7 +315,7 @@ OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_REVERSE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 6);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Axis);
}
@@ -392,7 +386,7 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_SLICE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 6);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Slice);
}
@@ -465,7 +459,7 @@ OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_TILE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 6);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Tile);
}
@@ -667,7 +661,7 @@ OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_TRANSPOSE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 6);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Transpose);
}
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index c5801e7..1e873e7 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -29,7 +29,6 @@ BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
: GraphNode(sgt_, op_, id_)
{
setRequiredOperands(2, 1);
- setRequiredRank(0, 6);
a = b = nullptr;
result = nullptr;
@@ -51,11 +50,6 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
- {
- return 1;
- }
-
// A & B must be the same rank and types
if (inputs[0]->matchRankType(*inputs[1]))
{
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index 090ce29..16554b5 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -26,7 +26,6 @@ OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_SELECT, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(0, 6);
}
template <int Rank, TOSA_REF_TYPE Dtype>
@@ -43,12 +42,6 @@ int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) ||
- validateRequiredRank(outputs[0]))
- {
- return 1;
- }
-
// output and input must be the same types
if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */) ||
inputs[1]->matchRankTypeShape(*outputs[0], true /* broadcastOk */) ||
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 514cb84..e6e870e 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -27,7 +27,6 @@ UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64
: GraphNode(sgt_, op_, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
fcn = [](InEigenType a) -> OutEigenType {
ASSERT_MSG(0, "In default UnaryNode function, missing function registration");
@@ -49,11 +48,6 @@ int UnaryNode<Rank, Dtype>::checkTensorAttributes()
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
- {
- return 1;
- }
-
// output and input must be the same types
if (inputs[0]->matchRankTypeShape(*outputs[0]))
{
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index ca12cfe..575a500 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -113,11 +113,6 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
int16_t border_y = border[0];
int16_t border_x = border[1];
- // Check Tosa Level
- auto tosa_level = g_func_config.tosa_level;
- LEVEL_CHECK(scale_y_n / scale_y_d <= tosa_level.MAX_SCALE, "scale_y_n / scale_y_d should be smaller than or equal to MAX_SCALE");
- LEVEL_CHECK(scale_x_n / scale_x_d <= tosa_level.MAX_SCALE, "scale_x_n / scale_x_d should be smaller than or equal to MAX_SCALE");
-
ERROR_IF(std::max<int>({ in_height, in_width, out_height, out_width }) >= 16384,
"OpResize: exceeds maximum dimension");
ERROR_IF(in_batch != out_batch, "OpResize: output tensor batch mismatch");
@@ -137,6 +132,11 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
ERROR_IF((border_x < -16 * scale_x_n || border_x >= scale_x_n),
"OpResize: invalid attribute border width dimension");
+ // Check Tosa Level
+ auto tosa_level = g_func_config.tosa_level;
+ LEVEL_CHECK(scale_y_n / scale_y_d <= tosa_level.MAX_SCALE, "scale_y_n / scale_y_d should be smaller than or equal to MAX_SCALE");
+ LEVEL_CHECK(scale_x_n / scale_x_d <= tosa_level.MAX_SCALE, "scale_x_n / scale_x_d should be smaller than or equal to MAX_SCALE");
+
int32_t res_height = 0;
int32_t res_width = 0;
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index bf8ba57..fd48472 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -25,7 +25,7 @@ ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, Tosa
: GraphNode(sgt_, op_, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 4);
+ setRequiredRank(1, 4);
INIT_ATTRIBUTE(Axis);
}
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index f8fd323..a60819d 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -327,7 +327,7 @@ OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_ARGMAX, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 4);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Axis);
}
@@ -405,6 +405,10 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes()
template <int Rank, TOSA_REF_TYPE Dtype>
int OpArgMax<Rank, Dtype>::eval()
{
+ // Check Tosa Level
+ auto tosa_level = g_func_config.tosa_level;
+ LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
+
Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
@@ -419,7 +423,7 @@ OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_AVG_POOL2D, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(Pool);
}
@@ -645,7 +649,7 @@ OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_CONV2D, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(Conv);
}
@@ -839,7 +843,7 @@ OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_CONV3D, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(5);
+ setRequiredRank(5, 5);
INIT_ATTRIBUTE(Conv);
}
@@ -1042,7 +1046,7 @@ OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTra
: GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(Conv);
}
@@ -1227,7 +1231,7 @@ OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTrave
: GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(2);
+ setRequiredRank(2, 2);
INIT_ATTRIBUTE(FullyConnected);
}
@@ -1322,7 +1326,7 @@ OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_MATMUL, id_)
{
setRequiredOperands(2, 1);
- setRequiredRank(3);
+ setRequiredRank(3, 3);
INIT_ATTRIBUTE(MatMul);
}
@@ -1460,7 +1464,7 @@ OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_MAX_POOL2D, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(Pool);
}
@@ -1601,7 +1605,7 @@ OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_,
: GraphNode(sgt_, Op_FFT2D, id_)
{
setRequiredOperands(2, 2);
- setRequiredRank(3);
+ setRequiredRank(3, 3);
INIT_ATTRIBUTE(FFT);
}
@@ -1724,7 +1728,7 @@ OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_RFFT2D, id_)
{
setRequiredOperands(1, 2);
- setRequiredRank(3);
+ setRequiredRank(3, 3);
}
template <TOSA_REF_TYPE Dtype>
@@ -1830,7 +1834,7 @@ OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTra
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(TransposeConv);
}
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 68ffb1f..fce8e7c 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -31,7 +31,6 @@ OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_RESCALE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
INIT_ATTRIBUTE(Rescale);
}
@@ -52,11 +51,6 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
- {
- return 1;
- }
-
// output and input must be the same rank and size
if (inputs[0]->matchRankSize(*outputs[0]))
{
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 08ee8bf..b68a9b6 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -197,9 +197,9 @@ public:
}
// Unary check to make sure rank matches
- const int checkRequiredRank(const int exactRank) const
+ const int checkRequiredRank(const int minRank) const
{
- return (shape.size() == (size_t)exactRank) ? 0 : 1;
+ return (shape.size() >= (size_t)minRank) ? 0 : 1;
}
const int checkRequiredRank(const int minRank, const int maxRank) const
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 65bdeb7..c8c22c2 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -2614,7 +2614,7 @@ class TosaTestGen:
"argmax": {
"op": Op.ARGMAX,
"operands": (1, 0),
- "rank": (1, 4),
+ "rank": (1, 6),
"build_fcn": (
build_argmax,
TosaTensorGen.tgBasic,
diff --git a/verif/runner/tosa_refmodel_sut_run.py b/verif/runner/tosa_refmodel_sut_run.py
index df5c0db..7b129da 100644
--- a/verif/runner/tosa_refmodel_sut_run.py
+++ b/verif/runner/tosa_refmodel_sut_run.py
@@ -34,6 +34,7 @@ class TosaSUTRunner(TosaTestRunner):
# Call Reference model with description file to provide all file details
cmd = [
args.ref_model_path,
+ "--tosa_level={}".format(args.tosa_level),
"--operator_fbs={}".format(args.operator_fbs),
"--test_desc={}".format(self.descFile),
]
diff --git a/verif/tests/test_tosa_refmodel.py b/verif/tests/test_tosa_refmodel.py
index 1f9cd3e..79e6720 100644
--- a/verif/tests/test_tosa_refmodel.py
+++ b/verif/tests/test_tosa_refmodel.py
@@ -37,6 +37,7 @@ OUTPUT_RESULT_FILE = "result_numpy_pytest.npy"
OUTPUT_CONST_GLOB = "const-*.npy"
TEST_DESC_FILENAME = "desc.json"
+TOSA_LEVEL = "EIGHTK"
# Conversion from refmodel type into the type abbreviation used in the test output
REF_MODEL_TYPE_TO_OUT = {
@@ -182,6 +183,8 @@ def test_refmodel_simple_op(tosaTest):
str(desc_file),
"--ofm_file",
OUTPUT_OFM_FILE,
+ "--tosa_level",
+ TOSA_LEVEL,
]
try:
run_sh_command(refmodel_cmd, verbose=True, capture_output=True)