diff options
-rw-r--r-- | verif/tosa_test_gen.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 99dc5f8..e08add3 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -304,7 +304,13 @@ class TosaTensorGen: assert rank == 2 input_shape = testGen.makeShape(rank) - filter_oc = testGen.makeShape(1)[0] + filter_oc = ( + testGen.rng.integers( + low=testGen.args.tensor_shape_range[0], + high=testGen.args.tensor_shape_range[1], + size=1, + )[0] + ) filter_shape = np.asarray([filter_oc, input_shape[1]]) bias_shape = np.asarray([filter_oc]) |