diff options
author | TatWai Chong <tatwai.chong@arm.com> | 2022-06-08 00:48:04 -0700 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2022-06-13 21:56:27 -0700 |
commit | 24594f55ee3bf0e95c764e51b94c3ec7f9cfa54a (patch) | |
tree | 2c4b4a1062b3ab2d204b306b2bd4017d9803e122 /verif | |
parent | 61f6622945d8ef339c99c4b437f985c62aa81bcf (diff) | |
download | reference_model-24594f55ee3bf0e95c764e51b94c3ec7f9cfa54a.tar.gz |
Update transpose_conv2d to align with TOSA spec
Rename outpad to out_pad, and also fix the dilation in the generator.
Change-Id: I4c1599871f0d0b41856e819d8c644a85ca6d8267
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Diffstat (limited to 'verif')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 38 | ||||
-rw-r--r-- | verif/generator/tosa_error_if.py | 18 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 9 |
3 files changed, 24 insertions, 41 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index b5e68dd..a27d849 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1134,10 +1134,6 @@ class TosaArgGen: else: s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)] strides = {x for x in itertools.product(*([s_vals] * 2))} - # Dilation is not supported by the specification for transpose conv2d - # TODO: Remove this completely when schema has been updated - d_vals = [1] - dilations = {x for x in itertools.product(*([d_vals] * 2))} if not error_name and testGen.args.oversize: # add some oversize argument values @@ -1152,7 +1148,7 @@ class TosaArgGen: # There are too many parameter combinations, so generate them sparsely, # very sparse for negative tests sparsity_factor = 2 if error_name else 10 - sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1 + sparsity = len(paddings) * len(strides) // sparsity_factor + 1 # If there are only a small number of tests, just select them all if sparsity < 13: sparsity = 1 @@ -1164,24 +1160,22 @@ class TosaArgGen: n = 0 for s in sorted(list(strides)): for p in sorted(list(paddings)): - for d in sorted(list(dilations)): - if n % sparsity == 0: - # Determine the output shape - oh = (ifm_shape[1] - 1) * s[0] - p[0] - p[1] + filter_shape[1] - ow = (ifm_shape[2] - 1) * s[1] - p[2] - p[3] + filter_shape[2] - os = [ifm_shape[0], oh, ow, filter_shape[0]] - arg_list.append( - ( - "st{}_pad{}_dilat{}_os{}".format( - "".join([str(x) for x in s]), - "".join([str(x) for x in p]), - "".join([str(x) for x in d]), - "x".join([str(x) for x in os]), - ), - [s, p, d, os], - ) + if n % sparsity == 0: + # Determine the output shape + oh = (ifm_shape[1] - 1) * s[0] - p[0] - p[1] + filter_shape[1] + ow = (ifm_shape[2] - 1) * s[1] - p[2] - p[3] + filter_shape[2] + os = [ifm_shape[0], oh, ow, filter_shape[0]] + arg_list.append( + ( + "st{}_pad{}_os{}".format( + "".join([str(x) for x in s]), + "".join([str(x) for x in p]), + "x".join([str(x) for x in os]), + ), + [s, p, os], ) - n += 1 + ) + n += 1 return arg_list diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 1900d8a..1967d8a 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -2259,32 +2259,24 @@ class TosaInvalidValidator: if opName.startswith("transpose_conv2d"): # transpose_conv2d - dilations = args[2] - output_shape = args[3] + output_shape = args[2] filter_shape = inputShapes[1] kernel_shape = filter_shape[1:-1] - def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad): + def get_out_size(in_size, stride, kernel_size, out_pad, in_pad): """Calculate the transpose_conv2d output size for a dimension. - Based on the keras function deconv_output_length, in - https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py - Args: in_size: the input size - int stride: the stride - int kernel_size: the kernel size - int - dilation: the kernel dilation - int out_pad: the output padding - int in_pad: the input padding - int Returns: the output size """ - dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) - return ( - (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad - ) + return (in_size - 1) * stride + kernel_size - in_pad - out_pad for pad_h, pad_w in ( (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding @@ -2295,7 +2287,6 @@ class TosaInvalidValidator: input_shape[1], strides[0], kernel_shape[0], - dilations[0], padding[0], pad_h, ) @@ -2303,7 +2294,6 @@ class TosaInvalidValidator: input_shape[2], strides[1], kernel_shape[1], - dilations[1], padding[1], pad_w, ) @@ -2341,7 +2331,7 @@ class TosaInvalidValidator: @staticmethod def ivNonPositiveOutputShape(**kwargs): args = kwargs["args"] - output_shape = args[3] + output_shape = args[2] if output_shape[1] <= 0 or output_shape[2] <= 0: # Negative output shape return True diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index fc2e476..262a652 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -718,14 +718,13 @@ class TosaTestGen: filter, bias, stride, - outpad, - dilation, + out_pad, output_shape, validator_fcns=None, error_name=None, qinfo=None, ): - assert len(outpad) == 4 + assert len(out_pad) == 4 result_tens = OutputShaper.transposeConv2DOp( self.ser, self.rng, ifm, output_shape, error_name ) @@ -761,7 +760,7 @@ class TosaTestGen: input_list=input_list, num_operands=num_operands, output_list=output_list, - pad=outpad, + pad=out_pad, stride=stride, input_shape=ifm.shape, weight_shape=filter.shape, @@ -770,7 +769,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.TransposeConvAttribute(outpad, stride, dilation, output_shape) + attr.TransposeConvAttribute(out_pad, stride, output_shape) self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens |