aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-05-23 20:59:32 +0000
committerDominic Symes <dominic.symes@arm.com>2023-06-15 18:25:54 +0000
commit135c95544fda260e8ce622cff7835b886a97663f (patch)
tree5d46f8f48978112abff037309a827b5844ee80de /verif/generator/tosa_test_gen.py
parentcb7201e173961760c042cade591afe763c949c8f (diff)
downloadreference_model-135c95544fda260e8ce622cff7835b886a97663f.tar.gz
Add ERROR_IF to incorrect broadcast shapes
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I7460ad9eed3ed5c7cec6e855a0303753ed28eb1c
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 7691fdd..66084b4 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -2925,6 +2925,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"arithmetic_right_shift": {
@@ -2944,6 +2945,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_and": {
@@ -2963,6 +2965,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_or": {
@@ -2982,6 +2985,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_xor": {
@@ -3001,6 +3005,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"intdiv": {
@@ -3020,6 +3025,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_and": {
@@ -3039,6 +3045,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_left_shift": {
@@ -3058,6 +3065,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_right_shift": {
@@ -3077,6 +3085,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_or": {
@@ -3096,6 +3105,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_xor": {
@@ -3115,6 +3125,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"maximum": {
@@ -3134,6 +3145,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"minimum": {
@@ -3153,6 +3165,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"mul": {
@@ -3172,6 +3185,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"pow": {
@@ -3191,6 +3205,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"sub": {
@@ -3210,6 +3225,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"table": {
@@ -3441,6 +3457,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
# Comparison operators
@@ -3461,6 +3478,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"greater_equal": {
@@ -3480,6 +3498,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"greater": {
@@ -3499,6 +3518,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
# Reduction operators
@@ -4078,6 +4098,10 @@ class OutputShaper:
else:
shape.append(a.shape[i])
+ fuzz_idx = rng.integers(0, len(a.shape))
+ if error_name == ErrorIf.DimensionMismatch:
+ shape[fuzz_idx] += 1
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
@@ -4139,6 +4163,10 @@ class OutputShaper:
else:
shape.append(cond.shape[i])
+ fuzz_idx = rng.integers(0, len(a.shape))
+ if error_name == ErrorIf.DimensionMismatch:
+ shape[fuzz_idx] += 1
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
@@ -4170,6 +4198,10 @@ class OutputShaper:
else:
shape.append(a.shape[i])
+ fuzz_idx = rng.integers(0, len(a.shape))
+ if error_name == ErrorIf.DimensionMismatch:
+ shape[fuzz_idx] += 1
+
if error_name == ErrorIf.WrongOutputType:
wrong_dtypes = [
DType.INT8,