From f7f78ae236e623a57919f9450e8b2043e681ddb3 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 25 May 2022 15:26:38 +0100 Subject: Add support for uint16_t to RESCALE Update ref-model RESCALE op to support UINT16 conversions Add testing for RESCALE UINT16 and ERROR_IFs Signed-off-by: Jeremy Johnson Change-Id: Ic6e6e53de1f0b054bedb9e6ba3856e7475498aba --- verif/generator/tosa_error_if.py | 138 ++++++++++++++++++++++++++------------- 1 file changed, 91 insertions(+), 47 deletions(-) (limited to 'verif/generator/tosa_error_if.py') diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index e7e758f..1900d8a 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -68,6 +68,8 @@ class ErrorIf(object): InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch" InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch" CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool" + U16InputZeroPointNotValid = "U16InputZeroPointNotValid" + U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid" class TosaErrorIfArgGen: @@ -227,14 +229,26 @@ class TosaErrorIfArgGen: if input_dtype == DType.INT8: if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]: return True - if input_dtype in [DType.INT16, DType.INT32]: + elif input_dtype == DType.INT16: + if output_dtype not in [ + DType.UINT8, + DType.INT8, + DType.UINT16, + DType.INT16, + DType.INT32, + ]: + return True + elif input_dtype == DType.INT32: if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]: return True elif input_dtype == DType.INT48: if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]: return True elif input_dtype == DType.UINT8: - if output_dtype != DType.INT8: + if output_dtype not in [DType.INT8, DType.INT16]: + return True + elif input_dtype == DType.UINT16: + if output_dtype != DType.INT16: return True return False @@ -418,23 +432,9 @@ class TosaErrorValidator: error_result = True elif op["op"] == Op.RESCALE: - if input_dtype == DType.INT8: - if output_dtype not in [ - DType.UINT8, - DType.INT8, - DType.INT16, - DType.INT32, - ]: - error_result = True - if input_dtype in [DType.INT16, DType.INT32]: - if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]: - error_result = True - elif input_dtype == DType.INT48: - if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]: - error_result = True - elif input_dtype == DType.UINT8: - if output_dtype != DType.INT8: - error_result = True + error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType( + input_dtype, output_dtype + ) elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]: if ( @@ -997,13 +997,26 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def _getZeroPoint(qinfo, index): + """Return zero point value from quantization info. + + Generally input_zp is index 0, output_zp is index 1 + """ + if isinstance(qinfo, tuple): + zero_point = qinfo[index] + else: + # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp + zero_point = qinfo.ints[index][1] + return zero_point + @staticmethod def evInputZeroPointNotZero(check=False, **kwargs): op = kwargs["op"] error_result = False # Quantizable types - qTypes = (DType.INT8, DType.UINT8) + qTypes = (DType.INT8, DType.UINT8, DType.UINT16) # This does not apply to quantizable types inputDtypes = [ @@ -1015,19 +1028,12 @@ class TosaErrorValidator: if check: input_dtype = kwargs["input_dtype"] - if isinstance(kwargs["qinfo"], tuple): - qinfo = kwargs["qinfo"] - input_zero_point = qinfo[0] - else: - # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp - qinfo = kwargs["qinfo"].ints - input_zero_point = qinfo[0][1] - + input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0) if op["op"] == Op.MATMUL: - qinfo = kwargs["qinfo"].ints + input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1) for dtype, zp in ( - (kwargs["input_dtype"], qinfo[0][1]), - (kwargs["input2_dtype"], qinfo[1][1]), + (kwargs["input_dtype"], input_zero_point), + (kwargs["input2_dtype"], input2_zero_point), ): if dtype not in qTypes and zp != 0: error_result = True @@ -1059,9 +1065,7 @@ class TosaErrorValidator: if check: weight_dtype = kwargs["weight_dtype"] - # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp - qinfo = kwargs["qinfo"].ints - weight_zero_point = qinfo[1][1] + weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1) if weight_dtype != DType.INT8 and weight_zero_point != 0: error_result = True @@ -1076,11 +1080,9 @@ class TosaErrorValidator: @staticmethod def evOutputZeroPointNotZero(check=False, **kwargs): op = kwargs["op"] - inputDtypes = op["types"].copy() - if DType.INT8 in inputDtypes: - inputDtypes.remove(DType.INT8) - if DType.UINT8 in inputDtypes: - inputDtypes.remove(DType.UINT8) + inputDtypes = [ + t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16] + ] error_name = ErrorIf.OutputZeroPointNotZero param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None} @@ -1090,18 +1092,13 @@ class TosaErrorValidator: if check: input_dtype = kwargs["input_dtype"] output_dtype = kwargs["output_dtype"] - if isinstance(kwargs["qinfo"], tuple): - qinfo = kwargs["qinfo"] - output_zero_point = qinfo[1] - else: - # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp - qinfo = kwargs["qinfo"].ints - output_zero_point = qinfo[1][1] + output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1) if op["op"] == Op.AVG_POOL2D: if input_dtype != DType.INT8 and output_zero_point != 0: error_result = True elif ( - output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0 + output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16] + and output_zero_point != 0 ): error_result = True @@ -1113,6 +1110,53 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evU16InputZeroPointNotValid(check=False, **kwargs): + error_name = ErrorIf.U16InputZeroPointNotValid + param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None} + error_result = False + error_reason = "Input DType is UINT16 and zero point not 0 or 32678" + + if check: + input_dtype = kwargs["input_dtype"] + input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0) + error_result = input_dtype == DType.UINT16 and input_zero_point not in [ + 0, + 32768, + ] + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + + @staticmethod + def evU16OutputZeroPointNotValid(check=False, **kwargs): + error_name = ErrorIf.U16OutputZeroPointNotValid + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Output DType is UINT16 and zero point not 0 or 32678" + + if check: + output_dtype = kwargs["output_dtype"] + output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1) + + error_result = output_dtype == DType.UINT16 and output_zero_point not in [ + 0, + 32768, + ] + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + @staticmethod def evAxisSmallerZero(check=False, **kwargs): error_name = ErrorIf.AxisSmallerZero -- cgit v1.2.1