aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-01-25 12:53:21 +0000
committerEric Kunze <eric.kunze@arm.com>2024-01-30 16:11:26 +0000
commitfc4bde92120567a98189f95cfe90bb1699d25809 (patch)
tree48ab320c288440816734fc771a817257f3a229f9
parent95a6710ffb8cadcb8658a967ab29cac1bffad930 (diff)
downloadreference_model-fc4bde92120567a98189f95cfe90bb1699d25809.tar.gz
Fix up shape operator test errors
Update serialization_lib to store SHAPE as INT64. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ie589cd6670dc79b77df981c81cd7c27b982f20fa
m---------thirdparty/serialization_lib0
-rw-r--r--verif/conformance/tosa_base_profile_ops_info.json1
-rw-r--r--verif/generator/tosa_arg_gen.py22
-rw-r--r--verif/generator/tosa_test_gen.py14
-rw-r--r--verif/generator/tosa_utils.py2
5 files changed, 12 insertions, 27 deletions
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject 5d580faec02bcef56164587accb5fd88a3e80d8
+Subproject 7c22d77a71eb59885dab1cbc3b957384c10a2af
diff --git a/verif/conformance/tosa_base_profile_ops_info.json b/verif/conformance/tosa_base_profile_ops_info.json
index ec51324..54bce21 100644
--- a/verif/conformance/tosa_base_profile_ops_info.json
+++ b/verif/conformance/tosa_base_profile_ops_info.json
@@ -1011,6 +1011,7 @@
],
"generation": {
"standard": {
+ "no_negative_tests": "true",
"generator_args": [
[
"--target-dtype",
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 91d2d62..386e243 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -622,28 +622,6 @@ class TosaTensorGen:
return new_shapeList
- @staticmethod
- def tgShape(testGen, opName, rank, error_name=None):
- pl, const = opName["operands"]
- shape = [rank]
-
- # Constrict the overall size of the shape when creating ERROR_IF tests
- if error_name:
- shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
-
- shape_list = []
- for i in range(pl + const):
- shape_list.append(shape.copy())
-
- # Generates an input rank mismatch for operators with more than one input
- if error_name == ErrorIf.RankMismatch:
- if rank == 1 and i != 1:
- shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
- elif i != 1:
- shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
-
- return shape_list
-
class TosaTensorValuesGen:
"""Tensor Value generators create the random data for each tensor in each test."""
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 39b064d..a347b13 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -4714,9 +4714,10 @@ class TosaTestGen:
"add_shape": {
"op": Op.ADD_SHAPE,
"operands": (2, 0),
+ "rank": (1, 1),
"build_fcn": (
build_shape_op,
- TosaTensorGen.tgShape,
+ TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgAddSub,
TosaArgGen.agNone,
),
@@ -4726,9 +4727,10 @@ class TosaTestGen:
"sub_shape": {
"op": Op.SUB_SHAPE,
"operands": (2, 0),
+ "rank": (1, 1),
"build_fcn": (
build_shape_op,
- TosaTensorGen.tgShape,
+ TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgAddSub,
TosaArgGen.agNone,
),
@@ -4738,9 +4740,10 @@ class TosaTestGen:
"mul_shape": {
"op": Op.MUL_SHAPE,
"operands": (2, 0),
+ "rank": (1, 1),
"build_fcn": (
build_shape_op,
- TosaTensorGen.tgShape,
+ TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgMul,
TosaArgGen.agNone,
),
@@ -4750,9 +4753,10 @@ class TosaTestGen:
"div_shape": {
"op": Op.DIV_SHAPE,
"operands": (2, 0),
+ "rank": (1, 1),
"build_fcn": (
build_shape_op,
- TosaTensorGen.tgShape,
+ TosaTensorGen.tgBasic,
TosaTensorValuesGen.tvgIntDiv,
TosaArgGen.agNone,
),
@@ -4762,6 +4766,7 @@ class TosaTestGen:
"concat_shape": {
"op": Op.CONCAT_SHAPE,
"operands": (2, 0),
+ "rank": (1, 1),
"build_fcn": (
build_concat,
TosaTensorGen.tgConcat,
@@ -4774,6 +4779,7 @@ class TosaTestGen:
"const_shape": {
"op": Op.CONST_SHAPE,
"operands": (0, 1),
+ "rank": (1, 1),
"build_fcn": (
build_const,
TosaTensorGen.tgBasic,
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 33db95f..6387d06 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -56,7 +56,7 @@ def dtypeIsSupportedByCompliance(dtype):
"""Types supported by the new data generation and compliance flow."""
if isinstance(dtype, list) or isinstance(dtype, tuple):
dtype = dtype[0]
- return dtype in (DType.FP32, DType.FP16, DType.SHAPE)
+ return dtype in (DType.FP32, DType.FP16)
def getOpNameFromOpListName(opName):