From 194fe314695bdfeba5b12b837b70f392db91995b Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 7 Dec 2023 14:17:57 +0000 Subject: Enforce no output rewrite REQUIRE in SCATTER Signed-off-by: Jeremy Johnson Change-Id: I3555e7216d403d436bf6e39d4b16bb000645c4bb --- verif/generator/tosa_arg_gen.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'verif/generator/tosa_arg_gen.py') diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 193da73..35253e0 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -213,17 +213,17 @@ class TosaTensorGen: assert rank == 3 values_in_shape = testGen.makeShape(rank) + K = values_in_shape[1] # ignore max batch size if target shape is set if testGen.args.max_batch_size and not testGen.args.target_shapes: values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size) - W = testGen.randInt( - testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1] - ) - # Constrict W if one dimension is too large to keep tensor size reasonable - if max(values_in_shape) > 5000: - W = testGen.randInt(0, 16) + # Make sure W is not greater than K, as we can only write each output index + # once (having a W greater than K means that you have to repeat a K index) + W_min = min(testGen.args.tensor_shape_range[0], K) + W_max = min(testGen.args.tensor_shape_range[1], K) + W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min input_shape = [values_in_shape[0], W, values_in_shape[2]] -- cgit v1.2.1