diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 113 |
1 files changed, 69 insertions, 44 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index e3492cd..f63a7df 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1031,7 +1031,9 @@ class TosaArgGen: # Can't use stride=0, as it is used to derive output shape, as a divisor s_vals = [testGen.rng.choice(range(-5, 0))] else: - s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)] + # Stride must be greater than 1 to force non-integer error + startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2 + s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)] strides = {x for x in itertools.product(*([s_vals] * k_rank))} if error_name == ErrorIf.DilationSmallerOne: d_vals = [testGen.rng.choice(range(-5, 1))] @@ -1055,7 +1057,7 @@ class TosaArgGen: # There are too many parameter combinations, so generate them sparsely, # very sparse for negative tests - sparsity_factor = 2 if error_name else 100 + sparsity_factor = 2 if error_name else 120 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1 # If there are only a small number of tests, just select them all if sparsity < 13: @@ -1084,16 +1086,37 @@ class TosaArgGen: and (ifm_shape[2] + p[2] + p[3]) > d[1] and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2])) ): - arg_list.append( - ( - "st{}_pad{}_dilat{}".format( - "".join([str(x) for x in s]), - "".join([str(x) for x in p]), - "".join([str(x) for x in d]), - ), - [s, p, d], + remainders = [] + for index in range(k_rank): + pad_offset = index * 2 + remainders.append( + ( + ifm_shape[index + 1] + - 1 + + p[pad_offset] + + p[pad_offset + 1] + - (k[index] - 1) * d[index] + ) + % s[index] + ) + if ( + # the parameters must produce integer exact output + error_name != ErrorIf.ConvOutputShapeNonInteger + and max(remainders) == 0 + ) or ( + error_name == ErrorIf.ConvOutputShapeNonInteger + and max(remainders) > 0 + ): + arg_list.append( + ( + "st{}_pad{}_dilat{}".format( + "".join([str(x) for x in s]), + "".join([str(x) for x in p]), + "".join([str(x) for x in d]), + ), + [s, p, d], + ) ) - ) n += 1 return arg_list @@ -1116,17 +1139,16 @@ class TosaArgGen: p_vals = [testGen.rng.choice(range(-5, 0))] else: p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)] - paddings = {x for x in itertools.product(*([p_vals] * 2))} + paddings = {x for x in itertools.product(*([p_vals] * 4))} if error_name == ErrorIf.StrideSmallerOne: # Can't use stride=0, as it is used to derive output shape, as a divisor s_vals = [testGen.rng.choice(range(-5, 0))] else: s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)] strides = {x for x in itertools.product(*([s_vals] * 2))} - if error_name == ErrorIf.DilationSmallerOne: - d_vals = [testGen.rng.choice(range(-5, 1))] - else: - d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)] + # Dilation is not supported by the specification for transpose conv2d + # TODO: Remove this completely when schema has been updated + d_vals = [1] dilations = {x for x in itertools.product(*([d_vals] * 2))} if not error_name: @@ -1134,16 +1156,14 @@ class TosaArgGen: if max(ifm_shape) < 64: bigPadding = 9 paddings.update( - {x for x in itertools.product(*([[0, bigPadding]] * 2))} + {x for x in itertools.product(*([[0, bigPadding]] * 4))} ) bigStride = 8 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))}) - bigDilation = 7 - dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))}) # There are too many parameter combinations, so generate them sparsely, # very sparse for negative tests - sparsity_factor = 2 if error_name else 100 + sparsity_factor = 2 if error_name else 10 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1 # If there are only a small number of tests, just select them all if sparsity < 13: @@ -1159,18 +1179,8 @@ class TosaArgGen: for d in sorted(list(dilations)): if n % sparsity == 0: # Determine the output shape - oh = ( - ifm_shape[1] - - filter_shape[1] - - (filter_shape[1] - 1) * (d[0] - 1) - + 2 * p[0] - ) // s[0] + 1 - ow = ( - ifm_shape[2] - - filter_shape[2] - - (filter_shape[2] - 1) * (d[1] - 1) - + 2 * p[1] - ) // s[1] + 1 + oh = (ifm_shape[1] - 1) * s[0] - p[0] - p[1] + filter_shape[1] + ow = (ifm_shape[2] - 1) * s[1] - p[2] - p[3] + filter_shape[2] os = [ifm_shape[0], oh, ow, filter_shape[0]] arg_list.append( ( @@ -1231,7 +1241,9 @@ class TosaArgGen: # Generate comprehensive argument lists p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)] paddings = {x for x in itertools.product(*([p_vals] * 4))} - s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)] + # Stride must be greater than 1 to force non-integer error + startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2 + s_vals = [x for x in range(startStride, testGen.args.max_pooling_stride + 1)] strides = {x for x in itertools.product(*([s_vals] * 2))} k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)] kernels = {x for x in itertools.product(*([k_vals] * 2))} @@ -1239,8 +1251,10 @@ class TosaArgGen: if testGen.args.oversize: # add some oversize argument values bigStride = 7 - strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))}) - bigKernel = 6 + strides.update( + {x for x in itertools.product(*([[startStride, bigStride]] * 2))} + ) + bigKernel = 9 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))}) if max(shape) < 64: # padding must be less than the kernel size @@ -1289,16 +1303,27 @@ class TosaArgGen: and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1] ): - arg_list.append( - ( - "st{}_kern{}_pad{}".format( - "".join([str(x) for x in s]), - "".join([str(x) for x in k]), - "".join([str(x) for x in p]), - ), - [s, p, k], + remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0] + remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1] + if ( + # the parameters must produce integer exact output + error_name != ErrorIf.PoolingOutputShapeNonInteger + and remainder_h == 0 + and remainder_w == 0 + ) or ( + error_name == ErrorIf.PoolingOutputShapeNonInteger + and (remainder_h != 0 or remainder_w != 0) + ): + arg_list.append( + ( + "st{}_kern{}_pad{}".format( + "".join([str(x) for x in s]), + "".join([str(x) for x in k]), + "".join([str(x) for x in p]), + ), + [s, p, k], + ) ) - ) n += 1 return arg_list |