aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py54
1 files changed, 30 insertions, 24 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 9555195..a19c5f4 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -502,30 +502,36 @@ class TosaArgGen:
assert len(ifm_shape) == 5
assert len(filter_shape) == 5
- # Generate basic argument list now
- # TODO: increase coverage
- s = [1, 1, 1]
- p = [0, 0, 0, 0, 0, 0]
- d = [1, 1, 1]
- arg_list.append(
- (
- "st{}{}{}_pad{}{}{}{}{}{}_dilat{}{}{}".format(
- s[0],
- s[1],
- s[2],
- p[0],
- p[1],
- p[2],
- p[3],
- p[4],
- p[5],
- d[0],
- d[1],
- d[2],
- ),
- [s, p, d],
- )
- )
+ # Generate comprehensive argument list
+ p_range = [x for x in range(0, testGen.args.max_conv_padding + 1)]
+ paddings = [x for x in itertools.product(*([p_range] * 6))]
+ s_range = [x for x in range(1, testGen.args.max_conv_stride + 1)]
+ strides = [x for x in itertools.product(*([s_range] * 3))]
+ d_range = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
+ dilations = [x for x in itertools.product(*([d_range] * 3))]
+
+ # There are too many parameter combinations, so generate them sparsely
+ # To get a variety of parameter combinations sparsity should not be a multiple of 2, or 3
+ # TODO: make sparsity a CLI option
+ sparsity = 37
+ n = 0
+
+ for s in strides:
+ for p in paddings:
+ for d in dilations:
+ if n % sparsity == 0:
+ arg_list.append(
+ (
+ "st{}_pad{}_dilat{}".format(
+ "".join([str(x) for x in s]),
+ "".join([str(x) for x in p]),
+ "".join([str(x) for x in d]),
+ ),
+ [s, p, d],
+ )
+ )
+ n += 1
+
return arg_list
@staticmethod