aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--verif/tosa_test_gen.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 0d5169c..b5ddbd7 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -177,13 +177,17 @@ class TosaTensorGen:
values_in_shape = testGen.makeShape(rank)
- # Constrict the batch size?
- if testGen.args.max_batch_size:
+ # ignore max batch size if target shape is set
+ if testGen.args.max_batch_size and not testGen.args.target_shapes:
values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
W = testGen.randInt(
testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
)
+ # Constrict W if one dimension is too large to keep tensor size reasonable
+ if max(values_in_shape) > 5000:
+ W = testGen.randInt(0, 16)
+
input_shape = [values_in_shape[0], W, values_in_shape[2]]
shape_list = []