diff options
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 196 |
1 files changed, 179 insertions, 17 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index caf63e3..e7e758f 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -42,6 +42,9 @@ class ErrorIf(object): PadSmallerZero = "PadSmallerZero" PadLargerEqualKernel = "PadLargerEqualKernel" PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch" + PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger" + ConvOutputShapeMismatch = "ConvOutputShapeMismatch" + ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger" ScaleNotTrue = "ScaleNotTrue" ScaleTrue = "ScaleTrue" TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch" @@ -1226,6 +1229,20 @@ class TosaErrorValidator: return info_dict @staticmethod + def checkPoolingParams(kernel, stride, pad): + return ( + min(kernel) >= 1 + and min(stride) >= 1 + and min(pad) >= 0 + and not ( + pad[0] >= kernel[0] + or pad[1] >= kernel[0] + or pad[2] >= kernel[1] + or pad[3] >= kernel[1] + ) + ) + + @staticmethod def evPoolingOutputShapeMismatch(check=False, **kwargs): error_name = ErrorIf.PoolingOutputShapeMismatch param_reqs = {"rank": None, "dtype": None, "shape": None} @@ -1252,25 +1269,11 @@ class TosaErrorValidator: # calculate correct height, width dimensions if stride_x != 0 and stride_y != 0: - y_correct = ( - IH + pad_top + pad_bottom + stride_y - kernel_y - ) // stride_y - x_correct = ( - IW + pad_left + pad_right + stride_x - kernel_x - ) // stride_x + y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1 + x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1 # ensure parameters are valid - params_valid = ( - min(kernel) >= 1 - and min(stride) >= 1 - and min(pad) >= 0 - and not ( - pad[0] >= kernel[0] - or pad[1] >= kernel[0] - or pad[2] >= kernel[1] - or pad[3] >= kernel[1] - ) - ) + params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad) if params_valid and (OH != y_correct or OW != x_correct): error_result = True @@ -1284,6 +1287,165 @@ class TosaErrorValidator: return info_dict @staticmethod + def evPoolingOutputShapeNonInteger(check=False, **kwargs): + error_name = ErrorIf.PoolingOutputShapeNonInteger + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Parameters do not yield exact integer output dimensions" + + if check: + pad = kwargs["pad"] + pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3] + + kernel = kwargs["kernel"] + kernel_y, kernel_x = kernel[0], kernel[1] + + input_shape = kwargs["input_shape"] + IH, IW = input_shape[1], input_shape[2] + + stride = kwargs["stride"] + stride_y, stride_x = stride[0], stride[1] + + # calculate remainder of height, width dimensions + if stride_x != 0 and stride_y != 0: + y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y + x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x + + # ensure parameters are valid + params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad) + if params_valid and (y_remainder != 0 or x_remainder != 0): + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + + @staticmethod + def checkConvParams(weight_shape, stride, pad, dilation): + return ( + # Check kernel sizes + min(weight_shape[1:-1]) >= 1 + and min(stride) >= 1 + and min(pad) >= 0 + and (dilation is None or min(dilation) >= 1) + ) + + @staticmethod + def evConvOutputShapeMismatch(check=False, **kwargs): + error_name = ErrorIf.ConvOutputShapeMismatch + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = ( + "Mismatch between output shape provided and expected output shape" + ) + + if check: + op = kwargs["op"] + pad = kwargs["pad"] + weight_shape = kwargs["weight_shape"] + input_shape = kwargs["input_shape"] + output_shape = kwargs["output_shape"] + dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None + stride = kwargs["stride"] + + kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1 + + # calculate correct dimensions + dims_correct = [] + if min(stride) > 0: + for index in range(len(stride)): + pad_offset = index * 2 + if op["op"] == Op.TRANSPOSE_CONV2D: + dims_correct.append( + (input_shape[index + 1] - 1) * stride[index] + - pad[pad_offset] + - pad[pad_offset + 1] + + weight_shape[index + kernel_offset] + ) + else: + dims_correct.append( + ( + input_shape[index + 1] + - 1 + + pad[pad_offset] + + pad[pad_offset + 1] + - (weight_shape[index + kernel_offset] - 1) + * dilation[index] + ) + // stride[index] + + 1 + ) + + # ensure parameters are valid + params_valid = TosaErrorValidator.checkConvParams( + weight_shape, stride, pad, dilation + ) + + if params_valid and output_shape[1:-1] != dims_correct: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + + @staticmethod + def evConvOutputShapeNonInteger(check=False, **kwargs): + error_name = ErrorIf.ConvOutputShapeNonInteger + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Parameters do not yield exact integer output dimensions" + + if check: + op = kwargs["op"] + pad = kwargs["pad"] + weight_shape = kwargs["weight_shape"] + input_shape = kwargs["input_shape"] + dilation = kwargs["dilation"] + stride = kwargs["stride"] + + kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1 + + # calculate correct height, width dimensions + remainders = [] + if min(stride) > 0: + for index in range(len(stride)): + pad_offset = index * 2 + remainders.append( + ( + input_shape[index + 1] + - 1 + + pad[pad_offset] + + pad[pad_offset + 1] + - (weight_shape[index + kernel_offset] - 1) + * dilation[index] + ) + % stride[index] + ) + + # ensure parameters are valid + params_valid = TosaErrorValidator.checkConvParams( + weight_shape, stride, pad, dilation + ) + if params_valid and max(remainders) > 0: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + + @staticmethod def evArgmaxOutputShapeMismatch(check=False, **kwargs): error_name = ErrorIf.ArgmaxOutputShapeMismatch param_reqs = {"rank": [2, 4], "dtype": None, "shape": None} |