diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-01-25 12:53:21 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-01-30 16:11:26 +0000 |
commit | fc4bde92120567a98189f95cfe90bb1699d25809 (patch) | |
tree | 48ab320c288440816734fc771a817257f3a229f9 | |
parent | 95a6710ffb8cadcb8658a967ab29cac1bffad930 (diff) | |
download | reference_model-fc4bde92120567a98189f95cfe90bb1699d25809.tar.gz |
Fix up shape operator test errors
Update serialization_lib to store SHAPE as INT64.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ie589cd6670dc79b77df981c81cd7c27b982f20fa
m--------- | thirdparty/serialization_lib | 0 | ||||
-rw-r--r-- | verif/conformance/tosa_base_profile_ops_info.json | 1 | ||||
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 22 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 14 | ||||
-rw-r--r-- | verif/generator/tosa_utils.py | 2 |
5 files changed, 12 insertions, 27 deletions
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib -Subproject 5d580faec02bcef56164587accb5fd88a3e80d8 +Subproject 7c22d77a71eb59885dab1cbc3b957384c10a2af diff --git a/verif/conformance/tosa_base_profile_ops_info.json b/verif/conformance/tosa_base_profile_ops_info.json index ec51324..54bce21 100644 --- a/verif/conformance/tosa_base_profile_ops_info.json +++ b/verif/conformance/tosa_base_profile_ops_info.json @@ -1011,6 +1011,7 @@ ], "generation": { "standard": { + "no_negative_tests": "true", "generator_args": [ [ "--target-dtype", diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 91d2d62..386e243 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -622,28 +622,6 @@ class TosaTensorGen: return new_shapeList - @staticmethod - def tgShape(testGen, opName, rank, error_name=None): - pl, const = opName["operands"] - shape = [rank] - - # Constrict the overall size of the shape when creating ERROR_IF tests - if error_name: - shape = TosaErrorIfArgGen.eiRestrictDimensions(shape) - - shape_list = [] - for i in range(pl + const): - shape_list.append(shape.copy()) - - # Generates an input rank mismatch for operators with more than one input - if error_name == ErrorIf.RankMismatch: - if rank == 1 and i != 1: - shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3])) - elif i != 1: - shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1])) - - return shape_list - class TosaTensorValuesGen: """Tensor Value generators create the random data for each tensor in each test.""" diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 39b064d..a347b13 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -4714,9 +4714,10 @@ class TosaTestGen: "add_shape": { "op": Op.ADD_SHAPE, "operands": (2, 0), + "rank": (1, 1), "build_fcn": ( build_shape_op, - TosaTensorGen.tgShape, + TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgAddSub, TosaArgGen.agNone, ), @@ -4726,9 +4727,10 @@ class TosaTestGen: "sub_shape": { "op": Op.SUB_SHAPE, "operands": (2, 0), + "rank": (1, 1), "build_fcn": ( build_shape_op, - TosaTensorGen.tgShape, + TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgAddSub, TosaArgGen.agNone, ), @@ -4738,9 +4740,10 @@ class TosaTestGen: "mul_shape": { "op": Op.MUL_SHAPE, "operands": (2, 0), + "rank": (1, 1), "build_fcn": ( build_shape_op, - TosaTensorGen.tgShape, + TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgMul, TosaArgGen.agNone, ), @@ -4750,9 +4753,10 @@ class TosaTestGen: "div_shape": { "op": Op.DIV_SHAPE, "operands": (2, 0), + "rank": (1, 1), "build_fcn": ( build_shape_op, - TosaTensorGen.tgShape, + TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgIntDiv, TosaArgGen.agNone, ), @@ -4762,6 +4766,7 @@ class TosaTestGen: "concat_shape": { "op": Op.CONCAT_SHAPE, "operands": (2, 0), + "rank": (1, 1), "build_fcn": ( build_concat, TosaTensorGen.tgConcat, @@ -4774,6 +4779,7 @@ class TosaTestGen: "const_shape": { "op": Op.CONST_SHAPE, "operands": (0, 1), + "rank": (1, 1), "build_fcn": ( build_const, TosaTensorGen.tgBasic, diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 33db95f..6387d06 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -56,7 +56,7 @@ def dtypeIsSupportedByCompliance(dtype): """Types supported by the new data generation and compliance flow.""" if isinstance(dtype, list) or isinstance(dtype, tuple): dtype = dtype[0] - return dtype in (DType.FP32, DType.FP16, DType.SHAPE) + return dtype in (DType.FP32, DType.FP16) def getOpNameFromOpListName(opName): |