aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py38
1 files changed, 16 insertions, 22 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index b5e68dd..a27d849 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1134,10 +1134,6 @@ class TosaArgGen:
else:
s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
strides = {x for x in itertools.product(*([s_vals] * 2))}
- # Dilation is not supported by the specification for transpose conv2d
- # TODO: Remove this completely when schema has been updated
- d_vals = [1]
- dilations = {x for x in itertools.product(*([d_vals] * 2))}
if not error_name and testGen.args.oversize:
# add some oversize argument values
@@ -1152,7 +1148,7 @@ class TosaArgGen:
# There are too many parameter combinations, so generate them sparsely,
# very sparse for negative tests
sparsity_factor = 2 if error_name else 10
- sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
+ sparsity = len(paddings) * len(strides) // sparsity_factor + 1
# If there are only a small number of tests, just select them all
if sparsity < 13:
sparsity = 1
@@ -1164,24 +1160,22 @@ class TosaArgGen:
n = 0
for s in sorted(list(strides)):
for p in sorted(list(paddings)):
- for d in sorted(list(dilations)):
- if n % sparsity == 0:
- # Determine the output shape
- oh = (ifm_shape[1] - 1) * s[0] - p[0] - p[1] + filter_shape[1]
- ow = (ifm_shape[2] - 1) * s[1] - p[2] - p[3] + filter_shape[2]
- os = [ifm_shape[0], oh, ow, filter_shape[0]]
- arg_list.append(
- (
- "st{}_pad{}_dilat{}_os{}".format(
- "".join([str(x) for x in s]),
- "".join([str(x) for x in p]),
- "".join([str(x) for x in d]),
- "x".join([str(x) for x in os]),
- ),
- [s, p, d, os],
- )
+ if n % sparsity == 0:
+ # Determine the output shape
+ oh = (ifm_shape[1] - 1) * s[0] - p[0] - p[1] + filter_shape[1]
+ ow = (ifm_shape[2] - 1) * s[1] - p[2] - p[3] + filter_shape[2]
+ os = [ifm_shape[0], oh, ow, filter_shape[0]]
+ arg_list.append(
+ (
+ "st{}_pad{}_os{}".format(
+ "".join([str(x) for x in s]),
+ "".join([str(x) for x in p]),
+ "x".join([str(x) for x in os]),
+ ),
+ [s, p, os],
)
- n += 1
+ )
+ n += 1
return arg_list