From fc4bde92120567a98189f95cfe90bb1699d25809 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 25 Jan 2024 12:53:21 +0000 Subject: Fix up shape operator test errors Update serialization_lib to store SHAPE as INT64. Signed-off-by: Jeremy Johnson Change-Id: Ie589cd6670dc79b77df981c81cd7c27b982f20fa --- thirdparty/serialization_lib | 2 +- verif/conformance/tosa_base_profile_ops_info.json | 1 + verif/generator/tosa_arg_gen.py | 22 ---------------------- verif/generator/tosa_test_gen.py | 14 ++++++++++---- verif/generator/tosa_utils.py | 2 +- 5 files changed, 13 insertions(+), 28 deletions(-) diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index 5d580fa..7c22d77 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit 5d580faec02bcef56164587accb5fd88a3e80d86 +Subproject commit 7c22d77a71eb59885dab1cbc3b957384c10a2af7 diff --git a/verif/conformance/tosa_base_profile_ops_info.json b/verif/conformance/tosa_base_profile_ops_info.json index ec51324..54bce21 100644 --- a/verif/conformance/tosa_base_profile_ops_info.json +++ b/verif/conformance/tosa_base_profile_ops_info.json @@ -1011,6 +1011,7 @@ ], "generation": { "standard": { + "no_negative_tests": "true", "generator_args": [ [ "--target-dtype", diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 91d2d62..386e243 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -622,28 +622,6 @@ class TosaTensorGen: return new_shapeList - @staticmethod - def tgShape(testGen, opName, rank, error_name=None): - pl, const = opName["operands"] - shape = [rank] - - # Constrict the overall size of the shape when creating ERROR_IF tests - if error_name: - shape = TosaErrorIfArgGen.eiRestrictDimensions(shape) - - shape_list = [] - for i in range(pl + const): - shape_list.append(shape.copy()) - - # Generates an input rank mismatch for operators with more than one input - if error_name == ErrorIf.RankMismatch: - if rank == 1 and i != 1: - shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3])) - elif i != 1: - shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1])) - - return shape_list - class TosaTensorValuesGen: """Tensor Value generators create the random data for each tensor in each test.""" diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 39b064d..a347b13 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -4714,9 +4714,10 @@ class TosaTestGen: "add_shape": { "op": Op.ADD_SHAPE, "operands": (2, 0), + "rank": (1, 1), "build_fcn": ( build_shape_op, - TosaTensorGen.tgShape, + TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgAddSub, TosaArgGen.agNone, ), @@ -4726,9 +4727,10 @@ class TosaTestGen: "sub_shape": { "op": Op.SUB_SHAPE, "operands": (2, 0), + "rank": (1, 1), "build_fcn": ( build_shape_op, - TosaTensorGen.tgShape, + TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgAddSub, TosaArgGen.agNone, ), @@ -4738,9 +4740,10 @@ class TosaTestGen: "mul_shape": { "op": Op.MUL_SHAPE, "operands": (2, 0), + "rank": (1, 1), "build_fcn": ( build_shape_op, - TosaTensorGen.tgShape, + TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgMul, TosaArgGen.agNone, ), @@ -4750,9 +4753,10 @@ class TosaTestGen: "div_shape": { "op": Op.DIV_SHAPE, "operands": (2, 0), + "rank": (1, 1), "build_fcn": ( build_shape_op, - TosaTensorGen.tgShape, + TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgIntDiv, TosaArgGen.agNone, ), @@ -4762,6 +4766,7 @@ class TosaTestGen: "concat_shape": { "op": Op.CONCAT_SHAPE, "operands": (2, 0), + "rank": (1, 1), "build_fcn": ( build_concat, TosaTensorGen.tgConcat, @@ -4774,6 +4779,7 @@ class TosaTestGen: "const_shape": { "op": Op.CONST_SHAPE, "operands": (0, 1), + "rank": (1, 1), "build_fcn": ( build_const, TosaTensorGen.tgBasic, diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 33db95f..6387d06 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -56,7 +56,7 @@ def dtypeIsSupportedByCompliance(dtype): """Types supported by the new data generation and compliance flow.""" if isinstance(dtype, list) or isinstance(dtype, tuple): dtype = dtype[0] - return dtype in (DType.FP32, DType.FP16, DType.SHAPE) + return dtype in (DType.FP32, DType.FP16) def getOpNameFromOpListName(opName): -- cgit v1.2.1