aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py22
1 files changed, 16 insertions, 6 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 2596bec..ef84762 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1111,10 +1111,15 @@ class TosaArgGen:
# Generate comprehensive argument lists
# - except for named errors, which use specific invalid value(s)
- if error_name == ErrorIf.PadSmallerZero:
- p_vals = [testGen.rng.choice(range(-5, 0))]
+ smallest_padding_size = -min(filter_shape[1], filter_shape[2]) + 1
+ if error_name == ErrorIf.PadLargerEqualKernel:
+ max_filter_size = -max(filter_shape[1], filter_shape[2])
+ p_vals = [testGen.rng.choice(range(max_filter_size - 10, max_filter_size))]
else:
- p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
+ p_vals = [
+ x
+ for x in range(smallest_padding_size, testGen.args.max_conv_padding + 1)
+ ]
paddings = {x for x in itertools.product(*([p_vals] * 4))}
if error_name == ErrorIf.StrideSmallerOne:
# Can't use stride=0, as it is used to derive output shape, as a divisor
@@ -1128,7 +1133,12 @@ class TosaArgGen:
if max(ifm_shape) < 64:
bigPadding = 9
paddings.update(
- {x for x in itertools.product(*([[0, bigPadding]] * 4))}
+ {
+ x
+ for x in itertools.product(
+ *([[smallest_padding_size, bigPadding]] * 4)
+ )
+ }
)
bigStride = 8
strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
@@ -1150,8 +1160,8 @@ class TosaArgGen:
for p in sorted(list(paddings)):
if n % sparsity == 0:
# Determine the output shape
- oh = (ifm_shape[1] - 1) * s[0] - p[0] - p[1] + filter_shape[1]
- ow = (ifm_shape[2] - 1) * s[1] - p[2] - p[3] + filter_shape[2]
+ oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + filter_shape[1]
+ ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + filter_shape[2]
os = [ifm_shape[0], oh, ow, filter_shape[0]]
arg_list.append(
(