aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py113
1 files changed, 69 insertions, 44 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index e3492cd..f63a7df 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1031,7 +1031,9 @@ class TosaArgGen:
# Can't use stride=0, as it is used to derive output shape, as a divisor
s_vals = [testGen.rng.choice(range(-5, 0))]
else:
- s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
+ # Stride must be greater than 1 to force non-integer error
+ startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
+ s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)]
strides = {x for x in itertools.product(*([s_vals] * k_rank))}
if error_name == ErrorIf.DilationSmallerOne:
d_vals = [testGen.rng.choice(range(-5, 1))]
@@ -1055,7 +1057,7 @@ class TosaArgGen:
# There are too many parameter combinations, so generate them sparsely,
# very sparse for negative tests
- sparsity_factor = 2 if error_name else 100
+ sparsity_factor = 2 if error_name else 120
sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
# If there are only a small number of tests, just select them all
if sparsity < 13:
@@ -1084,16 +1086,37 @@ class TosaArgGen:
and (ifm_shape[2] + p[2] + p[3]) > d[1]
and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
):
- arg_list.append(
- (
- "st{}_pad{}_dilat{}".format(
- "".join([str(x) for x in s]),
- "".join([str(x) for x in p]),
- "".join([str(x) for x in d]),
- ),
- [s, p, d],
+ remainders = []
+ for index in range(k_rank):
+ pad_offset = index * 2
+ remainders.append(
+ (
+ ifm_shape[index + 1]
+ - 1
+ + p[pad_offset]
+ + p[pad_offset + 1]
+ - (k[index] - 1) * d[index]
+ )
+ % s[index]
+ )
+ if (
+ # the parameters must produce integer exact output
+ error_name != ErrorIf.ConvOutputShapeNonInteger
+ and max(remainders) == 0
+ ) or (
+ error_name == ErrorIf.ConvOutputShapeNonInteger
+ and max(remainders) > 0
+ ):
+ arg_list.append(
+ (
+ "st{}_pad{}_dilat{}".format(
+ "".join([str(x) for x in s]),
+ "".join([str(x) for x in p]),
+ "".join([str(x) for x in d]),
+ ),
+ [s, p, d],
+ )
)
- )
n += 1
return arg_list
@@ -1116,17 +1139,16 @@ class TosaArgGen:
p_vals = [testGen.rng.choice(range(-5, 0))]
else:
p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
- paddings = {x for x in itertools.product(*([p_vals] * 2))}
+ paddings = {x for x in itertools.product(*([p_vals] * 4))}
if error_name == ErrorIf.StrideSmallerOne:
# Can't use stride=0, as it is used to derive output shape, as a divisor
s_vals = [testGen.rng.choice(range(-5, 0))]
else:
s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
strides = {x for x in itertools.product(*([s_vals] * 2))}
- if error_name == ErrorIf.DilationSmallerOne:
- d_vals = [testGen.rng.choice(range(-5, 1))]
- else:
- d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
+ # Dilation is not supported by the specification for transpose conv2d
+ # TODO: Remove this completely when schema has been updated
+ d_vals = [1]
dilations = {x for x in itertools.product(*([d_vals] * 2))}
if not error_name:
@@ -1134,16 +1156,14 @@ class TosaArgGen:
if max(ifm_shape) < 64:
bigPadding = 9
paddings.update(
- {x for x in itertools.product(*([[0, bigPadding]] * 2))}
+ {x for x in itertools.product(*([[0, bigPadding]] * 4))}
)
bigStride = 8
strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
- bigDilation = 7
- dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
# There are too many parameter combinations, so generate them sparsely,
# very sparse for negative tests
- sparsity_factor = 2 if error_name else 100
+ sparsity_factor = 2 if error_name else 10
sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
# If there are only a small number of tests, just select them all
if sparsity < 13:
@@ -1159,18 +1179,8 @@ class TosaArgGen:
for d in sorted(list(dilations)):
if n % sparsity == 0:
# Determine the output shape
- oh = (
- ifm_shape[1]
- - filter_shape[1]
- - (filter_shape[1] - 1) * (d[0] - 1)
- + 2 * p[0]
- ) // s[0] + 1
- ow = (
- ifm_shape[2]
- - filter_shape[2]
- - (filter_shape[2] - 1) * (d[1] - 1)
- + 2 * p[1]
- ) // s[1] + 1
+ oh = (ifm_shape[1] - 1) * s[0] - p[0] - p[1] + filter_shape[1]
+ ow = (ifm_shape[2] - 1) * s[1] - p[2] - p[3] + filter_shape[2]
os = [ifm_shape[0], oh, ow, filter_shape[0]]
arg_list.append(
(
@@ -1231,7 +1241,9 @@ class TosaArgGen:
# Generate comprehensive argument lists
p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
paddings = {x for x in itertools.product(*([p_vals] * 4))}
- s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
+ # Stride must be greater than 1 to force non-integer error
+ startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2
+ s_vals = [x for x in range(startStride, testGen.args.max_pooling_stride + 1)]
strides = {x for x in itertools.product(*([s_vals] * 2))}
k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
kernels = {x for x in itertools.product(*([k_vals] * 2))}
@@ -1239,8 +1251,10 @@ class TosaArgGen:
if testGen.args.oversize:
# add some oversize argument values
bigStride = 7
- strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
- bigKernel = 6
+ strides.update(
+ {x for x in itertools.product(*([[startStride, bigStride]] * 2))}
+ )
+ bigKernel = 9
kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
if max(shape) < 64:
# padding must be less than the kernel size
@@ -1289,16 +1303,27 @@ class TosaArgGen:
and (shape[1] + p[0] + p[1]) > k[0]
and (shape[2] + p[2] + p[3]) > k[1]
):
- arg_list.append(
- (
- "st{}_kern{}_pad{}".format(
- "".join([str(x) for x in s]),
- "".join([str(x) for x in k]),
- "".join([str(x) for x in p]),
- ),
- [s, p, k],
+ remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
+ remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
+ if (
+ # the parameters must produce integer exact output
+ error_name != ErrorIf.PoolingOutputShapeNonInteger
+ and remainder_h == 0
+ and remainder_w == 0
+ ) or (
+ error_name == ErrorIf.PoolingOutputShapeNonInteger
+ and (remainder_h != 0 or remainder_w != 0)
+ ):
+ arg_list.append(
+ (
+ "st{}_kern{}_pad{}".format(
+ "".join([str(x) for x in s]),
+ "".join([str(x) for x in k]),
+ "".join([str(x) for x in p]),
+ ),
+ [s, p, k],
+ )
)
- )
n += 1
return arg_list