diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 22 |
1 files changed, 9 insertions, 13 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 69968d3..e0c6cf0 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1017,7 +1017,7 @@ class TosaArgGen: s_vals = [testGen.rng.choice(range(-5, 0))] else: # Stride must be greater than 1 to force non-integer error - startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2 + startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)] strides = {x for x in itertools.product(*([s_vals] * k_rank))} if error_name == ErrorIf.DilationSmallerOne: @@ -1058,18 +1058,14 @@ class TosaArgGen: for d in sorted(list(dilations)): if ( n % sparsity == 0 - # padding must not exceed the kernel size ? - # and p[0] < k[0] and p[1] < k[0] - # and p[2] < k[1] and p[3] < k[1] - # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2])) - # the padded shape must exceed the kernel size - and (ifm_shape[1] + p[0] + p[1]) > k[0] - and (ifm_shape[2] + p[2] + p[3]) > k[1] - and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2])) - # the padded shape must exceed the dilation - and (ifm_shape[1] + p[0] + p[1]) > d[0] - and (ifm_shape[2] + p[2] + p[3]) > d[1] - and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2])) + # the padded shape must exceed the dilation * kernel to get a positive + # sized output shape + and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k[0] - 1) + and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k[1] - 1) + and ( + k_rank < 3 + or ((ifm_shape[3] - 1 + p[4] + p[5]) > d[2] * (k[2] - 1)) + ) ): remainders = [] for index in range(k_rank): |