aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py138
1 files changed, 91 insertions, 47 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index e7e758f..1900d8a 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -68,6 +68,8 @@ class ErrorIf(object):
InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
+ U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
+ U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
class TosaErrorIfArgGen:
@@ -227,14 +229,26 @@ class TosaErrorIfArgGen:
if input_dtype == DType.INT8:
if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
return True
- if input_dtype in [DType.INT16, DType.INT32]:
+ elif input_dtype == DType.INT16:
+ if output_dtype not in [
+ DType.UINT8,
+ DType.INT8,
+ DType.UINT16,
+ DType.INT16,
+ DType.INT32,
+ ]:
+ return True
+ elif input_dtype == DType.INT32:
if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
return True
elif input_dtype == DType.INT48:
if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
return True
elif input_dtype == DType.UINT8:
- if output_dtype != DType.INT8:
+ if output_dtype not in [DType.INT8, DType.INT16]:
+ return True
+ elif input_dtype == DType.UINT16:
+ if output_dtype != DType.INT16:
return True
return False
@@ -418,23 +432,9 @@ class TosaErrorValidator:
error_result = True
elif op["op"] == Op.RESCALE:
- if input_dtype == DType.INT8:
- if output_dtype not in [
- DType.UINT8,
- DType.INT8,
- DType.INT16,
- DType.INT32,
- ]:
- error_result = True
- if input_dtype in [DType.INT16, DType.INT32]:
- if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
- error_result = True
- elif input_dtype == DType.INT48:
- if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
- error_result = True
- elif input_dtype == DType.UINT8:
- if output_dtype != DType.INT8:
- error_result = True
+ error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
+ input_dtype, output_dtype
+ )
elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
if (
@@ -998,12 +998,25 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def _getZeroPoint(qinfo, index):
+ """Return zero point value from quantization info.
+
+ Generally input_zp is index 0, output_zp is index 1
+ """
+ if isinstance(qinfo, tuple):
+ zero_point = qinfo[index]
+ else:
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+ zero_point = qinfo.ints[index][1]
+ return zero_point
+
+ @staticmethod
def evInputZeroPointNotZero(check=False, **kwargs):
op = kwargs["op"]
error_result = False
# Quantizable types
- qTypes = (DType.INT8, DType.UINT8)
+ qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
# This does not apply to quantizable types
inputDtypes = [
@@ -1015,19 +1028,12 @@ class TosaErrorValidator:
if check:
input_dtype = kwargs["input_dtype"]
- if isinstance(kwargs["qinfo"], tuple):
- qinfo = kwargs["qinfo"]
- input_zero_point = qinfo[0]
- else:
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- qinfo = kwargs["qinfo"].ints
- input_zero_point = qinfo[0][1]
-
+ input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
if op["op"] == Op.MATMUL:
- qinfo = kwargs["qinfo"].ints
+ input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
for dtype, zp in (
- (kwargs["input_dtype"], qinfo[0][1]),
- (kwargs["input2_dtype"], qinfo[1][1]),
+ (kwargs["input_dtype"], input_zero_point),
+ (kwargs["input2_dtype"], input2_zero_point),
):
if dtype not in qTypes and zp != 0:
error_result = True
@@ -1059,9 +1065,7 @@ class TosaErrorValidator:
if check:
weight_dtype = kwargs["weight_dtype"]
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
- qinfo = kwargs["qinfo"].ints
- weight_zero_point = qinfo[1][1]
+ weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
if weight_dtype != DType.INT8 and weight_zero_point != 0:
error_result = True
@@ -1076,11 +1080,9 @@ class TosaErrorValidator:
@staticmethod
def evOutputZeroPointNotZero(check=False, **kwargs):
op = kwargs["op"]
- inputDtypes = op["types"].copy()
- if DType.INT8 in inputDtypes:
- inputDtypes.remove(DType.INT8)
- if DType.UINT8 in inputDtypes:
- inputDtypes.remove(DType.UINT8)
+ inputDtypes = [
+ t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
+ ]
error_name = ErrorIf.OutputZeroPointNotZero
param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
@@ -1090,18 +1092,13 @@ class TosaErrorValidator:
if check:
input_dtype = kwargs["input_dtype"]
output_dtype = kwargs["output_dtype"]
- if isinstance(kwargs["qinfo"], tuple):
- qinfo = kwargs["qinfo"]
- output_zero_point = qinfo[1]
- else:
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- qinfo = kwargs["qinfo"].ints
- output_zero_point = qinfo[1][1]
+ output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
if op["op"] == Op.AVG_POOL2D:
if input_dtype != DType.INT8 and output_zero_point != 0:
error_result = True
elif (
- output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0
+ output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
+ and output_zero_point != 0
):
error_result = True
@@ -1114,6 +1111,53 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evU16InputZeroPointNotValid(check=False, **kwargs):
+ error_name = ErrorIf.U16InputZeroPointNotValid
+ param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
+ error_result = False
+ error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
+
+ if check:
+ input_dtype = kwargs["input_dtype"]
+ input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
+ error_result = input_dtype == DType.UINT16 and input_zero_point not in [
+ 0,
+ 32768,
+ ]
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evU16OutputZeroPointNotValid(check=False, **kwargs):
+ error_name = ErrorIf.U16OutputZeroPointNotValid
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
+
+ if check:
+ output_dtype = kwargs["output_dtype"]
+ output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
+
+ error_result = output_dtype == DType.UINT16 and output_zero_point not in [
+ 0,
+ 32768,
+ ]
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
def evAxisSmallerZero(check=False, **kwargs):
error_name = ErrorIf.AxisSmallerZero
param_reqs = {"rank": None, "dtype": None, "shape": None}