diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 38 |
1 files changed, 16 insertions, 22 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index b5e68dd..a27d849 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1134,10 +1134,6 @@ class TosaArgGen: else: s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)] strides = {x for x in itertools.product(*([s_vals] * 2))} - # Dilation is not supported by the specification for transpose conv2d - # TODO: Remove this completely when schema has been updated - d_vals = [1] - dilations = {x for x in itertools.product(*([d_vals] * 2))} if not error_name and testGen.args.oversize: # add some oversize argument values @@ -1152,7 +1148,7 @@ class TosaArgGen: # There are too many parameter combinations, so generate them sparsely, # very sparse for negative tests sparsity_factor = 2 if error_name else 10 - sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1 + sparsity = len(paddings) * len(strides) // sparsity_factor + 1 # If there are only a small number of tests, just select them all if sparsity < 13: sparsity = 1 @@ -1164,24 +1160,22 @@ class TosaArgGen: n = 0 for s in sorted(list(strides)): for p in sorted(list(paddings)): - for d in sorted(list(dilations)): - if n % sparsity == 0: - # Determine the output shape - oh = (ifm_shape[1] - 1) * s[0] - p[0] - p[1] + filter_shape[1] - ow = (ifm_shape[2] - 1) * s[1] - p[2] - p[3] + filter_shape[2] - os = [ifm_shape[0], oh, ow, filter_shape[0]] - arg_list.append( - ( - "st{}_pad{}_dilat{}_os{}".format( - "".join([str(x) for x in s]), - "".join([str(x) for x in p]), - "".join([str(x) for x in d]), - "x".join([str(x) for x in os]), - ), - [s, p, d, os], - ) + if n % sparsity == 0: + # Determine the output shape + oh = (ifm_shape[1] - 1) * s[0] - p[0] - p[1] + filter_shape[1] + ow = (ifm_shape[2] - 1) * s[1] - p[2] - p[3] + filter_shape[2] + os = [ifm_shape[0], oh, ow, filter_shape[0]] + arg_list.append( + ( + "st{}_pad{}_os{}".format( + "".join([str(x) for x in s]), + "".join([str(x) for x in p]), + "x".join([str(x) for x in os]), + ), + [s, p, os], ) - n += 1 + ) + n += 1 return arg_list |