From b5fabec33abeca2d92c20c7b094fa3f113d0ddd8 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 7 Jun 2022 05:20:44 +0000 Subject: Remove quantization info from serialization attributes Any needed information moves into the attributes for each operator. New serialization library version removes teh quantization information attributes from the schema Signed-off-by: Eric Kunze Change-Id: Icf6165687ab1fd34a01f64c01b0b92b2820e72fa --- verif/generator/tosa_arg_gen.py | 119 +++++++++++++++++---------------------- verif/generator/tosa_error_if.py | 7 +-- verif/generator/tosa_test_gen.py | 103 +++++++++++++++++---------------- 3 files changed, 109 insertions(+), 120 deletions(-) (limited to 'verif') diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index a27d849..8e00fab 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -4,7 +4,6 @@ import itertools import math import numpy as np -import serializer.tosa_serializer as ts from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen from serializer.tosa_serializer import DTypeNames @@ -26,7 +25,7 @@ class TosaQuantGen: pass @staticmethod - def getQinfo(testGen, dtype, error_name=None): + def getZeroPoint(testGen, dtype, error_name=None): if dtype == DType.INT8: return testGen.randInt(-128, 128) @@ -45,27 +44,25 @@ class TosaQuantGen: @staticmethod def qgUnary(testGen, op, dtype, error_name=None): - qinfo = ts.TosaSerializerQuantInfo() if error_name == ErrorIf.InputZeroPointNotZero: - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype, error_name), - TosaQuantGen.getQinfo(testGen, dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype, error_name), + TosaQuantGen.getZeroPoint(testGen, dtype), + ] elif error_name == ErrorIf.OutputZeroPointNotZero: - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype), - TosaQuantGen.getQinfo(testGen, dtype, error_name), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype), + TosaQuantGen.getZeroPoint(testGen, dtype, error_name), + ] else: - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype), - TosaQuantGen.getQinfo(testGen, dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype), + TosaQuantGen.getZeroPoint(testGen, dtype), + ] return qinfo @staticmethod def qgConv(testGen, op, dtype_or_dtypeList, error_name=None): - qinfo = ts.TosaSerializerQuantInfo() if isinstance(dtype_or_dtypeList, list): # a list of [input, weights, accumulator] dtypes dtypeList = dtype_or_dtypeList @@ -74,40 +71,34 @@ class TosaQuantGen: dtypeList = [dtype_or_dtypeList] * 3 if error_name == ErrorIf.InputZeroPointNotZero: - input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name) - weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1]) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name), + TosaQuantGen.getZeroPoint(testGen, dtypeList[1]), + ] elif error_name == ErrorIf.WeightZeroPointNotZero: - input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0]) - weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtypeList[0]), + TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name), + ] else: - input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0]) - weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1]) - - qinfo.ConvQuantInfo(input_zp, weights_zp) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtypeList[0]), + TosaQuantGen.getZeroPoint(testGen, dtypeList[1]), + ] return qinfo @staticmethod def qgMatmul(testGen, op, dtype, error_name=None): - qinfo = ts.TosaSerializerQuantInfo() if error_name == ErrorIf.InputZeroPointNotZero: - qinfo.MatMulQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype, error_name), - TosaQuantGen.getQinfo(testGen, dtype, error_name), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype, error_name), + TosaQuantGen.getZeroPoint(testGen, dtype, error_name), + ] else: - qinfo.MatMulQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype), - TosaQuantGen.getQinfo(testGen, dtype), - ) - return qinfo - - @staticmethod - def qgPad(testGen, op, dtype, error_name=None): - qinfo = ts.TosaSerializerQuantInfo() - if error_name == ErrorIf.InputZeroPointNotZero: - qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name)) - else: - qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype)) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype), + TosaQuantGen.getZeroPoint(testGen, dtype), + ] return qinfo @staticmethod @@ -550,7 +541,7 @@ class TosaTensorValuesGen: pass @staticmethod - def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None): pCount, cCount = op["operands"] tens = [] @@ -562,7 +553,7 @@ class TosaTensorValuesGen: return tens @staticmethod - def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32 and error_name is None: pCount, cCount = op["operands"] assert ( @@ -582,11 +573,11 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32 and error_name is None: # Make sure the operation does not cause value saturation - where # the number wraps due to limited number of bits to store the answer @@ -651,12 +642,12 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgCondIfWhileLoop( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None + testGen, op, dtypeList, shapeList, testArgs, error_name=None ): if dtypeList[0] in ( DType.INT32, @@ -689,12 +680,12 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgArithmeticRightShift( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None + testGen, op, dtypeList, shapeList, testArgs, error_name=None ): pCount, cCount = op["operands"] # Force value of operand[1] to be within [0, num_bits] @@ -722,16 +713,16 @@ class TosaTensorValuesGen: return placeholders @staticmethod - def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None): # Set datatype of condition tensor to boolean dtypeList[0] = DType.BOOL return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if error_name is None: pCount, cCount = op["operands"] assert ( @@ -765,11 +756,11 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgMul(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if error_name is None: pCount, cCount = op["operands"] assert ( @@ -839,11 +830,11 @@ class TosaTensorValuesGen: return tens else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None): count = len(shapeList) - testGen.args.num_const_inputs_concat if count < 1: count = 1 @@ -866,9 +857,7 @@ class TosaTensorValuesGen: return tens @staticmethod - def tvgLogicalShift( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None - ): + def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None): pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 @@ -886,7 +875,7 @@ class TosaTensorValuesGen: return placeholders @staticmethod - def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if error_name is None: pCount, cCount = op["operands"] assert ( @@ -924,13 +913,11 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgReduceSum( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None - ): + def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32: pCount, cCount = op["operands"] assert ( @@ -949,7 +936,7 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 1967d8a..b331a42 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -1003,12 +1003,7 @@ class TosaErrorValidator: Generally input_zp is index 0, output_zp is index 1 """ - if isinstance(qinfo, tuple): - zero_point = qinfo[index] - else: - # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp - zero_point = qinfo.ints[index][1] - return zero_point + return qinfo[index] @staticmethod def evInputZeroPointNotZero(check=False, **kwargs): diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 262a652..b0e7c8c 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -216,20 +216,19 @@ class TosaTestGen: # build_placeholder returns an int, ABS/other ops does not if isinstance(op, int): - self.ser.addOperator(op, a.name, result_tens.name, None, qinfo) + self.ser.addOperator(op, a.name, result_tens.name, None) return result_tens elif op["op"] == Op.IDENTITY: - self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo) + self.ser.addOperator(op["op"], a.name, result_tens.name, None) return result_tens # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongOutputType: if result_tens.dtype not in [DType.INT8, DType.UINT8]: - qinfo = ts.TosaSerializerQuantInfo() - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(self, a.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, a.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error if checks. input_list = [a.name] @@ -255,7 +254,12 @@ class TosaTestGen: ): return None - self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) + attr = None + if op["op"] == Op.NEGATE: + attr = ts.TosaSerializerAttribute() + attr.NegateAttribute(qinfo[0], qinfo[1]) + + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None): @@ -542,11 +546,10 @@ class TosaTestGen: # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongInputType: if input.dtype not in [DType.INT8, DType.UINT8]: - qinfo = ts.TosaSerializerQuantInfo() - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(self, input.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, input.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error if checks. input_list = [input.name] @@ -577,10 +580,13 @@ class TosaTestGen: ): return None + if qinfo is None: + qinfo = [0, 0] + attr = ts.TosaSerializerAttribute() - attr.PoolAttribute(kernel, stride, pad) + attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_conv2d( @@ -606,11 +612,10 @@ class TosaTestGen: DType.INT8, DType.UINT8, ): - qinfo = ts.TosaSerializerQuantInfo() - qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, ifm.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] @@ -642,9 +647,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_conv3d( @@ -670,11 +675,10 @@ class TosaTestGen: DType.INT8, DType.UINT8, ): - qinfo = ts.TosaSerializerQuantInfo() - qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, ifm.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] @@ -706,9 +710,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_transpose_conv2d( @@ -734,11 +738,10 @@ class TosaTestGen: DType.INT8, DType.UINT8, ): - qinfo = ts.TosaSerializerQuantInfo() - qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, ifm.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] @@ -769,9 +772,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.TransposeConvAttribute(out_pad, stride, output_shape) + attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_depthwise_conv2d( @@ -796,11 +799,10 @@ class TosaTestGen: DType.INT8, DType.UINT8, ): - qinfo = ts.TosaSerializerQuantInfo() - qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), - TosaQuantGen.getQinfo(self, result_tens.dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(self, ifm.dtype), + TosaQuantGen.getZeroPoint(self, result_tens.dtype), + ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] @@ -832,9 +834,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_fully_connected( @@ -871,7 +873,10 @@ class TosaTestGen: ): return None - self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) + attr = ts.TosaSerializerAttribute() + attr.FullyConnectedAttribute(qinfo[0], qinfo[1]) + + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None): @@ -905,7 +910,10 @@ class TosaTestGen: ): return None - self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) + attr = ts.TosaSerializerAttribute() + attr.MatMulAttribute(qinfo[0], qinfo[1]) + + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_reduce(self, op, a, axis, validator_fcns, error_name=None): @@ -1164,7 +1172,7 @@ class TosaTestGen: ): return None - self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None): @@ -2212,7 +2220,7 @@ class TosaTestGen: else: qinfo = None - tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, qinfo, error_name) + tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name) try: if error_if_validators is None: @@ -3425,7 +3433,6 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agPad, ), - "qgen": TosaQuantGen.qgPad, "types": TYPE_FIB, "error_if_validators": ( TosaErrorValidator.evWrongInputType, -- cgit v1.2.1