aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-04-13 17:18:19 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2023-04-26 14:32:01 +0100
commit0c71686875618b2e11290273b7a05b88ef8a8aae (patch)
tree533051a0c5befee3640639fdbd3fe122da21dd40 /verif/generator/tosa_error_if.py
parentb2099706b3db022e8c4d85c4ae863086630e0678 (diff)
downloadreference_model-0c71686875618b2e11290273b7a05b88ef8a8aae.tar.gz
8K levels: Tensor op tests kernel/stride at 8192 maximums
Operators updated: AVG_POOL2D, MAX_POOL2D, CONV2D, CONV3D, DEPTHWISE_CONV2D & TRANSPOSE_CONV2D tosa_verif_build_tests argument --level-8k-sizes used to allow kernel/stride maximum boundary testing Fixed bugs in height/width validator function meaning some esixting avg_pool2d float tests need regening. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I7aeab82d3bd3c49d02d54708f2c9d995cd3cf2df
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py78
1 files changed, 32 insertions, 46 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index b19d5e9..8c40371 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -2547,14 +2547,16 @@ class TosaInvalidValidator:
args = kwargs["args"]
- # MaxPool2D has no accum_dtype arg
- stride_idx, pad_idx = (0, 1) if opName == "max_pool2d" else (1, 2)
+ # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
+ stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
+
+ # Common info for all ops
strides = args[stride_idx]
padding = args[pad_idx]
if opName.endswith("pool2d"):
# avg_pool2d, max_pool2d
- kernel_shape = args[2]
+ kernel_shape = args[pad_idx + 1]
h = (
input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
) // strides[0]
@@ -2566,53 +2568,36 @@ class TosaInvalidValidator:
if opName.startswith("transpose_conv2d"):
# transpose_conv2d
- output_shape = args[2]
+ output_shape = args[pad_idx + 1]
filter_shape = inputShapes[1]
kernel_shape = filter_shape[1:-1]
def get_out_size(in_size, stride, kernel_size, out_pad, in_pad):
- """Calculate the transpose_conv2d output size for a dimension.
-
- Args:
- in_size: the input size - int
- stride: the stride - int
- kernel_size: the kernel size - int
- out_pad: the output padding - int
- in_pad: the input padding - int
-
- Returns:
- the output size
- """
- return (in_size - 1) * stride + kernel_size - in_pad - out_pad
-
- for pad_h, pad_w in (
- (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
- (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
- (0, 0), # VALID padding
- ):
- h = get_out_size(
- input_shape[1],
- strides[0],
- kernel_shape[0],
- padding[0],
- pad_h,
- )
- w = get_out_size(
- input_shape[2],
- strides[1],
- kernel_shape[1],
- padding[1],
- pad_w,
- )
- if output_shape[1] == h and output_shape[2] == w:
- return False
-
- # output shape does not match the expected shape for any padding option
+ """Calculate the transpose_conv2d output size for a dimension."""
+ return (in_size - 1) * stride + kernel_size + in_pad + out_pad
+
+ h = get_out_size(
+ input_shape[1],
+ strides[0],
+ kernel_shape[0],
+ padding[0],
+ padding[1],
+ )
+ w = get_out_size(
+ input_shape[2],
+ strides[1],
+ kernel_shape[1],
+ padding[2],
+ padding[3],
+ )
+ if output_shape[1] == h and output_shape[2] == w:
+ return False
+ # output shape does not match the expected shape
return True
if "conv2d" in opName or "conv3d" in opName:
# conv2d, conv3d, depthwise_conv2d
- dilations = args[2]
+ dilations = args[pad_idx + 1]
filter_shape = inputShapes[1]
kernel_shape = (
filter_shape[0:2]
@@ -2621,12 +2606,13 @@ class TosaInvalidValidator:
)
for i in range(len(kernel_shape)):
+ pad_offset = i * 2
dim = (
input_shape[i + 1]
- - kernel_shape[i]
- - (kernel_shape[i] - 1) * (dilations[i] - 1)
- + padding[i * 2 + 0]
- + padding[i * 2 + 1]
+ - 1
+ + padding[pad_offset]
+ + padding[pad_offset + 1]
+ - (kernel_shape[i] - 1) * dilations[i]
) // strides[i] + 1
# return True if any dimension is < 1
if dim < 1: