From b5fabec33abeca2d92c20c7b094fa3f113d0ddd8 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 7 Jun 2022 05:20:44 +0000 Subject: Remove quantization info from serialization attributes Any needed information moves into the attributes for each operator. New serialization library version removes teh quantization information attributes from the schema Signed-off-by: Eric Kunze Change-Id: Icf6165687ab1fd34a01f64c01b0b92b2820e72fa --- verif/generator/tosa_test_gen.py | 103 +++++++++++++++++++++------------------ 1 file changed, 55 insertions(+), 48 deletions(-) (limited to 'verif/generator/tosa_test_gen.py') diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 262a652..b0e7c8c 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -216,20 +216,19 @@ class TosaTestGen: # build_placeholder returns an int, ABS/other ops does not if isinstance(op, int): - self.ser.addOperator(op, a.name, result_tens.name, None, qinfo) + self.ser.addOperator(op, a.name, result_tens.name, None) return result_tens elif op["op"] == Op.IDENTITY: - self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo) + self.ser.addOperator(op["op"], a.name, result_tens.name, None) return result_tens # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongOutputType: if result_tens.dtype not in [DType.INT8, DType.UINT8]: - qinfo = ts.TosaSerializerQuantInfo() - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(self, a.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, a.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error if checks. input_list = [a.name] @@ -255,7 +254,12 @@ class TosaTestGen: ): return None - self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) + attr = None + if op["op"] == Op.NEGATE: + attr = ts.TosaSerializerAttribute() + attr.NegateAttribute(qinfo[0], qinfo[1]) + + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None): @@ -542,11 +546,10 @@ class TosaTestGen: # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongInputType: if input.dtype not in [DType.INT8, DType.UINT8]: - qinfo = ts.TosaSerializerQuantInfo() - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(self, input.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, input.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error if checks. input_list = [input.name] @@ -577,10 +580,13 @@ class TosaTestGen: ): return None + if qinfo is None: + qinfo = [0, 0] + attr = ts.TosaSerializerAttribute() - attr.PoolAttribute(kernel, stride, pad) + attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_conv2d( @@ -606,11 +612,10 @@ class TosaTestGen: DType.INT8, DType.UINT8, ): - qinfo = ts.TosaSerializerQuantInfo() - qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, ifm.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] @@ -642,9 +647,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_conv3d( @@ -670,11 +675,10 @@ class TosaTestGen: DType.INT8, DType.UINT8, ): - qinfo = ts.TosaSerializerQuantInfo() - qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, ifm.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] @@ -706,9 +710,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_transpose_conv2d( @@ -734,11 +738,10 @@ class TosaTestGen: DType.INT8, DType.UINT8, ): - qinfo = ts.TosaSerializerQuantInfo() - qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, ifm.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] @@ -769,9 +772,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.TransposeConvAttribute(out_pad, stride, output_shape) + attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_depthwise_conv2d( @@ -796,11 +799,10 @@ class TosaTestGen: DType.INT8, DType.UINT8, ): - qinfo = ts.TosaSerializerQuantInfo() - qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, ifm.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] @@ -832,9 +834,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_fully_connected( @@ -871,7 +873,10 @@ class TosaTestGen: ): return None - self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) + attr = ts.TosaSerializerAttribute() + attr.FullyConnectedAttribute(qinfo[0], qinfo[1]) + + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None): @@ -905,7 +910,10 @@ class TosaTestGen: ): return None - self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) + attr = ts.TosaSerializerAttribute() + attr.MatMulAttribute(qinfo[0], qinfo[1]) + + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_reduce(self, op, a, axis, validator_fcns, error_name=None): @@ -1164,7 +1172,7 @@ class TosaTestGen: ): return None - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None): @@ -2212,7 +2220,7 @@ class TosaTestGen: else: qinfo = None - tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, qinfo, error_name) + tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name) try: if error_if_validators is None: @@ -3425,7 +3433,6 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agPad, ), - "qgen": TosaQuantGen.qgPad, "types": TYPE_FIB, "error_if_validators": ( TosaErrorValidator.evWrongInputType, -- cgit v1.2.1