aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 7691fdd..66084b4 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -2925,6 +2925,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"arithmetic_right_shift": {
@@ -2944,6 +2945,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_and": {
@@ -2963,6 +2965,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_or": {
@@ -2982,6 +2985,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_xor": {
@@ -3001,6 +3005,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"intdiv": {
@@ -3020,6 +3025,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_and": {
@@ -3039,6 +3045,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_left_shift": {
@@ -3058,6 +3065,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_right_shift": {
@@ -3077,6 +3085,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_or": {
@@ -3096,6 +3105,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_xor": {
@@ -3115,6 +3125,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"maximum": {
@@ -3134,6 +3145,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"minimum": {
@@ -3153,6 +3165,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"mul": {
@@ -3172,6 +3185,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"pow": {
@@ -3191,6 +3205,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"sub": {
@@ -3210,6 +3225,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"table": {
@@ -3441,6 +3457,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
# Comparison operators
@@ -3461,6 +3478,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"greater_equal": {
@@ -3480,6 +3498,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"greater": {
@@ -3499,6 +3518,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
# Reduction operators
@@ -4078,6 +4098,10 @@ class OutputShaper:
else:
shape.append(a.shape[i])
+ fuzz_idx = rng.integers(0, len(a.shape))
+ if error_name == ErrorIf.DimensionMismatch:
+ shape[fuzz_idx] += 1
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
@@ -4139,6 +4163,10 @@ class OutputShaper:
else:
shape.append(cond.shape[i])
+ fuzz_idx = rng.integers(0, len(a.shape))
+ if error_name == ErrorIf.DimensionMismatch:
+ shape[fuzz_idx] += 1
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
@@ -4170,6 +4198,10 @@ class OutputShaper:
else:
shape.append(a.shape[i])
+ fuzz_idx = rng.integers(0, len(a.shape))
+ if error_name == ErrorIf.DimensionMismatch:
+ shape[fuzz_idx] += 1
+
if error_name == ErrorIf.WrongOutputType:
wrong_dtypes = [
DType.INT8,