aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py196
1 files changed, 179 insertions, 17 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index caf63e3..e7e758f 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -42,6 +42,9 @@ class ErrorIf(object):
PadSmallerZero = "PadSmallerZero"
PadLargerEqualKernel = "PadLargerEqualKernel"
PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
+ PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
+ ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
+ ConvOutputShapeNonInteger = "ConvOutputShapeNonInteger"
ScaleNotTrue = "ScaleNotTrue"
ScaleTrue = "ScaleTrue"
TensorSizeInputOutputMismatch = "TensorSizeInputOutputMismatch"
@@ -1226,6 +1229,20 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def checkPoolingParams(kernel, stride, pad):
+ return (
+ min(kernel) >= 1
+ and min(stride) >= 1
+ and min(pad) >= 0
+ and not (
+ pad[0] >= kernel[0]
+ or pad[1] >= kernel[0]
+ or pad[2] >= kernel[1]
+ or pad[3] >= kernel[1]
+ )
+ )
+
+ @staticmethod
def evPoolingOutputShapeMismatch(check=False, **kwargs):
error_name = ErrorIf.PoolingOutputShapeMismatch
param_reqs = {"rank": None, "dtype": None, "shape": None}
@@ -1252,25 +1269,11 @@ class TosaErrorValidator:
# calculate correct height, width dimensions
if stride_x != 0 and stride_y != 0:
- y_correct = (
- IH + pad_top + pad_bottom + stride_y - kernel_y
- ) // stride_y
- x_correct = (
- IW + pad_left + pad_right + stride_x - kernel_x
- ) // stride_x
+ y_correct = ((IH + pad_top + pad_bottom - kernel_y) // stride_y) + 1
+ x_correct = ((IW + pad_left + pad_right - kernel_x) // stride_x) + 1
# ensure parameters are valid
- params_valid = (
- min(kernel) >= 1
- and min(stride) >= 1
- and min(pad) >= 0
- and not (
- pad[0] >= kernel[0]
- or pad[1] >= kernel[0]
- or pad[2] >= kernel[1]
- or pad[3] >= kernel[1]
- )
- )
+ params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
if params_valid and (OH != y_correct or OW != x_correct):
error_result = True
@@ -1284,6 +1287,165 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evPoolingOutputShapeNonInteger(check=False, **kwargs):
+ error_name = ErrorIf.PoolingOutputShapeNonInteger
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Parameters do not yield exact integer output dimensions"
+
+ if check:
+ pad = kwargs["pad"]
+ pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
+
+ kernel = kwargs["kernel"]
+ kernel_y, kernel_x = kernel[0], kernel[1]
+
+ input_shape = kwargs["input_shape"]
+ IH, IW = input_shape[1], input_shape[2]
+
+ stride = kwargs["stride"]
+ stride_y, stride_x = stride[0], stride[1]
+
+ # calculate remainder of height, width dimensions
+ if stride_x != 0 and stride_y != 0:
+ y_remainder = (IH + pad_top + pad_bottom - kernel_y) % stride_y
+ x_remainder = (IW + pad_left + pad_right - kernel_x) % stride_x
+
+ # ensure parameters are valid
+ params_valid = TosaErrorValidator.checkPoolingParams(kernel, stride, pad)
+ if params_valid and (y_remainder != 0 or x_remainder != 0):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def checkConvParams(weight_shape, stride, pad, dilation):
+ return (
+ # Check kernel sizes
+ min(weight_shape[1:-1]) >= 1
+ and min(stride) >= 1
+ and min(pad) >= 0
+ and (dilation is None or min(dilation) >= 1)
+ )
+
+ @staticmethod
+ def evConvOutputShapeMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ConvOutputShapeMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = (
+ "Mismatch between output shape provided and expected output shape"
+ )
+
+ if check:
+ op = kwargs["op"]
+ pad = kwargs["pad"]
+ weight_shape = kwargs["weight_shape"]
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs["output_shape"]
+ dilation = kwargs["dilation"] if op["op"] != Op.TRANSPOSE_CONV2D else None
+ stride = kwargs["stride"]
+
+ kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
+
+ # calculate correct dimensions
+ dims_correct = []
+ if min(stride) > 0:
+ for index in range(len(stride)):
+ pad_offset = index * 2
+ if op["op"] == Op.TRANSPOSE_CONV2D:
+ dims_correct.append(
+ (input_shape[index + 1] - 1) * stride[index]
+ - pad[pad_offset]
+ - pad[pad_offset + 1]
+ + weight_shape[index + kernel_offset]
+ )
+ else:
+ dims_correct.append(
+ (
+ input_shape[index + 1]
+ - 1
+ + pad[pad_offset]
+ + pad[pad_offset + 1]
+ - (weight_shape[index + kernel_offset] - 1)
+ * dilation[index]
+ )
+ // stride[index]
+ + 1
+ )
+
+ # ensure parameters are valid
+ params_valid = TosaErrorValidator.checkConvParams(
+ weight_shape, stride, pad, dilation
+ )
+
+ if params_valid and output_shape[1:-1] != dims_correct:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evConvOutputShapeNonInteger(check=False, **kwargs):
+ error_name = ErrorIf.ConvOutputShapeNonInteger
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Parameters do not yield exact integer output dimensions"
+
+ if check:
+ op = kwargs["op"]
+ pad = kwargs["pad"]
+ weight_shape = kwargs["weight_shape"]
+ input_shape = kwargs["input_shape"]
+ dilation = kwargs["dilation"]
+ stride = kwargs["stride"]
+
+ kernel_offset = 0 if op["op"] == Op.DEPTHWISE_CONV2D else 1
+
+ # calculate correct height, width dimensions
+ remainders = []
+ if min(stride) > 0:
+ for index in range(len(stride)):
+ pad_offset = index * 2
+ remainders.append(
+ (
+ input_shape[index + 1]
+ - 1
+ + pad[pad_offset]
+ + pad[pad_offset + 1]
+ - (weight_shape[index + kernel_offset] - 1)
+ * dilation[index]
+ )
+ % stride[index]
+ )
+
+ # ensure parameters are valid
+ params_valid = TosaErrorValidator.checkConvParams(
+ weight_shape, stride, pad, dilation
+ )
+ if params_valid and max(remainders) > 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
def evArgmaxOutputShapeMismatch(check=False, **kwargs):
error_name = ErrorIf.ArgmaxOutputShapeMismatch
param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}