From f7f78ae236e623a57919f9450e8b2043e681ddb3 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 25 May 2022 15:26:38 +0100 Subject: Add support for uint16_t to RESCALE Update ref-model RESCALE op to support UINT16 conversions Add testing for RESCALE UINT16 and ERROR_IFs Signed-off-by: Jeremy Johnson Change-Id: Ic6e6e53de1f0b054bedb9e6ba3856e7475498aba --- verif/generator/tosa_test_gen.py | 49 ++++++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 10 deletions(-) (limited to 'verif/generator/tosa_test_gen.py') diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 7c2b9de..c9c6d7e 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -70,6 +70,8 @@ class TosaTestGen: return np.int32(self.rng.integers(low=0, high=256, size=shape)) elif dtype == DType.INT16: return np.int32(self.rng.integers(low=-32768, high=32768, size=shape)) + elif dtype == DType.UINT16: + return np.int32(self.rng.integers(low=0, high=65536, size=shape)) elif dtype == DType.INT32: return np.int32( self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape) @@ -169,6 +171,8 @@ class TosaTestGen: return "u8" elif t == DType.INT16: return "i16" + elif t == DType.UINT16: + return "u16" elif t == DType.INT32: return "i32" elif t == DType.INT48: @@ -188,6 +192,8 @@ class TosaTestGen: return 8 elif t == DType.INT16: return 16 + elif t == DType.UINT16: + return 16 elif t == DType.INT32: return 32 elif t == DType.INT48: @@ -1575,29 +1581,43 @@ class TosaTestGen: if val.dtype == DType.INT8: input_zp = self.randInt(-128, 128) - in_type_width = in_type_width + 1 + in_type_width += 1 elif val.dtype == DType.UINT8: input_zp = self.randInt(0, 256) - in_type_width = in_type_width + 1 - elif error_name == ErrorIf.InputZeroPointNotZero: + in_type_width += 1 + elif error_name in [ + ErrorIf.InputZeroPointNotZero, + ErrorIf.U16InputZeroPointNotValid, + ]: input_zp = self.randInt(-128, 128) if input_zp == 0: input_zp = input_zp + self.rng.integers(1, 10) - in_type_width = in_type_width + 1 + in_type_width += 1 + elif val.dtype == DType.UINT16: + # Must come after ErrorIf.U16InputZeroPointNotValid check + input_zp = self.rng.choice([0, 32768]) + in_type_width += 1 else: input_zp = 0 if out_dtype == DType.INT8: output_zp = self.randInt(-128, 128) - out_type_width = out_type_width + 1 + out_type_width += 1 elif out_dtype == DType.UINT8: output_zp = self.randInt(0, 256) - out_type_width = out_type_width + 1 - elif error_name == ErrorIf.OutputZeroPointNotZero: + out_type_width += 1 + elif error_name in [ + ErrorIf.OutputZeroPointNotZero, + ErrorIf.U16OutputZeroPointNotValid, + ]: output_zp = self.randInt(-128, 128) if output_zp == 0: output_zp = output_zp + self.rng.integers(1, 10) - out_type_width = out_type_width + 1 + out_type_width += 1 + elif out_dtype == DType.UINT16: + # Must come after ErrorIf.U16OutputZeroPointNotValid check + output_zp = self.rng.choice([0, 32768]) + out_type_width += 1 else: output_zp = 0 @@ -1631,7 +1651,7 @@ class TosaTestGen: # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp)) if scale32 and error_name is None: - # Make sure random values are within apply_scale_32 speicification + # Make sure random values are within apply_scale_32 specification # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2)) assert val.placeholderFilename values = np.load( @@ -3642,10 +3662,19 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agRescale, ), - "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48], + "types": [ + DType.UINT8, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.UINT16, + ], "error_if_validators": ( TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, + TosaErrorValidator.evU16InputZeroPointNotValid, + TosaErrorValidator.evU16OutputZeroPointNotValid, TosaErrorValidator.evScaleTrue, TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, -- cgit v1.2.1