aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py49
1 files changed, 39 insertions, 10 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 7c2b9de..c9c6d7e 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -70,6 +70,8 @@ class TosaTestGen:
return np.int32(self.rng.integers(low=0, high=256, size=shape))
elif dtype == DType.INT16:
return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
+ elif dtype == DType.UINT16:
+ return np.int32(self.rng.integers(low=0, high=65536, size=shape))
elif dtype == DType.INT32:
return np.int32(
self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
@@ -169,6 +171,8 @@ class TosaTestGen:
return "u8"
elif t == DType.INT16:
return "i16"
+ elif t == DType.UINT16:
+ return "u16"
elif t == DType.INT32:
return "i32"
elif t == DType.INT48:
@@ -188,6 +192,8 @@ class TosaTestGen:
return 8
elif t == DType.INT16:
return 16
+ elif t == DType.UINT16:
+ return 16
elif t == DType.INT32:
return 32
elif t == DType.INT48:
@@ -1575,29 +1581,43 @@ class TosaTestGen:
if val.dtype == DType.INT8:
input_zp = self.randInt(-128, 128)
- in_type_width = in_type_width + 1
+ in_type_width += 1
elif val.dtype == DType.UINT8:
input_zp = self.randInt(0, 256)
- in_type_width = in_type_width + 1
- elif error_name == ErrorIf.InputZeroPointNotZero:
+ in_type_width += 1
+ elif error_name in [
+ ErrorIf.InputZeroPointNotZero,
+ ErrorIf.U16InputZeroPointNotValid,
+ ]:
input_zp = self.randInt(-128, 128)
if input_zp == 0:
input_zp = input_zp + self.rng.integers(1, 10)
- in_type_width = in_type_width + 1
+ in_type_width += 1
+ elif val.dtype == DType.UINT16:
+ # Must come after ErrorIf.U16InputZeroPointNotValid check
+ input_zp = self.rng.choice([0, 32768])
+ in_type_width += 1
else:
input_zp = 0
if out_dtype == DType.INT8:
output_zp = self.randInt(-128, 128)
- out_type_width = out_type_width + 1
+ out_type_width += 1
elif out_dtype == DType.UINT8:
output_zp = self.randInt(0, 256)
- out_type_width = out_type_width + 1
- elif error_name == ErrorIf.OutputZeroPointNotZero:
+ out_type_width += 1
+ elif error_name in [
+ ErrorIf.OutputZeroPointNotZero,
+ ErrorIf.U16OutputZeroPointNotValid,
+ ]:
output_zp = self.randInt(-128, 128)
if output_zp == 0:
output_zp = output_zp + self.rng.integers(1, 10)
- out_type_width = out_type_width + 1
+ out_type_width += 1
+ elif out_dtype == DType.UINT16:
+ # Must come after ErrorIf.U16OutputZeroPointNotValid check
+ output_zp = self.rng.choice([0, 32768])
+ out_type_width += 1
else:
output_zp = 0
@@ -1631,7 +1651,7 @@ class TosaTestGen:
# print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
if scale32 and error_name is None:
- # Make sure random values are within apply_scale_32 speicification
+ # Make sure random values are within apply_scale_32 specification
# REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2))
assert val.placeholderFilename
values = np.load(
@@ -3642,10 +3662,19 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agRescale,
),
- "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
+ "types": [
+ DType.UINT8,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.UINT16,
+ ],
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evOutputZeroPointNotZero,
+ TosaErrorValidator.evU16InputZeroPointNotValid,
+ TosaErrorValidator.evU16OutputZeroPointNotValid,
TosaErrorValidator.evScaleTrue,
TosaErrorValidator.evScaleNotTrue,
TosaErrorValidator.evWrongInputType,