From bdcc3fee1b8bf55aac50e060115b92a1ccf9741c Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 7 Jun 2022 05:17:37 +0000 Subject: Remove quantinfo types Any needed information has been moved into the attributes for each operator. This aligns with the structure of the attributes in the TOSA specification, and generally simplifies the code. Change-Id: I8243e91b09de1a9115f8af09c5e7def7e8f2866b Signed-off-by: Eric Kunze --- python/serializer/tosa_serializer.py | 192 +++++++++++++++++++++-------------- 1 file changed, 115 insertions(+), 77 deletions(-) (limited to 'python/serializer/tosa_serializer.py') diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index 4d7d7bf..10372e4 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -30,7 +30,7 @@ import tosa.Op as TosaOp # Keep version number in sync with the version default value with schema/tosa.fbs TOSA_VERSION_MAJOR = 0 -TOSA_VERSION_MINOR = 25 +TOSA_VERSION_MINOR = 30 TOSA_VERSION_PATCH = 0 TOSA_VERSION_DRAFT = True TOSA_VERSION = [ @@ -141,7 +141,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): def __init__(self): super().__init__() - def PoolAttribute(self, kernel, stride, pad): + def PoolAttribute(self, kernel, stride, pad, input_zp, output_zp): from tosa import PoolAttribute as a, Attribute self.utype = Attribute.Attribute().PoolAttribute @@ -150,8 +150,10 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddPad, pad)) self.intvecs.append((a.AddKernel, kernel)) self.intvecs.append((a.AddStride, stride)) + self.ints.append((a.AddInputZp, input_zp)) + self.ints.append((a.AddOutputZp, output_zp)) - def ConvAttribute(self, pad, stride, dilation): + def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp): from tosa import ConvAttribute as a, Attribute self.utype = Attribute.Attribute().ConvAttribute @@ -160,8 +162,10 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddPad, pad)) self.intvecs.append((a.AddStride, stride)) self.intvecs.append((a.AddDilation, dilation)) + self.ints.append((a.AddInputZp, input_zp)) + self.ints.append((a.AddWeightZp, weight_zp)) - def TransposeConvAttribute(self, outpad, stride, output_shape): + def TransposeConvAttribute(self, outpad, stride, output_shape, input_zp, weight_zp): from tosa import TransposeConvAttribute as a, Attribute self.utype = Attribute.Attribute().TransposeConvAttribute @@ -170,6 +174,8 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddOutPad, outpad)) self.intvecs.append((a.AddStride, stride)) self.intvecs.append((a.AddOutputShape, output_shape)) + self.ints.append((a.AddInputZp, input_zp)) + self.ints.append((a.AddWeightZp, weight_zp)) def PadAttribute(self, padding, pad_const_int, pad_const_fp): from tosa import PadAttribute as a, Attribute @@ -311,43 +317,32 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddTable, table)) + def MatMulAttribute(self, A_zp, B_zp): + from tosa import MatMulAttribute as a, Attribute -class TosaSerializerQuantInfo(TosaSerializerUnion): - """This class handles encapsulating all of the enumerated types for quantinfo""" - - def __init__(self): - super().__init__() - - def ConvQuantInfo(self, input_zp, weight_zp): - from tosa import ConvQuantInfo as q, QuantInfo + self.utype = Attribute.Attribute().MatMulAttribute + self.optFcns = (a.Start, a.End) - self.utype = QuantInfo.QuantInfo().ConvQuantInfo - self.optFcns = (q.Start, q.End) - self.ints.append((q.AddInputZp, input_zp)) - self.ints.append((q.AddWeightZp, weight_zp)) + self.ints.append((a.AddAZp, A_zp)) + self.ints.append((a.AddBZp, B_zp)) - def UnaryQuantInfo(self, input_zp, output_zp): - from tosa import UnaryQuantInfo as q, QuantInfo + def FullyConnectedAttribute(self, input_zp, weight_zp): + from tosa import FullyConnectedAttribute as a, Attribute - self.utype = QuantInfo.QuantInfo().UnaryQuantInfo - self.optFcns = (q.Start, q.End) - self.ints.append((q.AddInputZp, input_zp)) - self.ints.append((q.AddOutputZp, output_zp)) + self.utype = Attribute.Attribute().FullyConnectedAttribute + self.optFcns = (a.Start, a.End) - def MatMulQuantInfo(self, a_zp, b_zp): - from tosa import MatMulQuantInfo as q, QuantInfo + self.ints.append((a.AddInputZp, input_zp)) + self.ints.append((a.AddWeightZp, weight_zp)) - self.utype = QuantInfo.QuantInfo().MatMulQuantInfo - self.optFcns = (q.Start, q.End) - self.ints.append((q.AddAZp, a_zp)) - self.ints.append((q.AddBZp, b_zp)) + def NegateAttribute(self, input1_zp, output_zp): + from tosa import NegateAttribute as a, Attribute - def PadQuantInfo(self, input_zp): - from tosa import PadQuantInfo as q, QuantInfo + self.utype = Attribute.Attribute().NegateAttribute + self.optFcns = (a.Start, a.End) - self.utype = QuantInfo.QuantInfo().PadQuantInfo - self.optFcns = (q.Start, q.End) - self.ints.append((q.AddInputZp, input_zp)) + self.ints.append((a.AddInput1Zp, input1_zp)) + self.ints.append((a.AddOutputZp, output_zp)) class TosaSerializerTensor: @@ -467,12 +462,11 @@ class TosaSerializerTensor: class TosaSerializerOperator: - def __init__(self, op, inputs, outputs, attributes=None, quantInfo=None): + def __init__(self, op, inputs, outputs, attributes=None): self.op = op self.attributes = attributes self.inputs = TosaSerializer.toList(inputs) self.outputs = TosaSerializer.toList(outputs) - self.quantInfo = quantInfo def __str__(self): str = "Op {}\n----\n".format(self.op) @@ -491,13 +485,10 @@ class TosaSerializerOperator: fb_outputs = TosaSerializer.serializeStrVec( builder, self.outputs, TosaOperator.StartOutputsVector ) - # Need to serialize quant_info and attributes enums still + # Need to serialize attributes enums still if self.attributes is not None: fb_attributes = self.attributes.serialize(builder) - if self.quantInfo is not None: - fb_qinfo = self.quantInfo.serialize(builder) - TosaOperator.Start(builder) TosaOperator.AddOp(builder, self.op) TosaOperator.AddInputs(builder, fb_inputs) @@ -505,9 +496,6 @@ class TosaSerializerOperator: if self.attributes is not None: TosaOperator.AddAttributeType(builder, self.attributes.utype) TosaOperator.AddAttribute(builder, fb_attributes) - if self.quantInfo is not None: - TosaOperator.AddQuantInfoType(builder, self.quantInfo.utype) - TosaOperator.AddQuantInfo(builder, fb_qinfo) return TosaOperator.End(builder) @@ -544,10 +532,8 @@ class TosaSerializerBasicBlock: def addOutput(self, name): self.outputs.append(name) - def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None): - self.operators.append( - TosaSerializerOperator(op, inputs, outputs, attributes, quant_info) - ) + def addOperator(self, op, inputs, outputs, attributes=None): + self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes)) def serialize(self, builder): fb_name = builder.CreateString(self.name) @@ -671,13 +657,16 @@ class TosaSerializer: self.currBasicBlock.addOutput(name) return tens - def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None): + def addOperator(self, op, inputs, outputs, attributes=None): if op == TosaOp.Op().CONST: raise Exception("Use addConstTensor() to add CONST ops") return self.currBasicBlock.addOperator( - op, inputs, outputs, attributes, quant_info + op, + inputs, + outputs, + attributes, ) def setExpectedReturnCode(self, val, fail, desc=""): @@ -861,21 +850,48 @@ class TosaSerializer: ConvAttribute.StartDilationVector = ( ConvAttribute.ConvAttributeStartDilationVector ) + ConvAttribute.AddInputZp = ConvAttribute.ConvAttributeAddInputZp + ConvAttribute.AddWeightZp = ConvAttribute.ConvAttributeAddWeightZp ConvAttribute.End = ConvAttribute.ConvAttributeEnd - from tosa import ConvQuantInfo - - if not hasattr(ConvQuantInfo, "Start"): - ConvQuantInfo.Start = ConvQuantInfo.ConvQuantInfoStart - ConvQuantInfo.AddInputZp = ConvQuantInfo.ConvQuantInfoAddInputZp - ConvQuantInfo.AddWeightZp = ConvQuantInfo.ConvQuantInfoAddWeightZp - ConvQuantInfo.End = ConvQuantInfo.ConvQuantInfoEnd - from tosa import MatMulQuantInfo - - if not hasattr(MatMulQuantInfo, "Start"): - MatMulQuantInfo.Start = MatMulQuantInfo.MatMulQuantInfoStart - MatMulQuantInfo.AddAZp = MatMulQuantInfo.MatMulQuantInfoAddAZp - MatMulQuantInfo.AddBZp = MatMulQuantInfo.MatMulQuantInfoAddBZp - MatMulQuantInfo.End = MatMulQuantInfo.MatMulQuantInfoEnd + from tosa import FullyConnectedAttribute + + if not hasattr(FullyConnectedAttribute, "Start"): + FullyConnectedAttribute.Start = ( + FullyConnectedAttribute.FullyConnectedAttributeStart + ) + FullyConnectedAttribute.AddInputZp = ( + FullyConnectedAttribute.FullyConnectedAttributeAddInputZp + ) + FullyConnectedAttribute.AddWeightZp = ( + FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp + ) + FullyConnectedAttribute.End = ( + FullyConnectedAttribute.FullyConnectedAttributeEnd + ) + from tosa import MatMulAttribute + + if not hasattr(MatMulAttribute, "Start"): + MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart + MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp + MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp + MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd + from tosa import PoolAttribute + + if not hasattr(PoolAttribute, "Start"): + PoolAttribute.Start = PoolAttribute.PoolAttributeStart + PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad + PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector + PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel + PoolAttribute.StartKernelVector = ( + PoolAttribute.PoolAttributeStartKernelVector + ) + PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride + PoolAttribute.StartStrideVector = ( + PoolAttribute.PoolAttributeStartStrideVector + ) + PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp + PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp + PoolAttribute.End = PoolAttribute.PoolAttributeEnd from tosa import MulAttribute if not hasattr(MulAttribute, "Start"): @@ -893,12 +909,6 @@ class TosaSerializer: PadAttribute.AddPadConstInt = PadAttribute.PadAttributeAddPadConstInt PadAttribute.AddPadConstFp = PadAttribute.PadAttributeAddPadConstFp PadAttribute.End = PadAttribute.PadAttributeEnd - from tosa import PadQuantInfo - - if not hasattr(PadQuantInfo, "Start"): - PadQuantInfo.Start = PadQuantInfo.PadQuantInfoStart - PadQuantInfo.AddInputZp = PadQuantInfo.PadQuantInfoAddInputZp - PadQuantInfo.End = PadQuantInfo.PadQuantInfoEnd from tosa import PoolAttribute if not hasattr(PoolAttribute, "Start"): @@ -913,6 +923,8 @@ class TosaSerializer: PoolAttribute.StartStrideVector = ( PoolAttribute.PoolAttributeStartStrideVector ) + PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp + PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp PoolAttribute.End = PoolAttribute.PoolAttributeEnd from tosa import RescaleAttribute @@ -1048,8 +1060,6 @@ class TosaSerializer: TosaOperator.StartOutputsVector = ( TosaOperator.TosaOperatorStartOutputsVector ) - TosaOperator.AddQuantInfoType = TosaOperator.TosaOperatorAddQuantInfoType - TosaOperator.AddQuantInfo = TosaOperator.TosaOperatorAddQuantInfo TosaOperator.End = TosaOperator.TosaOperatorEnd from tosa import TosaTensor @@ -1095,16 +1105,15 @@ class TosaSerializer: TransposeConvAttribute.StartOutputShapeVector = ( TransposeConvAttribute.TransposeConvAttributeStartOutputShapeVector ) + TransposeConvAttribute.AddInputZp = ( + TransposeConvAttribute.TransposeConvAttributeAddInputZp + ) + TransposeConvAttribute.AddWeightZp = ( + TransposeConvAttribute.TransposeConvAttributeAddWeightZp + ) TransposeConvAttribute.End = ( TransposeConvAttribute.TransposeConvAttributeEnd ) - from tosa import UnaryQuantInfo - - if not hasattr(UnaryQuantInfo, "Start"): - UnaryQuantInfo.Start = UnaryQuantInfo.UnaryQuantInfoStart - UnaryQuantInfo.AddInputZp = UnaryQuantInfo.UnaryQuantInfoAddInputZp - UnaryQuantInfo.AddOutputZp = UnaryQuantInfo.UnaryQuantInfoAddOutputZp - UnaryQuantInfo.End = UnaryQuantInfo.UnaryQuantInfoEnd from tosa import Version if not hasattr(Version, "Start"): @@ -1114,6 +1123,35 @@ class TosaSerializer: Version.Add_patch = Version.VersionAdd_patch Version.Add_draft = Version.VersionAdd_draft Version.End = Version.VersionEnd + from tosa import MatMulAttribute + + if not hasattr(MatMulAttribute, "Start"): + MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart + MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp + MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp + MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd + from tosa import FullyConnectedAttribute + + if not hasattr(FullyConnectedAttribute, "Start"): + FullyConnectedAttribute.Start = ( + FullyConnectedAttribute.FullyConnectedAttributeStart + ) + FullyConnectedAttribute.AddInputZp = ( + FullyConnectedAttribute.FullyConnectedAttributeAddInputZp + ) + FullyConnectedAttribute.AddWeightZp = ( + FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp + ) + FullyConnectedAttribute.End = ( + FullyConnectedAttribute.FullyConnectedAttributeEnd + ) + from tosa import NegateAttribute + + if not hasattr(NegateAttribute, "Start"): + NegateAttribute.Start = NegateAttribute.NegateAttributeStart + NegateAttribute.AddInput1Zp = NegateAttribute.NegateAttributeAddInput1Zp + NegateAttribute.AddOutputZp = NegateAttribute.NegateAttributeAddOutputZp + NegateAttribute.End = NegateAttribute.NegateAttributeEnd from tosa import WhileLoopAttribute if not hasattr(WhileLoopAttribute, "Start"): -- cgit v1.2.1