diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 49 |
1 files changed, 39 insertions, 10 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 7c2b9de..c9c6d7e 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -70,6 +70,8 @@ class TosaTestGen: return np.int32(self.rng.integers(low=0, high=256, size=shape)) elif dtype == DType.INT16: return np.int32(self.rng.integers(low=-32768, high=32768, size=shape)) + elif dtype == DType.UINT16: + return np.int32(self.rng.integers(low=0, high=65536, size=shape)) elif dtype == DType.INT32: return np.int32( self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape) @@ -169,6 +171,8 @@ class TosaTestGen: return "u8" elif t == DType.INT16: return "i16" + elif t == DType.UINT16: + return "u16" elif t == DType.INT32: return "i32" elif t == DType.INT48: @@ -188,6 +192,8 @@ class TosaTestGen: return 8 elif t == DType.INT16: return 16 + elif t == DType.UINT16: + return 16 elif t == DType.INT32: return 32 elif t == DType.INT48: @@ -1575,29 +1581,43 @@ class TosaTestGen: if val.dtype == DType.INT8: input_zp = self.randInt(-128, 128) - in_type_width = in_type_width + 1 + in_type_width += 1 elif val.dtype == DType.UINT8: input_zp = self.randInt(0, 256) - in_type_width = in_type_width + 1 - elif error_name == ErrorIf.InputZeroPointNotZero: + in_type_width += 1 + elif error_name in [ + ErrorIf.InputZeroPointNotZero, + ErrorIf.U16InputZeroPointNotValid, + ]: input_zp = self.randInt(-128, 128) if input_zp == 0: input_zp = input_zp + self.rng.integers(1, 10) - in_type_width = in_type_width + 1 + in_type_width += 1 + elif val.dtype == DType.UINT16: + # Must come after ErrorIf.U16InputZeroPointNotValid check + input_zp = self.rng.choice([0, 32768]) + in_type_width += 1 else: input_zp = 0 if out_dtype == DType.INT8: output_zp = self.randInt(-128, 128) - out_type_width = out_type_width + 1 + out_type_width += 1 elif out_dtype == DType.UINT8: output_zp = self.randInt(0, 256) - out_type_width = out_type_width + 1 - elif error_name == ErrorIf.OutputZeroPointNotZero: + out_type_width += 1 + elif error_name in [ + ErrorIf.OutputZeroPointNotZero, + ErrorIf.U16OutputZeroPointNotValid, + ]: output_zp = self.randInt(-128, 128) if output_zp == 0: output_zp = output_zp + self.rng.integers(1, 10) - out_type_width = out_type_width + 1 + out_type_width += 1 + elif out_dtype == DType.UINT16: + # Must come after ErrorIf.U16OutputZeroPointNotValid check + output_zp = self.rng.choice([0, 32768]) + out_type_width += 1 else: output_zp = 0 @@ -1631,7 +1651,7 @@ class TosaTestGen: # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp)) if scale32 and error_name is None: - # Make sure random values are within apply_scale_32 speicification + # Make sure random values are within apply_scale_32 specification # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2)) assert val.placeholderFilename values = np.load( @@ -3642,10 +3662,19 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agRescale, ), - "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48], + "types": [ + DType.UINT8, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.UINT16, + ], "error_if_validators": ( TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, + TosaErrorValidator.evU16InputZeroPointNotValid, + TosaErrorValidator.evU16OutputZeroPointNotValid, TosaErrorValidator.evScaleTrue, TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, |