From 24594f55ee3bf0e95c764e51b94c3ec7f9cfa54a Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Wed, 8 Jun 2022 00:48:04 -0700 Subject: Update transpose_conv2d to align with TOSA spec Rename outpad to out_pad, and also fix the dilation in the generator. Change-Id: I4c1599871f0d0b41856e819d8c644a85ca6d8267 Signed-off-by: TatWai Chong --- verif/generator/tosa_error_if.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) (limited to 'verif/generator/tosa_error_if.py') diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 1900d8a..1967d8a 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -2259,32 +2259,24 @@ class TosaInvalidValidator: if opName.startswith("transpose_conv2d"): # transpose_conv2d - dilations = args[2] - output_shape = args[3] + output_shape = args[2] filter_shape = inputShapes[1] kernel_shape = filter_shape[1:-1] - def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad): + def get_out_size(in_size, stride, kernel_size, out_pad, in_pad): """Calculate the transpose_conv2d output size for a dimension. - Based on the keras function deconv_output_length, in - https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py - Args: in_size: the input size - int stride: the stride - int kernel_size: the kernel size - int - dilation: the kernel dilation - int out_pad: the output padding - int in_pad: the input padding - int Returns: the output size """ - dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) - return ( - (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad - ) + return (in_size - 1) * stride + kernel_size - in_pad - out_pad for pad_h, pad_w in ( (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding @@ -2295,7 +2287,6 @@ class TosaInvalidValidator: input_shape[1], strides[0], kernel_shape[0], - dilations[0], padding[0], pad_h, ) @@ -2303,7 +2294,6 @@ class TosaInvalidValidator: input_shape[2], strides[1], kernel_shape[1], - dilations[1], padding[1], pad_w, ) @@ -2341,7 +2331,7 @@ class TosaInvalidValidator: @staticmethod def ivNonPositiveOutputShape(**kwargs): args = kwargs["args"] - output_shape = args[3] + output_shape = args[2] if output_shape[1] <= 0 or output_shape[2] <= 0: # Negative output shape return True -- cgit v1.2.1