diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 9386ec2..97ff237 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -246,15 +246,18 @@ class TosaTensorGen: # Choose one of the inputs to broadcast # Note: Simplifies OutputShaper code if we don't change first shape for errors bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const) + fuzz_idx = testGen.randInt(0, rank) + for i in range(pl + const): shape_bcast = shape.copy() + # To test broadcasting, the chosen fuzz index dimension should not be 1 + if shape_bcast[fuzz_idx] == 1: + shape_bcast[fuzz_idx] += 1 + # If the chosen input, pick a random index to broadcast if i == bcast_idx: - fuzz_idx = testGen.randInt(0, rank) - if error_name == ErrorIf.DimensionMismatch: - shape_bcast[fuzz_idx] += 1 - elif error_name == ErrorIf.RankMismatch: + if error_name == ErrorIf.RankMismatch: # Add one rank to the shape (or more for rank of 1) extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1 shape_bcast = np.concatenate( @@ -264,6 +267,8 @@ class TosaTensorGen: # Either keep the extra rank, or remove it new_len = testGen.rng.choice([-2, len(shape_bcast)]) shape_bcast = shape_bcast[:new_len] + elif error_name == ErrorIf.BroadcastShapesMismatch: + shape_bcast[fuzz_idx] += 2 else: shape_bcast[fuzz_idx] = 1 |