aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py22
1 files changed, 9 insertions, 13 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 69968d3..e0c6cf0 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1017,7 +1017,7 @@ class TosaArgGen:
s_vals = [testGen.rng.choice(range(-5, 0))]
else:
# Stride must be greater than 1 to force non-integer error
- startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
+ startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
strides = {x for x in itertools.product(*([s_vals] * k_rank))}
if error_name == ErrorIf.DilationSmallerOne:
@@ -1058,18 +1058,14 @@ class TosaArgGen:
for d in sorted(list(dilations)):
if (
n % sparsity == 0
- # padding must not exceed the kernel size ?
- # and p[0] < k[0] and p[1] < k[0]
- # and p[2] < k[1] and p[3] < k[1]
- # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
- # the padded shape must exceed the kernel size
- and (ifm_shape[1] + p[0] + p[1]) > k[0]
- and (ifm_shape[2] + p[2] + p[3]) > k[1]
- and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
- # the padded shape must exceed the dilation
- and (ifm_shape[1] + p[0] + p[1]) > d[0]
- and (ifm_shape[2] + p[2] + p[3]) > d[1]
- and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
+ # the padded shape must exceed the dilation * kernel to get a positive
+ # sized output shape
+ and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k[0] - 1)
+ and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k[1] - 1)
+ and (
+ k_rank < 3
+ or ((ifm_shape[3] - 1 + p[4] + p[5]) > d[2] * (k[2] - 1))
+ )
):
remainders = []
for index in range(k_rank):