From b5fabec33abeca2d92c20c7b094fa3f113d0ddd8 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 7 Jun 2022 05:20:44 +0000 Subject: Remove quantization info from serialization attributes Any needed information moves into the attributes for each operator. New serialization library version removes teh quantization information attributes from the schema Signed-off-by: Eric Kunze Change-Id: Icf6165687ab1fd34a01f64c01b0b92b2820e72fa --- verif/generator/tosa_arg_gen.py | 119 ++++++++++++++++++---------------------- 1 file changed, 53 insertions(+), 66 deletions(-) (limited to 'verif/generator/tosa_arg_gen.py') diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index a27d849..8e00fab 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -4,7 +4,6 @@ import itertools import math import numpy as np -import serializer.tosa_serializer as ts from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen from serializer.tosa_serializer import DTypeNames @@ -26,7 +25,7 @@ class TosaQuantGen: pass @staticmethod - def getQinfo(testGen, dtype, error_name=None): + def getZeroPoint(testGen, dtype, error_name=None): if dtype == DType.INT8: return testGen.randInt(-128, 128) @@ -45,27 +44,25 @@ class TosaQuantGen: @staticmethod def qgUnary(testGen, op, dtype, error_name=None): - qinfo = ts.TosaSerializerQuantInfo() if error_name == ErrorIf.InputZeroPointNotZero: - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype, error_name), - TosaQuantGen.getQinfo(testGen, dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype, error_name), + TosaQuantGen.getZeroPoint(testGen, dtype), + ] elif error_name == ErrorIf.OutputZeroPointNotZero: - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype), - TosaQuantGen.getQinfo(testGen, dtype, error_name), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype), + TosaQuantGen.getZeroPoint(testGen, dtype, error_name), + ] else: - qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype), - TosaQuantGen.getQinfo(testGen, dtype), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype), + TosaQuantGen.getZeroPoint(testGen, dtype), + ] return qinfo @staticmethod def qgConv(testGen, op, dtype_or_dtypeList, error_name=None): - qinfo = ts.TosaSerializerQuantInfo() if isinstance(dtype_or_dtypeList, list): # a list of [input, weights, accumulator] dtypes dtypeList = dtype_or_dtypeList @@ -74,40 +71,34 @@ class TosaQuantGen: dtypeList = [dtype_or_dtypeList] * 3 if error_name == ErrorIf.InputZeroPointNotZero: - input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name) - weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1]) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name), + TosaQuantGen.getZeroPoint(testGen, dtypeList[1]), + ] elif error_name == ErrorIf.WeightZeroPointNotZero: - input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0]) - weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtypeList[0]), + TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name), + ] else: - input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0]) - weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1]) - - qinfo.ConvQuantInfo(input_zp, weights_zp) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtypeList[0]), + TosaQuantGen.getZeroPoint(testGen, dtypeList[1]), + ] return qinfo @staticmethod def qgMatmul(testGen, op, dtype, error_name=None): - qinfo = ts.TosaSerializerQuantInfo() if error_name == ErrorIf.InputZeroPointNotZero: - qinfo.MatMulQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype, error_name), - TosaQuantGen.getQinfo(testGen, dtype, error_name), - ) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype, error_name), + TosaQuantGen.getZeroPoint(testGen, dtype, error_name), + ] else: - qinfo.MatMulQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype), - TosaQuantGen.getQinfo(testGen, dtype), - ) - return qinfo - - @staticmethod - def qgPad(testGen, op, dtype, error_name=None): - qinfo = ts.TosaSerializerQuantInfo() - if error_name == ErrorIf.InputZeroPointNotZero: - qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name)) - else: - qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype)) + qinfo = [ + TosaQuantGen.getZeroPoint(testGen, dtype), + TosaQuantGen.getZeroPoint(testGen, dtype), + ] return qinfo @staticmethod @@ -550,7 +541,7 @@ class TosaTensorValuesGen: pass @staticmethod - def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None): pCount, cCount = op["operands"] tens = [] @@ -562,7 +553,7 @@ class TosaTensorValuesGen: return tens @staticmethod - def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32 and error_name is None: pCount, cCount = op["operands"] assert ( @@ -582,11 +573,11 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32 and error_name is None: # Make sure the operation does not cause value saturation - where # the number wraps due to limited number of bits to store the answer @@ -651,12 +642,12 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgCondIfWhileLoop( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None + testGen, op, dtypeList, shapeList, testArgs, error_name=None ): if dtypeList[0] in ( DType.INT32, @@ -689,12 +680,12 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgArithmeticRightShift( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None + testGen, op, dtypeList, shapeList, testArgs, error_name=None ): pCount, cCount = op["operands"] # Force value of operand[1] to be within [0, num_bits] @@ -722,16 +713,16 @@ class TosaTensorValuesGen: return placeholders @staticmethod - def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None): # Set datatype of condition tensor to boolean dtypeList[0] = DType.BOOL return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if error_name is None: pCount, cCount = op["operands"] assert ( @@ -765,11 +756,11 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgMul(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if error_name is None: pCount, cCount = op["operands"] assert ( @@ -839,11 +830,11 @@ class TosaTensorValuesGen: return tens else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None): count = len(shapeList) - testGen.args.num_const_inputs_concat if count < 1: count = 1 @@ -866,9 +857,7 @@ class TosaTensorValuesGen: return tens @staticmethod - def tvgLogicalShift( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None - ): + def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None): pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 @@ -886,7 +875,7 @@ class TosaTensorValuesGen: return placeholders @staticmethod - def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None): + def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if error_name is None: pCount, cCount = op["operands"] assert ( @@ -924,13 +913,11 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod - def tvgReduceSum( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None - ): + def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32: pCount, cCount = op["operands"] assert ( @@ -949,7 +936,7 @@ class TosaTensorValuesGen: return placeholders else: return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name + testGen, op, dtypeList, shapeList, testArgs, error_name ) -- cgit v1.2.1