diff options
-rw-r--r-- | verif/tosa_test_gen.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 0d5169c..b5ddbd7 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -177,13 +177,17 @@ class TosaTensorGen: values_in_shape = testGen.makeShape(rank) - # Constrict the batch size? - if testGen.args.max_batch_size: + # ignore max batch size if target shape is set + if testGen.args.max_batch_size and not testGen.args.target_shapes: values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1 W = testGen.randInt( testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1] ) + # Constrict W if one dimension is too large to keep tensor size reasonable + if max(values_in_shape) > 5000: + W = testGen.randInt(0, 16) + input_shape = [values_in_shape[0], W, values_in_shape[2]] shape_list = [] |