aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py78
1 files changed, 32 insertions, 46 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index b19d5e9..8c40371 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -2547,14 +2547,16 @@ class TosaInvalidValidator:
args = kwargs["args"]
- # MaxPool2D has no accum_dtype arg
- stride_idx, pad_idx = (0, 1) if opName == "max_pool2d" else (1, 2)
+ # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
+ stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
+
+ # Common info for all ops
strides = args[stride_idx]
padding = args[pad_idx]
if opName.endswith("pool2d"):
# avg_pool2d, max_pool2d
- kernel_shape = args[2]
+ kernel_shape = args[pad_idx + 1]
h = (
input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
) // strides[0]
@@ -2566,53 +2568,36 @@ class TosaInvalidValidator:
if opName.startswith("transpose_conv2d"):
# transpose_conv2d
- output_shape = args[2]
+ output_shape = args[pad_idx + 1]
filter_shape = inputShapes[1]
kernel_shape = filter_shape[1:-1]
def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
- """Calculate the transpose_conv2d output size for a dimension.
-
- Args:
- in_size: the input size - int
- stride: the stride - int
- kernel_size: the kernel size - int
- out_pad: the output padding - int
- in_pad: the input padding - int
-
- Returns:
- the output size
- """
- return (in_size - 1) * stride + kernel_size - in_pad - out_pad
-
- for pad_h, pad_w in (
- (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
- (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
- (0, 0), # VALID padding
- ):
- h = get_out_size(
- input_shape[1],
- strides[0],
- kernel_shape[0],
- padding[0],
- pad_h,
- )
- w = get_out_size(
- input_shape[2],
- strides[1],
- kernel_shape[1],
- padding[1],
- pad_w,
- )
- if output_shape[1] == h and output_shape[2] == w:
- return False
-
- # output shape does not match the expected shape for any padding option
+ """Calculate the transpose_conv2d output size for a dimension."""
+ return (in_size - 1) * stride + kernel_size + in_pad + out_pad
+
+ h = get_out_size(
+ input_shape[1],
+ strides[0],
+ kernel_shape[0],
+ padding[0],
+ padding[1],
+ )
+ w = get_out_size(
+ input_shape[2],
+ strides[1],
+ kernel_shape[1],
+ padding[2],
+ padding[3],
+ )
+ if output_shape[1] == h and output_shape[2] == w:
+ return False
+ # output shape does not match the expected shape
return True
if "conv2d" in opName or "conv3d" in opName:
# conv2d, conv3d, depthwise_conv2d
- dilations = args[2]
+ dilations = args[pad_idx + 1]
filter_shape = inputShapes[1]
kernel_shape = (
filter_shape[0:2]
@@ -2621,12 +2606,13 @@ class TosaInvalidValidator:
)
for i in range(len(kernel_shape)):
+ pad_offset = i * 2
dim = (
input_shape[i + 1]
- - kernel_shape[i]
- - (kernel_shape[i] - 1) * (dilations[i] - 1)
- + padding[i * 2 + 0]
- + padding[i * 2 + 1]
+ - 1
+ + padding[pad_offset]
+ + padding[pad_offset + 1]
+ - (kernel_shape[i] - 1) * dilations[i]
) // strides[i] + 1
# return True if any dimension is < 1
if dim < 1: