aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--verif/tosa_test_gen.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 655cdfc..0071b9f 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -4761,7 +4761,7 @@ class TosaTestGen:
if error_name in [ErrorIf.CondIfOutputListThenGraphMismatch, ErrorIf.CondIfOutputListElseGraphMismatch]:
incorrect_shape = deepcopy(then_tens.shape)
for i in range(len(incorrect_shape)):
- incorrect_shape[i] = incorrect_shape[i] + self.rng.choice([-3, -2, 2, 3])
+ incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3]) if incorrect_shape[i] > 3 else self.rng.choice([1, 2, 4])
incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))