aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-08-24 14:25:43 +0100
committerEric Kunze <eric.kunze@arm.com>2021-09-07 13:52:39 +0000
commit4b2881a7c1cbb2a4b0b24cafcdef28af0f4975c1 (patch)
tree1e98f0b9cf03a111572a14434d88e6ed74c3f8d0
parent459443c59fcfb5eb2ec3df4579cfe87f3a45db1c (diff)
downloadreference_model-4b2881a7c1cbb2a4b0b24cafcdef28af0f4975c1.tar.gz
Fix batch size if target shape set for SCATTER operator
* max batch size ignored if target shape is set * W size reduced if large input tensor used Signed-off-by: Matthew Haddon <matthew.haddon@arm.com> Change-Id: I13472ab768fa93a1d0b9e28964f56ec4a06dbdfd
-rw-r--r--verif/tosa_test_gen.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 0d5169c..b5ddbd7 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -177,13 +177,17 @@ class TosaTensorGen:
values_in_shape = testGen.makeShape(rank)
- # Constrict the batch size?
- if testGen.args.max_batch_size:
+ # ignore max batch size if target shape is set
+ if testGen.args.max_batch_size and not testGen.args.target_shapes:
values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
W = testGen.randInt(
testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
)
+ # Constrict W if one dimension is too large to keep tensor size reasonable
+ if max(values_in_shape) > 5000:
+ W = testGen.randInt(0, 16)
+
input_shape = [values_in_shape[0], W, values_in_shape[2]]
shape_list = []