diff options
author | James Ward <james.ward@arm.com> | 2022-08-12 20:48:56 +0100 |
---|---|---|
committer | James Ward <james.ward@arm.com> | 2022-10-11 11:56:02 +0100 |
commit | 8b39043c70332e1e4c95ee6a9616aec40dd3baf1 (patch) | |
tree | fea519246b698eb944b9d58537fc90bc30481d11 /verif/generator/tosa_test_gen.py | |
parent | ba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (diff) | |
download | reference_model-8b39043c70332e1e4c95ee6a9616aec40dd3baf1.tar.gz |
Reference model changes for fp16 support
Change-Id: I72f21fcfa153046274969d327313e3349981dbe6
Signed-off-by: James Ward <james.ward@arm.com>
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 298 |
1 files changed, 176 insertions, 122 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index b76b656..9ff6ec5 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -81,6 +81,8 @@ class TosaTestGen: return np.int64( self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape) ) + elif dtype == DType.FP16: + return np.float16(self.rng.random(size=shape)) elif dtype == DType.FLOAT: return np.float32(self.rng.random(size=shape)) else: @@ -128,6 +130,9 @@ class TosaTestGen: def getRandNumberDType(self, dtype): if dtype == DType.FLOAT: return self.rng.random() + elif dtype == DType.FP16: + rand_f32 = self.rng.random() + return np.float16(rand_f32) elif dtype == DType.BOOL: return self.rng.choice([False, True]) # TOSA specific INT4 weight range from -7 to 7 @@ -178,13 +183,15 @@ class TosaTestGen: return "i32" elif t == DType.INT48: return "i48" + elif t == DType.FP16: + return "f16" elif t == DType.FLOAT: return "float" else: raise Exception("Unknown dtype, cannot convert to string: {}".format(t)) def typeWidth(self, t): - """Get the datatype width for integer types""" + """Get the datatype width for data types""" if t == DType.INT4: return 4 elif t == DType.INT8: @@ -199,6 +206,8 @@ class TosaTestGen: return 32 elif t == DType.INT48: return 48 + elif t == DType.FP16: + return 16 elif t == DType.FLOAT: return 32 elif t == DType.BOOL: @@ -346,7 +355,7 @@ class TosaTestGen: # Special for multiply: # Force the result to INT32 for INT types - if a.dtype != DType.FLOAT: + if a.dtype not in (DType.FP16, DType.FLOAT): result_tens.setDtype(DType.INT32) if error_name == ErrorIf.WrongOutputType: all_dtypes = [DType.INT8, DType.INT16, DType.INT48] @@ -533,6 +542,7 @@ class TosaTestGen: self, op, input, + accum_dtype, stride, pad, kernel, @@ -585,17 +595,43 @@ class TosaTestGen: qinfo = [0, 0] attr = ts.TosaSerializerAttribute() - attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1]) + attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens + def build_maxpool2d( + self, + op, + input, + stride, + pad, + kernel, + validator_fcns=None, + error_name=None, + qinfo=None, + ): + # Same as build_pool2d but manually sets accum_dtype value + # (maxpool has no accum_dtype) + return self.build_pool2d( + op, + input, + DType.UNKNOWN, + stride, + pad, + kernel, + validator_fcns, + error_name, + qinfo, + ) + def build_conv2d( self, op, ifm, filter, bias, + accum_dtype, strides, padding, dilations, @@ -605,7 +641,15 @@ class TosaTestGen: ): assert len(padding) == 4 result_tens = OutputShaper.conv2dOp( - self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name + self.ser, + self.rng, + ifm, + filter, + accum_dtype, + strides, + padding, + dilations, + error_name, ) # Ensure new output type has correct qinfo @@ -648,7 +692,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -659,6 +703,7 @@ class TosaTestGen: ifm, filter, bias, + accum_dtype, strides, padding, dilations, @@ -668,7 +713,15 @@ class TosaTestGen: ): assert len(padding) == 6 result_tens = OutputShaper.conv3dOp( - self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name + self.ser, + self.rng, + ifm, + filter, + accum_dtype, + strides, + padding, + dilations, + error_name, ) # Ensure new output type has correct qinfo @@ -711,7 +764,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -722,6 +775,7 @@ class TosaTestGen: ifm, filter, bias, + accum_dtype, stride, out_pad, output_shape, @@ -731,7 +785,7 @@ class TosaTestGen: ): assert len(out_pad) == 4 result_tens = OutputShaper.transposeConv2DOp( - self.ser, self.rng, ifm, output_shape, error_name + self.ser, self.rng, ifm, output_shape, accum_dtype, error_name ) # Ensure new output type has correct qinfo @@ -773,7 +827,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1]) + attr.TransposeConvAttribute( + out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype + ) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -784,6 +840,7 @@ class TosaTestGen: ifm, filter, bias, + accum_dtype, strides, padding, dilations, @@ -792,7 +849,15 @@ class TosaTestGen: qinfo=None, ): result_tens = OutputShaper.depthwiseConv2dOp( - self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name + self.ser, + self.rng, + ifm, + filter, + accum_dtype, + strides, + padding, + dilations, + error_name, ) # Ensure new output type has correct qinfo @@ -835,16 +900,24 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_fully_connected( - self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None + self, + op, + ifm, + filter, + bias, + accum_dtype, + validator_fcns=None, + error_name=None, + qinfo=None, ): result_tens = OutputShaper.fullyConnectedOp( - self.ser, self.rng, ifm, filter, error_name + self.ser, self.rng, ifm, filter, accum_dtype, error_name ) # Invalidate Input/Output list for error if checks. @@ -871,17 +944,22 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + accum_dtype=accum_dtype, ): return None attr = ts.TosaSerializerAttribute() - attr.FullyConnectedAttribute(qinfo[0], qinfo[1]) + attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens - def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None): - result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name) + def build_matmul( + self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None + ): + result_tens = OutputShaper.matmulOp( + self.ser, self.rng, a, b, accum_dtype, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] @@ -908,11 +986,12 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + accum_dtype=accum_dtype, ): return None attr = ts.TosaSerializerAttribute() - attr.MatMulAttribute(qinfo[0], qinfo[1]) + attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -995,7 +1074,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - if a.dtype == DType.FLOAT: + if a.dtype in (DType.FP16, DType.FLOAT): attr.ClampAttribute(0, 0, min_val, max_val) else: attr.ClampAttribute(min_val, max_val, 0, 0) @@ -1811,7 +1890,7 @@ class TosaTestGen: op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr ) - if a.dtype in (DType.FLOAT, DType.INT32): + if a.dtype in (DType.FLOAT, DType.FP16, DType.INT32): then_op, else_op = Op.ADD, Op.SUB elif a.dtype in (DType.INT8, DType.INT16): then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT @@ -2350,22 +2429,37 @@ class TosaTestGen: # if not specified, defaults to (1, 4) # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum) # 'types': array of datatypes to be tested - TYPE_FP = [DType.FLOAT] + TYPE_FP = [DType.FLOAT, DType.FP16] TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4 - TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4 + TYPE_INT_FP = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP16, + DType.FLOAT, + ] # Excludes INT4 TYPE_BOOL = [DType.BOOL] - TYPE_FI32 = [DType.FLOAT, DType.INT32] - TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL] + TYPE_FI32 = [DType.FLOAT, DType.FP16, DType.INT32] # floating-types and INT32 + TYPE_FIB = [ + DType.FP16, + DType.FLOAT, + DType.INT8, + DType.INT16, + DType.INT32, + DType.BOOL, + ] TYPE_FI16 = [DType.FLOAT, DType.INT16] - TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT] + TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] TYPE_CONV = [ [DType.INT8, DType.INT4, DType.INT32], [DType.INT8, DType.INT8, DType.INT32], [DType.INT16, DType.INT8, DType.INT48], + [DType.FP16, DType.FP16, DType.FP16], + [DType.FP16, DType.FP16, DType.FLOAT], DType.FLOAT, ] @@ -2524,7 +2618,7 @@ class TosaTestGen: build_fully_connected, TosaTensorGen.tgFullyConnected, TosaTensorValuesGen.tvgDefault, - None, + TosaArgGen.agFullyConnected, ), "qgen": TosaQuantGen.qgConv, "types": TYPE_CONV, @@ -2546,7 +2640,7 @@ class TosaTestGen: build_matmul, TosaTensorGen.tgMatmul, TosaTensorValuesGen.tvgDefault, - None, + TosaArgGen.agMatMul, ), "qgen": TosaQuantGen.qgMatmul, "types": TYPE_NARROW_INT_FP, @@ -2564,7 +2658,7 @@ class TosaTestGen: "operands": (1, 0), "rank": (4, 4), "build_fcn": ( - build_pool2d, + build_maxpool2d, TosaTensorGen.tgNHWC, TosaTensorValuesGen.tvgDefault, TosaArgGen.agPooling, @@ -3384,7 +3478,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgReduceSum, TosaArgGen.agAxis, ), - "types": TYPE_FI32, + "types": (DType.FP16, DType.FLOAT, DType.INT32), "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -3571,7 +3665,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, None, ), - "types": TYPE_INT_FP, + "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FLOAT), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -3612,7 +3706,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agResize, ), - "types": [DType.INT8, DType.INT16, DType.FLOAT], + "types": (DType.INT8, DType.INT16, DType.FP16, DType.FLOAT), "invalid_test_validators": ( TosaInvalidValidator.ivWrongDataTypeOrModeResize, ), @@ -3646,7 +3740,14 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agCast, ), - "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL], + "types": ( + DType.FP16, + DType.FLOAT, + DType.INT8, + DType.INT16, + DType.INT32, + DType.BOOL, + ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -3925,7 +4026,9 @@ class OutputShaper: return ser.addOutput(shape, outputDType) @staticmethod - def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None): + def conv2dOp( + ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None + ): # IFM: NHWC # Filter: OHWI @@ -3958,26 +4061,26 @@ class OutputShaper: ofm_shape = [ifm.shape[0], h, w, filter.shape[0]] - if ifm.dtype == DType.INT8: - out_dtype = DType.INT32 - elif ifm.dtype == DType.INT16: - out_dtype = DType.INT48 - elif ifm.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: + if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception(f"Unsupported input dtype: {ifm.dtype}") + out_dtype = accum_dtype if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) + if ifm.dtype == DType.FP16: + excludes = [DType.FP16, DType.FLOAT] + else: + excludes = [out_dtype] + wrong_dtypes = list(usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @staticmethod - def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None): + def conv3dOp( + ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None + ): # IFM: NDHWC # Filter: ODHWI @@ -4020,27 +4123,25 @@ class OutputShaper: ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]] - if ifm.dtype == DType.INT8: - out_dtype = DType.INT32 - elif ifm.dtype == DType.INT16: - out_dtype = DType.INT48 - elif ifm.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: + if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception(f"Unsupported input dtype: {ifm.dtype}") + out_dtype = accum_dtype if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) + if ifm.dtype == DType.FP16: + excludes = [DType.FP16, DType.FLOAT] + else: + excludes = [out_dtype] + wrong_dtypes = list(usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @staticmethod def depthwiseConv2dOp( - ser, rng, ifm, filter, strides, padding, dilations, error_name=None + ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None ): # IFM: NHWC # Filter: HWCM @@ -4073,20 +4174,18 @@ class OutputShaper: ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]] - if ifm.dtype == DType.INT8: - out_dtype = DType.INT32 - elif ifm.dtype == DType.INT16: - out_dtype = DType.INT48 - elif ifm.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: + if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception(f"Unsupported input dtype: {ifm.dtype}") + out_dtype = accum_dtype if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) + if ifm.dtype == DType.FP16: + excludes = [DType.FP16, DType.FLOAT] + else: + excludes = [out_dtype] + wrong_dtypes = list(usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @@ -4119,6 +4218,7 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FLOAT, + DType.FP16, ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4128,55 +4228,20 @@ class OutputShaper: return ser.addOutput(ofm_shape, outputDType) @staticmethod - def fullyConnectedOp(ser, rng, input, filter, error_name=None): + def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None): # input: N, IC # filter: OC, IC # output: N, OC output_shape = [input.shape[0], filter.shape[0]] - if error_name == ErrorIf.WrongOutputType: - if input.dtype == DType.INT8: - incorrect_types = ( - DType.INT4, - DType.INT8, - DType.INT16, - DType.INT48, - DType.FLOAT, - ) - elif input.dtype == DType.INT16: - incorrect_types = ( - DType.INT4, - DType.INT8, - DType.INT16, - DType.INT32, - DType.FLOAT, - ) - elif input.dtype == DType.FLOAT: - incorrect_types = ( - DType.INT4, - DType.INT8, - DType.INT16, - DType.INT32, - DType.INT48, - ) - out_dtype = rng.choice(a=incorrect_types) - elif input.dtype == DType.INT8: - out_dtype = DType.INT32 - elif input.dtype == DType.INT16: - out_dtype = DType.INT48 - elif input.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: - # Pick some potentially correct output dtype if input type is incorrect - out_dtype = DType.INT32 - else: - raise Exception("Unsupported input dtype: {}".format(input.dtype)) + # Validated in arg_gen (also invalidated for ErrorIf) + out_dtype = accum_dtype return ser.addOutput(output_shape, out_dtype) @staticmethod - def matmulOp(ser, rng, a, b, error_name=None): + def matmulOp(ser, rng, a, b, accum_dtype, error_name=None): # a: N, H, C # b: N, C, W # out: N, H, W @@ -4200,7 +4265,7 @@ class OutputShaper: DType.INT32, DType.FLOAT, ) - elif a.dtype == DType.FLOAT: + elif a.dtype == DType.FLOAT or a.dtype == DType.FP16: incorrect_types = ( DType.INT4, DType.INT8, @@ -4209,17 +4274,11 @@ class OutputShaper: DType.INT48, ) out_dtype = rng.choice(a=incorrect_types) - elif a.dtype == DType.INT8: - out_dtype = DType.INT32 - elif a.dtype == DType.INT16: - out_dtype = DType.INT48 - elif a.dtype == DType.FLOAT: - out_dtype = DType.FLOAT elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype)) + out_dtype = accum_dtype # Validated in arg_gen return ser.addOutput(output_shape, out_dtype) @@ -4269,10 +4328,6 @@ class OutputShaper: bad_dim = rng.choice(range(len(output_shape))) output_shape[bad_dim] -= rng.choice([1, 2]) - # Fix negative output shape if error_if test causes it - if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1: - output_shape = [i if i >= 1 else 1 for i in output_shape] - if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, @@ -4280,6 +4335,7 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FLOAT, + DType.FP16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4546,7 +4602,7 @@ class OutputShaper: return ser.addOutput(val.shape, out_dtype) @staticmethod - def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None): + def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None): if error_name == ErrorIf.ConvOutputShapeMismatch: choices = [1, 2, 3] change = rng.choice(choices) @@ -4555,20 +4611,18 @@ class OutputShaper: if change in [2, 3]: output_shape[2] = output_shape[2] + rng.choice(choices) - if ifm.dtype == DType.INT8: - out_dtype = DType.INT32 - elif ifm.dtype == DType.INT16: - out_dtype = DType.INT48 - elif ifm.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: + if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception(f"Unsupported input dtype: {ifm.dtype}") + out_dtype = accum_dtype if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) + if ifm.dtype == DType.FP16: + excludes = [DType.FP16, DType.FLOAT] + else: + excludes = [out_dtype] + wrong_dtypes = list(usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(output_shape, out_dtype) |