aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py47
1 files changed, 35 insertions, 12 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index e4e60b7..f9a00f9 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1261,15 +1261,28 @@ class TosaErrorValidator:
if check:
pad = kwargs["pad"]
- kernel = kwargs["kernel"]
- if min(pad) > 0 and min(kernel) > 1:
+ op = kwargs["op"]
+ if op["op"] == Op.TRANSPOSE_CONV2D:
+ # transpose_conv2d
+ kernel = kwargs["weight_shape"][1:-1]
if (
- pad[0] >= kernel[0]
- or pad[1] >= kernel[0]
- or pad[2] >= kernel[1]
- or pad[3] >= kernel[1]
+ pad[0] <= -kernel[0]
+ or pad[1] <= -kernel[0]
+ or pad[2] <= -kernel[1]
+ or pad[3] <= -kernel[1]
):
error_result = True
+ else:
+ # pooling op
+ kernel = kwargs["kernel"]
+ if min(pad) > 0 and min(kernel) > 1:
+ if (
+ pad[0] >= kernel[0]
+ or pad[1] >= kernel[0]
+ or pad[2] >= kernel[1]
+ or pad[3] >= kernel[1]
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -1400,12 +1413,22 @@ class TosaErrorValidator:
return info_dict
@staticmethod
- def checkConvParams(weight_shape, stride, pad, dilation):
+ def checkConvParams(op, weight_shape, stride, pad, dilation):
+ if op == Op.TRANSPOSE_CONV2D:
+ pad_ok = (
+ pad[0] > -weight_shape[1]
+ and pad[1] > -weight_shape[1]
+ and pad[2] > -weight_shape[2]
+ and pad[3] > -weight_shape[2]
+ )
+ else:
+ pad_ok = min(pad) >= 0
+
return (
# Check kernel sizes
min(weight_shape[1:-1]) >= 1
and min(stride) >= 1
- and min(pad) >= 0
+ and pad_ok
and (dilation is None or min(dilation) >= 1)
)
@@ -1437,8 +1460,8 @@ class TosaErrorValidator:
if op["op"] == Op.TRANSPOSE_CONV2D:
dims_correct.append(
(input_shape[index + 1] - 1) * stride[index]
- - pad[pad_offset]
- - pad[pad_offset + 1]
+ + pad[pad_offset]
+ + pad[pad_offset + 1]
+ weight_shape[index + kernel_offset]
)
else:
@@ -1457,7 +1480,7 @@ class TosaErrorValidator:
# ensure parameters are valid
params_valid = TosaErrorValidator.checkConvParams(
- weight_shape, stride, pad, dilation
+ op["op"], weight_shape, stride, pad, dilation
)
if params_valid and output_shape[1:-1] != dims_correct:
@@ -1507,7 +1530,7 @@ class TosaErrorValidator:
# ensure parameters are valid
params_valid = TosaErrorValidator.checkConvParams(
- weight_shape, stride, pad, dilation
+ op["op"], weight_shape, stride, pad, dilation
)
if params_valid and max(remainders) > 0:
error_result = True