diff options
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 47 |
1 files changed, 35 insertions, 12 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index e4e60b7..f9a00f9 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -1261,15 +1261,28 @@ class TosaErrorValidator: if check: pad = kwargs["pad"] - kernel = kwargs["kernel"] - if min(pad) > 0 and min(kernel) > 1: + op = kwargs["op"] + if op["op"] == Op.TRANSPOSE_CONV2D: + # transpose_conv2d + kernel = kwargs["weight_shape"][1:-1] if ( - pad[0] >= kernel[0] - or pad[1] >= kernel[0] - or pad[2] >= kernel[1] - or pad[3] >= kernel[1] + pad[0] <= -kernel[0] + or pad[1] <= -kernel[0] + or pad[2] <= -kernel[1] + or pad[3] <= -kernel[1] ): error_result = True + else: + # pooling op + kernel = kwargs["kernel"] + if min(pad) > 0 and min(kernel) > 1: + if ( + pad[0] >= kernel[0] + or pad[1] >= kernel[0] + or pad[2] >= kernel[1] + or pad[3] >= kernel[1] + ): + error_result = True info_dict = { "error_name": error_name, @@ -1400,12 +1413,22 @@ class TosaErrorValidator: return info_dict @staticmethod - def checkConvParams(weight_shape, stride, pad, dilation): + def checkConvParams(op, weight_shape, stride, pad, dilation): + if op == Op.TRANSPOSE_CONV2D: + pad_ok = ( + pad[0] > -weight_shape[1] + and pad[1] > -weight_shape[1] + and pad[2] > -weight_shape[2] + and pad[3] > -weight_shape[2] + ) + else: + pad_ok = min(pad) >= 0 + return ( # Check kernel sizes min(weight_shape[1:-1]) >= 1 and min(stride) >= 1 - and min(pad) >= 0 + and pad_ok and (dilation is None or min(dilation) >= 1) ) @@ -1437,8 +1460,8 @@ class TosaErrorValidator: if op["op"] == Op.TRANSPOSE_CONV2D: dims_correct.append( (input_shape[index + 1] - 1) * stride[index] - - pad[pad_offset] - - pad[pad_offset + 1] + + pad[pad_offset] + + pad[pad_offset + 1] + weight_shape[index + kernel_offset] ) else: @@ -1457,7 +1480,7 @@ class TosaErrorValidator: # ensure parameters are valid params_valid = TosaErrorValidator.checkConvParams( - weight_shape, stride, pad, dilation + op["op"], weight_shape, stride, pad, dilation ) if params_valid and output_shape[1:-1] != dims_correct: @@ -1507,7 +1530,7 @@ class TosaErrorValidator: # ensure parameters are valid params_valid = TosaErrorValidator.checkConvParams( - weight_shape, stride, pad, dilation + op["op"], weight_shape, stride, pad, dilation ) if params_valid and max(remainders) > 0: error_result = True |