aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 9386ec2..97ff237 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -246,15 +246,18 @@ class TosaTensorGen:
# Choose one of the inputs to broadcast
# Note: Simplifies OutputShaper code if we don't change first shape for errors
bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
+ fuzz_idx = testGen.randInt(0, rank)
+
for i in range(pl + const):
shape_bcast = shape.copy()
+ # To test broadcasting, the chosen fuzz index dimension should not be 1
+ if shape_bcast[fuzz_idx] == 1:
+ shape_bcast[fuzz_idx] += 1
+
# If the chosen input, pick a random index to broadcast
if i == bcast_idx:
- fuzz_idx = testGen.randInt(0, rank)
- if error_name == ErrorIf.DimensionMismatch:
- shape_bcast[fuzz_idx] += 1
- elif error_name == ErrorIf.RankMismatch:
+ if error_name == ErrorIf.RankMismatch:
# Add one rank to the shape (or more for rank of 1)
extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
shape_bcast = np.concatenate(
@@ -264,6 +267,8 @@ class TosaTensorGen:
# Either keep the extra rank, or remove it
new_len = testGen.rng.choice([-2, len(shape_bcast)])
shape_bcast = shape_bcast[:new_len]
+ elif error_name == ErrorIf.BroadcastShapesMismatch:
+ shape_bcast[fuzz_idx] += 2
else:
shape_bcast[fuzz_idx] = 1