diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 7691fdd..66084b4 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -2925,6 +2925,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "arithmetic_right_shift": { @@ -2944,6 +2945,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "bitwise_and": { @@ -2963,6 +2965,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "bitwise_or": { @@ -2982,6 +2985,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "bitwise_xor": { @@ -3001,6 +3005,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "intdiv": { @@ -3020,6 +3025,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_and": { @@ -3039,6 +3045,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_left_shift": { @@ -3058,6 +3065,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_right_shift": { @@ -3077,6 +3085,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_or": { @@ -3096,6 +3105,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_xor": { @@ -3115,6 +3125,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "maximum": { @@ -3134,6 +3145,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "minimum": { @@ -3153,6 +3165,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "mul": { @@ -3172,6 +3185,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "pow": { @@ -3191,6 +3205,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "sub": { @@ -3210,6 +3225,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "table": { @@ -3441,6 +3457,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, # Comparison operators @@ -3461,6 +3478,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "greater_equal": { @@ -3480,6 +3498,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "greater": { @@ -3499,6 +3518,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, # Reduction operators @@ -4078,6 +4098,10 @@ class OutputShaper: else: shape.append(a.shape[i]) + fuzz_idx = rng.integers(0, len(a.shape)) + if error_name == ErrorIf.DimensionMismatch: + shape[fuzz_idx] += 1 + if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, @@ -4139,6 +4163,10 @@ class OutputShaper: else: shape.append(cond.shape[i]) + fuzz_idx = rng.integers(0, len(a.shape)) + if error_name == ErrorIf.DimensionMismatch: + shape[fuzz_idx] += 1 + if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, @@ -4170,6 +4198,10 @@ class OutputShaper: else: shape.append(a.shape[i]) + fuzz_idx = rng.integers(0, len(a.shape)) + if error_name == ErrorIf.DimensionMismatch: + shape[fuzz_idx] += 1 + if error_name == ErrorIf.WrongOutputType: wrong_dtypes = [ DType.INT8, |