diff options
author | Luke Hutton <luke.hutton@arm.com> | 2023-02-22 11:53:48 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-02-28 20:08:57 +0000 |
commit | a4e48ca7b032992ca0110900935c08d7cf860cd3 (patch) | |
tree | a58c8617390225ecc107721d9b5ff87c2bdb01b0 /verif/generator/tosa_test_gen.py | |
parent | 2226f90d5a6c48a975045bc9e0419113ce764aaf (diff) | |
download | reference_model-a4e48ca7b032992ca0110900935c08d7cf860cd3.tar.gz |
Update rank limits for SLICE, TILE and TRANSPOSE
Updated to align with corresponding changes to the
spec.
In addition, some ERROR_IF tests have been updated to
match the checks specified by the spec, including:
PAD, SLICE, TILE, TRANSPOSE.
Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Change-Id: Ie2c5f48e79a5610eb82739170e25057a63dac1d8
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 48 |
1 files changed, 36 insertions, 12 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index a768da0..7fef942 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -14,6 +14,7 @@ from generator.tosa_error_if import TosaErrorIfArgGen from generator.tosa_error_if import TosaErrorValidator from generator.tosa_error_if import TosaInvalidValidator from generator.tosa_utils import DTYPE_ATTRIBUTES +from generator.tosa_utils import get_rank_mismatch_shape from generator.tosa_utils import get_wrong_output_type from generator.tosa_utils import MAX_RESIZE_DIMENSION from generator.tosa_utils import usableDTypes @@ -1263,6 +1264,7 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + input1=a, ): return None @@ -1369,6 +1371,7 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + input1=a, ): return None @@ -1404,6 +1407,7 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + input1=a, ): return None @@ -1438,6 +1442,7 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + input1=a, ): return None @@ -3657,6 +3662,8 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongRank, ), }, "reshape": { @@ -3699,7 +3706,7 @@ class TosaTestGen: "slice": { "op": Op.SLICE, "operands": (1, 0), - "rank": (1, 4), + "rank": (1, 6), "build_fcn": ( build_slice, TosaTensorGen.tgBasic, @@ -3718,11 +3725,13 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evRankMismatch, ), }, "tile": { "op": Op.TILE, "operands": (1, 0), + "rank": (1, 6), "build_fcn": ( build_tile, TosaTensorGen.tgBasic, @@ -3735,12 +3744,14 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongRank, ), }, "transpose": { "op": Op.TRANSPOSE, "operands": (1, 0), - "rank": (1, 4), + "rank": (1, 6), "build_fcn": ( build_transpose, TosaTensorGen.tgBasic, @@ -3755,6 +3766,9 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evTensorSizeInputOutputMismatch, ), }, # Data nodes @@ -4539,6 +4553,8 @@ class OutputShaper: if error_name == ErrorIf.PadOutputShapeMismatch: bad_dim = rng.choice(range(len(output_shape))) output_shape[bad_dim] -= rng.choice([1, 2]) + elif error_name == ErrorIf.RankMismatch: + output_shape = get_rank_mismatch_shape(rng, output_shape) if error_name == ErrorIf.WrongOutputType: all_dtypes = [ @@ -4583,7 +4599,7 @@ class OutputShaper: return ser.addOutput(output_shape, outputDType) @staticmethod - def sliceOp(ser, rng, a, start, size, error_name=None): + def sliceOp(ser, rng, input, start, size, error_name=None): if error_name == ErrorIf.WrongOutputType: all_dtypes = [ @@ -4595,13 +4611,13 @@ class OutputShaper: DType.FP16, DType.BF16, ] - wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) + wrong_dtypes = list(set(all_dtypes) - set([input.dtype])) outputDType = rng.choice(wrong_dtypes) else: - outputDType = a.dtype + outputDType = input.dtype + output_shape = size.copy() if error_name == ErrorIf.SizeOutputShapeMismatch: - output_shape = size.copy() for index in range(len(output_shape)): if output_shape[index] <= 2: output_shape[index] = output_shape[index] + rng.choice([1, 2]) @@ -4609,8 +4625,10 @@ class OutputShaper: output_shape[index] = output_shape[index] + rng.choice( [-2, -1, 1, 2] ) - else: - output_shape = size.copy() + elif error_name == ErrorIf.InputSizeStartLengthMismatch: + output_shape = input.shape.copy() + elif error_name == ErrorIf.RankMismatch: + output_shape = get_rank_mismatch_shape(rng, output_shape) return ser.addOutput(output_shape, outputDType) @@ -4623,6 +4641,9 @@ class OutputShaper: for i in range(len(output_shape)): output_shape[i] = a.shape[i] * multiples[i] + if error_name == ErrorIf.RankMismatch: + output_shape = get_rank_mismatch_shape(rng, output_shape) + if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, @@ -4646,13 +4667,16 @@ class OutputShaper: assert len(perms) == len(output_shape) - if error_name == ErrorIf.IndexOutsideBounds: - for i in range(len(output_shape)): - output_shape[i] = a.shape[0] - else: + if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]: for i in range(len(output_shape)): output_shape[i] = a.shape[perms[i]] + if error_name == ErrorIf.TensorSizeInputOutputMismatch: + for i in range(len(output_shape)): + output_shape[i] += rng.integers(1, 10) + elif error_name == ErrorIf.RankMismatch: + output_shape = get_rank_mismatch_shape(rng, output_shape) + if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, |