From a4e48ca7b032992ca0110900935c08d7cf860cd3 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 22 Feb 2023 11:53:48 +0000 Subject: Update rank limits for SLICE, TILE and TRANSPOSE Updated to align with corresponding changes to the spec. In addition, some ERROR_IF tests have been updated to match the checks specified by the spec, including: PAD, SLICE, TILE, TRANSPOSE. Signed-off-by: Luke Hutton Change-Id: Ie2c5f48e79a5610eb82739170e25057a63dac1d8 --- verif/generator/tosa_arg_gen.py | 1 + verif/generator/tosa_error_if.py | 16 +++++++++----- verif/generator/tosa_test_gen.py | 48 ++++++++++++++++++++++++++++++---------- verif/generator/tosa_utils.py | 11 +++++++++ 4 files changed, 58 insertions(+), 18 deletions(-) (limited to 'verif/generator') diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 75ca634..9209d9c 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -176,6 +176,7 @@ class TosaTensorGen: for i in range(pl + const): shape_list.append(shape.copy()) + # Generates an input rank mismatch for operators with more than one input if error_name == ErrorIf.RankMismatch: if rank == 1 and i != 1: shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3])) diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index ee227b3..b19d5e9 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -1067,7 +1067,9 @@ class TosaErrorValidator: if check: input1_shape = kwargs["input1"].shape - input2_shape = kwargs["input2"].shape + input2_shape = ( + kwargs["input2"].shape if "input2" in kwargs else input1_shape + ) # In case of SELECT op input3_shape = ( kwargs["input3"].shape if "input3" in kwargs else input2_shape @@ -1921,11 +1923,13 @@ class TosaErrorValidator: input_shape = kwargs["input_shape"] output_shape = kwargs["output_shape"] size = kwargs["size"] - rank = len(input_shape) - if len(size) == rank: - for index in range(rank): - if size[index] != output_shape[index]: - error_result = True + + if len(input_shape) == len(output_shape): + rank = len(input_shape) + if len(size) == rank: + for index in range(rank): + if size[index] != output_shape[index]: + error_result = True info_dict = { "error_name": error_name, diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index a768da0..7fef942 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -14,6 +14,7 @@ from generator.tosa_error_if import TosaErrorIfArgGen from generator.tosa_error_if import TosaErrorValidator from generator.tosa_error_if import TosaInvalidValidator from generator.tosa_utils import DTYPE_ATTRIBUTES +from generator.tosa_utils import get_rank_mismatch_shape from generator.tosa_utils import get_wrong_output_type from generator.tosa_utils import MAX_RESIZE_DIMENSION from generator.tosa_utils import usableDTypes @@ -1263,6 +1264,7 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + input1=a, ): return None @@ -1369,6 +1371,7 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + input1=a, ): return None @@ -1404,6 +1407,7 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + input1=a, ): return None @@ -1438,6 +1442,7 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + input1=a, ): return None @@ -3657,6 +3662,8 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongRank, ), }, "reshape": { @@ -3699,7 +3706,7 @@ class TosaTestGen: "slice": { "op": Op.SLICE, "operands": (1, 0), - "rank": (1, 4), + "rank": (1, 6), "build_fcn": ( build_slice, TosaTensorGen.tgBasic, @@ -3718,11 +3725,13 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evRankMismatch, ), }, "tile": { "op": Op.TILE, "operands": (1, 0), + "rank": (1, 6), "build_fcn": ( build_tile, TosaTensorGen.tgBasic, @@ -3735,12 +3744,14 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongRank, ), }, "transpose": { "op": Op.TRANSPOSE, "operands": (1, 0), - "rank": (1, 4), + "rank": (1, 6), "build_fcn": ( build_transpose, TosaTensorGen.tgBasic, @@ -3755,6 +3766,9 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evTensorSizeInputOutputMismatch, ), }, # Data nodes @@ -4539,6 +4553,8 @@ class OutputShaper: if error_name == ErrorIf.PadOutputShapeMismatch: bad_dim = rng.choice(range(len(output_shape))) output_shape[bad_dim] -= rng.choice([1, 2]) + elif error_name == ErrorIf.RankMismatch: + output_shape = get_rank_mismatch_shape(rng, output_shape) if error_name == ErrorIf.WrongOutputType: all_dtypes = [ @@ -4583,7 +4599,7 @@ class OutputShaper: return ser.addOutput(output_shape, outputDType) @staticmethod - def sliceOp(ser, rng, a, start, size, error_name=None): + def sliceOp(ser, rng, input, start, size, error_name=None): if error_name == ErrorIf.WrongOutputType: all_dtypes = [ @@ -4595,13 +4611,13 @@ class OutputShaper: DType.FP16, DType.BF16, ] - wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) + wrong_dtypes = list(set(all_dtypes) - set([input.dtype])) outputDType = rng.choice(wrong_dtypes) else: - outputDType = a.dtype + outputDType = input.dtype + output_shape = size.copy() if error_name == ErrorIf.SizeOutputShapeMismatch: - output_shape = size.copy() for index in range(len(output_shape)): if output_shape[index] <= 2: output_shape[index] = output_shape[index] + rng.choice([1, 2]) @@ -4609,8 +4625,10 @@ class OutputShaper: output_shape[index] = output_shape[index] + rng.choice( [-2, -1, 1, 2] ) - else: - output_shape = size.copy() + elif error_name == ErrorIf.InputSizeStartLengthMismatch: + output_shape = input.shape.copy() + elif error_name == ErrorIf.RankMismatch: + output_shape = get_rank_mismatch_shape(rng, output_shape) return ser.addOutput(output_shape, outputDType) @@ -4623,6 +4641,9 @@ class OutputShaper: for i in range(len(output_shape)): output_shape[i] = a.shape[i] * multiples[i] + if error_name == ErrorIf.RankMismatch: + output_shape = get_rank_mismatch_shape(rng, output_shape) + if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, @@ -4646,13 +4667,16 @@ class OutputShaper: assert len(perms) == len(output_shape) - if error_name == ErrorIf.IndexOutsideBounds: - for i in range(len(output_shape)): - output_shape[i] = a.shape[0] - else: + if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]: for i in range(len(output_shape)): output_shape[i] = a.shape[perms[i]] + if error_name == ErrorIf.TensorSizeInputOutputMismatch: + for i in range(len(output_shape)): + output_shape[i] += rng.integers(1, 10) + elif error_name == ErrorIf.RankMismatch: + output_shape = get_rank_mismatch_shape(rng, output_shape) + if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 29ae898..8ff62f1 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -148,6 +148,17 @@ def get_wrong_output_type(op_name, rng, input_dtype): return rng.choice(a=incorrect_types) +def get_rank_mismatch_shape(rng, output_shape): + """ + Extends the rank of the provided output_shape by + an arbitrary amount but ensures the total element + count remains the same. + """ + rank_modifier = rng.choice([1, 2, 3]) + output_shape += [1] * rank_modifier + return output_shape + + def float32_is_valid_bfloat16(f): """Return True if float value is valid bfloat16.""" f32_bits = get_float32_bitstring(f) -- cgit v1.2.1