aboutsummaryrefslogtreecommitdiff
path: root/verif/generator
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator')
-rw-r--r--verif/generator/tosa_arg_gen.py49
-rw-r--r--verif/generator/tosa_error_if.py138
-rw-r--r--verif/generator/tosa_test_gen.py49
-rw-r--r--verif/generator/tosa_utils.py10
4 files changed, 175 insertions, 71 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index b1f8942..a741efb 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1349,29 +1349,58 @@ class TosaArgGen:
arg_list = []
# Enumerate the output types here
- for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
+ for outDtype in [
+ DType.UINT8,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.UINT16,
+ ]:
if (
- dtype in [DType.UINT8, DType.INT8]
+ outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
and error_name == ErrorIf.OutputZeroPointNotZero
):
continue
if (
+ outDtype != DType.UINT16
+ and error_name == ErrorIf.U16OutputZeroPointNotValid
+ ) or (
+ inDtype != DType.UINT16
+ and error_name == ErrorIf.U16InputZeroPointNotValid
+ ):
+ # ErrorIfs only valid with UINT16
+ continue
+ if (
inDtype == DType.UINT8
- and dtype != DType.INT8
+ and outDtype not in [DType.INT8, DType.INT16]
+ and error_name != ErrorIf.WrongOutputType
+ ):
+ # The only output dtypes for UINT8 are INT8/INT16, skip all others
+ continue
+ if (
+ inDtype not in [DType.INT8, DType.INT16]
+ and outDtype == DType.UINT8
+ and error_name != ErrorIf.WrongOutputType
+ ):
+ # The only input dtypes for UINT8 are INT8/INT16, skip all others
+ continue
+ if (
+ inDtype == DType.UINT16
+ and outDtype != DType.INT16
and error_name != ErrorIf.WrongOutputType
):
- # The only output dtype for UINT8 is INT8, skip all other combinations
+ # The only output dtype for UINT16 is INT16, skip all others
continue
if (
- inDtype != DType.INT8
- and dtype == DType.UINT8
+ inDtype != DType.INT16
+ and outDtype == DType.UINT16
and error_name != ErrorIf.WrongOutputType
):
- # The only input dtype for UINT8 is INT8, skip all other combinations
+ # The only input dtype for UINT16 is INT16, skip all others
continue
if (
error_name == ErrorIf.WrongOutputType
- and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype)
+ and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
):
continue
@@ -1403,12 +1432,12 @@ class TosaArgGen:
arg_list.append(
(
"out{}_sc{}_dr{}_pc{}".format(
- DTypeNames[dtype],
+ DTypeNames[outDtype],
int(scale32),
int(double_round),
int(per_channel),
),
- [dtype, scale32, double_round, per_channel],
+ [outDtype, scale32, double_round, per_channel],
)
)
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index e7e758f..1900d8a 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -68,6 +68,8 @@ class ErrorIf(object):
InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
+ U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
+ U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
class TosaErrorIfArgGen:
@@ -227,14 +229,26 @@ class TosaErrorIfArgGen:
if input_dtype == DType.INT8:
if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
return True
- if input_dtype in [DType.INT16, DType.INT32]:
+ elif input_dtype == DType.INT16:
+ if output_dtype not in [
+ DType.UINT8,
+ DType.INT8,
+ DType.UINT16,
+ DType.INT16,
+ DType.INT32,
+ ]:
+ return True
+ elif input_dtype == DType.INT32:
if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
return True
elif input_dtype == DType.INT48:
if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
return True
elif input_dtype == DType.UINT8:
- if output_dtype != DType.INT8:
+ if output_dtype not in [DType.INT8, DType.INT16]:
+ return True
+ elif input_dtype == DType.UINT16:
+ if output_dtype != DType.INT16:
return True
return False
@@ -418,23 +432,9 @@ class TosaErrorValidator:
error_result = True
elif op["op"] == Op.RESCALE:
- if input_dtype == DType.INT8:
- if output_dtype not in [
- DType.UINT8,
- DType.INT8,
- DType.INT16,
- DType.INT32,
- ]:
- error_result = True
- if input_dtype in [DType.INT16, DType.INT32]:
- if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
- error_result = True
- elif input_dtype == DType.INT48:
- if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
- error_result = True
- elif input_dtype == DType.UINT8:
- if output_dtype != DType.INT8:
- error_result = True
+ error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
+ input_dtype, output_dtype
+ )
elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
if (
@@ -998,12 +998,25 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def _getZeroPoint(qinfo, index):
+ """Return zero point value from quantization info.
+
+ Generally input_zp is index 0, output_zp is index 1
+ """
+ if isinstance(qinfo, tuple):
+ zero_point = qinfo[index]
+ else:
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+ zero_point = qinfo.ints[index][1]
+ return zero_point
+
+ @staticmethod
def evInputZeroPointNotZero(check=False, **kwargs):
op = kwargs["op"]
error_result = False
# Quantizable types
- qTypes = (DType.INT8, DType.UINT8)
+ qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
# This does not apply to quantizable types
inputDtypes = [
@@ -1015,19 +1028,12 @@ class TosaErrorValidator:
if check:
input_dtype = kwargs["input_dtype"]
- if isinstance(kwargs["qinfo"], tuple):
- qinfo = kwargs["qinfo"]
- input_zero_point = qinfo[0]
- else:
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- qinfo = kwargs["qinfo"].ints
- input_zero_point = qinfo[0][1]
-
+ input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
if op["op"] == Op.MATMUL:
- qinfo = kwargs["qinfo"].ints
+ input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
for dtype, zp in (
- (kwargs["input_dtype"], qinfo[0][1]),
- (kwargs["input2_dtype"], qinfo[1][1]),
+ (kwargs["input_dtype"], input_zero_point),
+ (kwargs["input2_dtype"], input2_zero_point),
):
if dtype not in qTypes and zp != 0:
error_result = True
@@ -1059,9 +1065,7 @@ class TosaErrorValidator:
if check:
weight_dtype = kwargs["weight_dtype"]
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
- qinfo = kwargs["qinfo"].ints
- weight_zero_point = qinfo[1][1]
+ weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
if weight_dtype != DType.INT8 and weight_zero_point != 0:
error_result = True
@@ -1076,11 +1080,9 @@ class TosaErrorValidator:
@staticmethod
def evOutputZeroPointNotZero(check=False, **kwargs):
op = kwargs["op"]
- inputDtypes = op["types"].copy()
- if DType.INT8 in inputDtypes:
- inputDtypes.remove(DType.INT8)
- if DType.UINT8 in inputDtypes:
- inputDtypes.remove(DType.UINT8)
+ inputDtypes = [
+ t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
+ ]
error_name = ErrorIf.OutputZeroPointNotZero
param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
@@ -1090,18 +1092,13 @@ class TosaErrorValidator:
if check:
input_dtype = kwargs["input_dtype"]
output_dtype = kwargs["output_dtype"]
- if isinstance(kwargs["qinfo"], tuple):
- qinfo = kwargs["qinfo"]
- output_zero_point = qinfo[1]
- else:
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- qinfo = kwargs["qinfo"].ints
- output_zero_point = qinfo[1][1]
+ output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
if op["op"] == Op.AVG_POOL2D:
if input_dtype != DType.INT8 and output_zero_point != 0:
error_result = True
elif (
- output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0
+ output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
+ and output_zero_point != 0
):
error_result = True
@@ -1114,6 +1111,53 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evU16InputZeroPointNotValid(check=False, **kwargs):
+ error_name = ErrorIf.U16InputZeroPointNotValid
+ param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
+ error_result = False
+ error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
+
+ if check:
+ input_dtype = kwargs["input_dtype"]
+ input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
+ error_result = input_dtype == DType.UINT16 and input_zero_point not in [
+ 0,
+ 32768,
+ ]
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evU16OutputZeroPointNotValid(check=False, **kwargs):
+ error_name = ErrorIf.U16OutputZeroPointNotValid
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
+
+ if check:
+ output_dtype = kwargs["output_dtype"]
+ output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
+
+ error_result = output_dtype == DType.UINT16 and output_zero_point not in [
+ 0,
+ 32768,
+ ]
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
def evAxisSmallerZero(check=False, **kwargs):
error_name = ErrorIf.AxisSmallerZero
param_reqs = {"rank": None, "dtype": None, "shape": None}
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 7c2b9de..c9c6d7e 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -70,6 +70,8 @@ class TosaTestGen:
return np.int32(self.rng.integers(low=0, high=256, size=shape))
elif dtype == DType.INT16:
return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
+ elif dtype == DType.UINT16:
+ return np.int32(self.rng.integers(low=0, high=65536, size=shape))
elif dtype == DType.INT32:
return np.int32(
self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
@@ -169,6 +171,8 @@ class TosaTestGen:
return "u8"
elif t == DType.INT16:
return "i16"
+ elif t == DType.UINT16:
+ return "u16"
elif t == DType.INT32:
return "i32"
elif t == DType.INT48:
@@ -188,6 +192,8 @@ class TosaTestGen:
return 8
elif t == DType.INT16:
return 16
+ elif t == DType.UINT16:
+ return 16
elif t == DType.INT32:
return 32
elif t == DType.INT48:
@@ -1575,29 +1581,43 @@ class TosaTestGen:
if val.dtype == DType.INT8:
input_zp = self.randInt(-128, 128)
- in_type_width = in_type_width + 1
+ in_type_width += 1
elif val.dtype == DType.UINT8:
input_zp = self.randInt(0, 256)
- in_type_width = in_type_width + 1
- elif error_name == ErrorIf.InputZeroPointNotZero:
+ in_type_width += 1
+ elif error_name in [
+ ErrorIf.InputZeroPointNotZero,
+ ErrorIf.U16InputZeroPointNotValid,
+ ]:
input_zp = self.randInt(-128, 128)
if input_zp == 0:
input_zp = input_zp + self.rng.integers(1, 10)
- in_type_width = in_type_width + 1
+ in_type_width += 1
+ elif val.dtype == DType.UINT16:
+ # Must come after ErrorIf.U16InputZeroPointNotValid check
+ input_zp = self.rng.choice([0, 32768])
+ in_type_width += 1
else:
input_zp = 0
if out_dtype == DType.INT8:
output_zp = self.randInt(-128, 128)
- out_type_width = out_type_width + 1
+ out_type_width += 1
elif out_dtype == DType.UINT8:
output_zp = self.randInt(0, 256)
- out_type_width = out_type_width + 1
- elif error_name == ErrorIf.OutputZeroPointNotZero:
+ out_type_width += 1
+ elif error_name in [
+ ErrorIf.OutputZeroPointNotZero,
+ ErrorIf.U16OutputZeroPointNotValid,
+ ]:
output_zp = self.randInt(-128, 128)
if output_zp == 0:
output_zp = output_zp + self.rng.integers(1, 10)
- out_type_width = out_type_width + 1
+ out_type_width += 1
+ elif out_dtype == DType.UINT16:
+ # Must come after ErrorIf.U16OutputZeroPointNotValid check
+ output_zp = self.rng.choice([0, 32768])
+ out_type_width += 1
else:
output_zp = 0
@@ -1631,7 +1651,7 @@ class TosaTestGen:
# print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
if scale32 and error_name is None:
- # Make sure random values are within apply_scale_32 speicification
+ # Make sure random values are within apply_scale_32 specification
# REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2))
assert val.placeholderFilename
values = np.load(
@@ -3642,10 +3662,19 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agRescale,
),
- "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
+ "types": [
+ DType.UINT8,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.UINT16,
+ ],
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evOutputZeroPointNotZero,
+ TosaErrorValidator.evU16InputZeroPointNotValid,
+ TosaErrorValidator.evU16OutputZeroPointNotValid,
TosaErrorValidator.evScaleTrue,
TosaErrorValidator.evScaleNotTrue,
TosaErrorValidator.evWrongInputType,
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index ca115a2..a4ef31a 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -59,9 +59,11 @@ def allDTypes(*, excludes=None):
def usableDTypes(*, excludes=None):
"""Get a set of usable DType values, optionally excluding some values.
- Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes
- specified by the caller, as the serializer lib does not support them.
- If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead.
+ Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
+ addition to the excludes specified by the caller, as the serializer lib
+ does not support them.
+ If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
+ instead.
Args:
excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
@@ -69,7 +71,7 @@ def usableDTypes(*, excludes=None):
Returns:
A set of DType values
"""
- omit = {DType.UNKNOWN, DType.UINT8}
+ omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16}
omit.update(excludes if excludes else ())
return allDTypes(excludes=omit)