aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py18
1 files changed, 4 insertions, 14 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 1900d8a..1967d8a 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -2259,32 +2259,24 @@ class TosaInvalidValidator:
if opName.startswith("transpose_conv2d"):
# transpose_conv2d
- dilations = args[2]
- output_shape = args[3]
+ output_shape = args[2]
filter_shape = inputShapes[1]
kernel_shape = filter_shape[1:-1]
- def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad):
+ def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
"""Calculate the transpose_conv2d output size for a dimension.
- Based on the keras function deconv_output_length, in
- https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py
-
Args:
in_size: the input size - int
stride: the stride - int
kernel_size: the kernel size - int
- dilation: the kernel dilation - int
out_pad: the output padding - int
in_pad: the input padding - int
Returns:
the output size
"""
- dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
- return (
- (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
- )
+ return (in_size - 1) * stride + kernel_size - in_pad - out_pad
for pad_h, pad_w in (
(kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
@@ -2295,7 +2287,6 @@ class TosaInvalidValidator:
input_shape[1],
strides[0],
kernel_shape[0],
- dilations[0],
padding[0],
pad_h,
)
@@ -2303,7 +2294,6 @@ class TosaInvalidValidator:
input_shape[2],
strides[1],
kernel_shape[1],
- dilations[1],
padding[1],
pad_w,
)
@@ -2341,7 +2331,7 @@ class TosaInvalidValidator:
@staticmethod
def ivNonPositiveOutputShape(**kwargs):
args = kwargs["args"]
- output_shape = args[3]
+ output_shape = args[2]
if output_shape[1] <= 0 or output_shape[2] <= 0:
# Negative output shape
return True