aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-12-07 14:17:57 +0000
committerEric Kunze <eric.kunze@arm.com>2023-12-11 15:01:06 +0000
commit194fe314695bdfeba5b12b837b70f392db91995b (patch)
tree86c8933f3dcf05f1eff94f6edfacd7666b28192f /verif/generator/tosa_arg_gen.py
parentaba79525d1348e0d964de22cef445089efaf3126 (diff)
downloadreference_model-194fe314695bdfeba5b12b837b70f392db91995b.tar.gz
Enforce no output rewrite REQUIRE in SCATTER
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I3555e7216d403d436bf6e39d4b16bb000645c4bb
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 193da73..35253e0 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -213,17 +213,17 @@ class TosaTensorGen:
assert rank == 3
values_in_shape = testGen.makeShape(rank)
+ K = values_in_shape[1]
# ignore max batch size if target shape is set
if testGen.args.max_batch_size and not testGen.args.target_shapes:
values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
- W = testGen.randInt(
- testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
- )
- # Constrict W if one dimension is too large to keep tensor size reasonable
- if max(values_in_shape) > 5000:
- W = testGen.randInt(0, 16)
+ # Make sure W is not greater than K, as we can only write each output index
+ # once (having a W greater than K means that you have to repeat a K index)
+ W_min = min(testGen.args.tensor_shape_range[0], K)
+ W_max = min(testGen.args.tensor_shape_range[1], K)
+ W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min
input_shape = [values_in_shape[0], W, values_in_shape[2]]