diff options
author | Jerry Ge <jerry.ge@arm.com> | 2023-05-23 20:59:32 +0000 |
---|---|---|
committer | Dominic Symes <dominic.symes@arm.com> | 2023-06-15 18:25:54 +0000 |
commit | 135c95544fda260e8ce622cff7835b886a97663f (patch) | |
tree | 5d46f8f48978112abff037309a827b5844ee80de /verif/generator/tosa_test_gen.py | |
parent | cb7201e173961760c042cade591afe763c949c8f (diff) | |
download | reference_model-135c95544fda260e8ce622cff7835b886a97663f.tar.gz |
Add ERROR_IF to incorrect broadcast shapes
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I7460ad9eed3ed5c7cec6e855a0303753ed28eb1c
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 7691fdd..66084b4 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -2925,6 +2925,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "arithmetic_right_shift": { @@ -2944,6 +2945,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "bitwise_and": { @@ -2963,6 +2965,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "bitwise_or": { @@ -2982,6 +2985,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "bitwise_xor": { @@ -3001,6 +3005,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "intdiv": { @@ -3020,6 +3025,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_and": { @@ -3039,6 +3045,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_left_shift": { @@ -3058,6 +3065,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_right_shift": { @@ -3077,6 +3085,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_or": { @@ -3096,6 +3105,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_xor": { @@ -3115,6 +3125,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "maximum": { @@ -3134,6 +3145,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "minimum": { @@ -3153,6 +3165,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "mul": { @@ -3172,6 +3185,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "pow": { @@ -3191,6 +3205,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "sub": { @@ -3210,6 +3225,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "table": { @@ -3441,6 +3457,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, # Comparison operators @@ -3461,6 +3478,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "greater_equal": { @@ -3480,6 +3498,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "greater": { @@ -3499,6 +3518,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, # Reduction operators @@ -4078,6 +4098,10 @@ class OutputShaper: else: shape.append(a.shape[i]) + fuzz_idx = rng.integers(0, len(a.shape)) + if error_name == ErrorIf.DimensionMismatch: + shape[fuzz_idx] += 1 + if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, @@ -4139,6 +4163,10 @@ class OutputShaper: else: shape.append(cond.shape[i]) + fuzz_idx = rng.integers(0, len(a.shape)) + if error_name == ErrorIf.DimensionMismatch: + shape[fuzz_idx] += 1 + if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, @@ -4170,6 +4198,10 @@ class OutputShaper: else: shape.append(a.shape[i]) + fuzz_idx = rng.integers(0, len(a.shape)) + if error_name == ErrorIf.DimensionMismatch: + shape[fuzz_idx] += 1 + if error_name == ErrorIf.WrongOutputType: wrong_dtypes = [ DType.INT8, |