aboutsummaryrefslogtreecommitdiff
path: root/python/serializer/tosa_serializer.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/serializer/tosa_serializer.py')
-rw-r--r--python/serializer/tosa_serializer.py331
1 files changed, 329 insertions, 2 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index fec676e..cd86777 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -582,7 +582,7 @@ class TensorDir(IntEnum):
class TosaSerializer:
def __init__(self, pathPrefix):
-
+ self.add_compat_methods()
# Get the global TOSA version if not already defined
self.builder = flatbuffers.Builder(0)
@@ -749,7 +749,10 @@ class TosaSerializer:
start_fcn(builder, len(fb_strs))
for s in fb_strs[::-1]:
builder.PrependUOffsetTRelative(s)
- return builder.EndVector()
+ try:
+ return builder.EndVector()
+ except TypeError:
+ return builder.EndVector(len(vec))
@staticmethod
def serializeUint8Vec(builder, vec):
@@ -801,3 +804,327 @@ class TosaSerializer:
return val
else:
return [val]
+
+ # Remove when switching to flatbuffers 2.0
+ # contains a mapping of the deprecated 1.12 method to the 2.0 version
+
+ def add_compat_methods(self):
+
+ from tosa import ArithmeticRightShiftAttribute
+
+ if not hasattr(ArithmeticRightShiftAttribute, "Start"):
+ ArithmeticRightShiftAttribute.Start = (
+ ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeStart
+ )
+ ArithmeticRightShiftAttribute.AddRound = (
+ ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeAddRound
+ )
+ ArithmeticRightShiftAttribute.End = (
+ ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeEnd
+ )
+ from tosa import AxisAttribute
+
+ if not hasattr(AxisAttribute, "Start"):
+ AxisAttribute.Start = AxisAttribute.AxisAttributeStart
+ AxisAttribute.AddAxis = AxisAttribute.AxisAttributeAddAxis
+ AxisAttribute.End = AxisAttribute.AxisAttributeEnd
+ from tosa import ClampAttribute
+
+ if not hasattr(ClampAttribute, "Start"):
+ ClampAttribute.Start = ClampAttribute.ClampAttributeStart
+ ClampAttribute.AddMinInt = ClampAttribute.ClampAttributeAddMinInt
+ ClampAttribute.AddMaxInt = ClampAttribute.ClampAttributeAddMaxInt
+ ClampAttribute.AddMinFp = ClampAttribute.ClampAttributeAddMinFp
+ ClampAttribute.AddMaxFp = ClampAttribute.ClampAttributeAddMaxFp
+ ClampAttribute.End = ClampAttribute.ClampAttributeEnd
+ from tosa import CondIfAttribute
+
+ if not hasattr(CondIfAttribute, "Start"):
+ CondIfAttribute.Start = CondIfAttribute.CondIfAttributeStart
+ CondIfAttribute.AddThenBranch = CondIfAttribute.CondIfAttributeAddThenBranch
+ CondIfAttribute.AddElseBranch = CondIfAttribute.CondIfAttributeAddElseBranch
+ CondIfAttribute.End = CondIfAttribute.CondIfAttributeEnd
+ from tosa import ConvAttribute
+
+ if not hasattr(ConvAttribute, "Start"):
+ ConvAttribute.Start = ConvAttribute.ConvAttributeStart
+ ConvAttribute.AddPad = ConvAttribute.ConvAttributeAddPad
+ ConvAttribute.StartPadVector = ConvAttribute.ConvAttributeStartPadVector
+ ConvAttribute.AddStride = ConvAttribute.ConvAttributeAddStride
+ ConvAttribute.StartStrideVector = (
+ ConvAttribute.ConvAttributeStartStrideVector
+ )
+ ConvAttribute.AddDilation = ConvAttribute.ConvAttributeAddDilation
+ ConvAttribute.StartDilationVector = (
+ ConvAttribute.ConvAttributeStartDilationVector
+ )
+ ConvAttribute.End = ConvAttribute.ConvAttributeEnd
+ from tosa import ConvQuantInfo
+
+ if not hasattr(ConvQuantInfo, "Start"):
+ ConvQuantInfo.Start = ConvQuantInfo.ConvQuantInfoStart
+ ConvQuantInfo.AddInputZp = ConvQuantInfo.ConvQuantInfoAddInputZp
+ ConvQuantInfo.AddWeightZp = ConvQuantInfo.ConvQuantInfoAddWeightZp
+ ConvQuantInfo.End = ConvQuantInfo.ConvQuantInfoEnd
+ from tosa import MatMulQuantInfo
+
+ if not hasattr(MatMulQuantInfo, "Start"):
+ MatMulQuantInfo.Start = MatMulQuantInfo.MatMulQuantInfoStart
+ MatMulQuantInfo.AddAZp = MatMulQuantInfo.MatMulQuantInfoAddAZp
+ MatMulQuantInfo.AddBZp = MatMulQuantInfo.MatMulQuantInfoAddBZp
+ MatMulQuantInfo.End = MatMulQuantInfo.MatMulQuantInfoEnd
+ from tosa import MulAttribute
+
+ if not hasattr(MulAttribute, "Start"):
+ MulAttribute.Start = MulAttribute.MulAttributeStart
+ MulAttribute.AddShift = MulAttribute.MulAttributeAddShift
+ MulAttribute.End = MulAttribute.MulAttributeEnd
+ from tosa import PadAttribute
+
+ if not hasattr(PadAttribute, "Start"):
+ PadAttribute.Start = PadAttribute.PadAttributeStart
+ PadAttribute.AddPadding = PadAttribute.PadAttributeAddPadding
+ PadAttribute.StartPaddingVector = (
+ PadAttribute.PadAttributeStartPaddingVector
+ )
+ PadAttribute.AddPadConstInt = PadAttribute.PadAttributeAddPadConstInt
+ PadAttribute.AddPadConstFp = PadAttribute.PadAttributeAddPadConstFp
+ PadAttribute.End = PadAttribute.PadAttributeEnd
+ from tosa import PadQuantInfo
+
+ if not hasattr(PadQuantInfo, "Start"):
+ PadQuantInfo.Start = PadQuantInfo.PadQuantInfoStart
+ PadQuantInfo.AddInputZp = PadQuantInfo.PadQuantInfoAddInputZp
+ PadQuantInfo.End = PadQuantInfo.PadQuantInfoEnd
+ from tosa import PoolAttribute
+
+ if not hasattr(PoolAttribute, "Start"):
+ PoolAttribute.Start = PoolAttribute.PoolAttributeStart
+ PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad
+ PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector
+ PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel
+ PoolAttribute.StartKernelVector = (
+ PoolAttribute.PoolAttributeStartKernelVector
+ )
+ PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride
+ PoolAttribute.StartStrideVector = (
+ PoolAttribute.PoolAttributeStartStrideVector
+ )
+ PoolAttribute.End = PoolAttribute.PoolAttributeEnd
+ from tosa import RescaleAttribute
+
+ if not hasattr(RescaleAttribute, "Start"):
+ RescaleAttribute.Start = RescaleAttribute.RescaleAttributeStart
+ RescaleAttribute.AddInputZp = RescaleAttribute.RescaleAttributeAddInputZp
+ RescaleAttribute.AddOutputZp = RescaleAttribute.RescaleAttributeAddOutputZp
+ RescaleAttribute.AddMultiplier = (
+ RescaleAttribute.RescaleAttributeAddMultiplier
+ )
+ RescaleAttribute.StartMultiplierVector = (
+ RescaleAttribute.RescaleAttributeStartMultiplierVector
+ )
+ RescaleAttribute.AddShift = RescaleAttribute.RescaleAttributeAddShift
+ RescaleAttribute.StartShiftVector = (
+ RescaleAttribute.RescaleAttributeStartShiftVector
+ )
+ RescaleAttribute.AddScale32 = RescaleAttribute.RescaleAttributeAddScale32
+ RescaleAttribute.AddDoubleRound = (
+ RescaleAttribute.RescaleAttributeAddDoubleRound
+ )
+ RescaleAttribute.AddPerChannel = (
+ RescaleAttribute.RescaleAttributeAddPerChannel
+ )
+ RescaleAttribute.End = RescaleAttribute.RescaleAttributeEnd
+ from tosa import ReshapeAttribute
+
+ if not hasattr(ReshapeAttribute, "Start"):
+ ReshapeAttribute.Start = ReshapeAttribute.ReshapeAttributeStart
+ ReshapeAttribute.AddNewShape = ReshapeAttribute.ReshapeAttributeAddNewShape
+ ReshapeAttribute.StartNewShapeVector = (
+ ReshapeAttribute.ReshapeAttributeStartNewShapeVector
+ )
+ ReshapeAttribute.End = ReshapeAttribute.ReshapeAttributeEnd
+ from tosa import ResizeAttribute
+
+ if not hasattr(ResizeAttribute, "Start"):
+ ResizeAttribute.Start = ResizeAttribute.ResizeAttributeStart
+ ResizeAttribute.AddOutputSize = ResizeAttribute.ResizeAttributeAddOutputSize
+ ResizeAttribute.StartOutputSizeVector = (
+ ResizeAttribute.ResizeAttributeStartOutputSizeVector
+ )
+ ResizeAttribute.AddStride = ResizeAttribute.ResizeAttributeAddStride
+ ResizeAttribute.StartStrideVector = (
+ ResizeAttribute.ResizeAttributeStartStrideVector
+ )
+ ResizeAttribute.AddOffset = ResizeAttribute.ResizeAttributeAddOffset
+ ResizeAttribute.StartOffsetVector = (
+ ResizeAttribute.ResizeAttributeStartOffsetVector
+ )
+ ResizeAttribute.AddShift = ResizeAttribute.ResizeAttributeAddShift
+ ResizeAttribute.AddStrideFp = ResizeAttribute.ResizeAttributeAddStrideFp
+ ResizeAttribute.StartStrideFpVector = (
+ ResizeAttribute.ResizeAttributeStartStrideFpVector
+ )
+ ResizeAttribute.AddOffsetFp = ResizeAttribute.ResizeAttributeAddOffsetFp
+ ResizeAttribute.StartOffsetFpVector = (
+ ResizeAttribute.ResizeAttributeStartOffsetFpVector
+ )
+ ResizeAttribute.AddMode = ResizeAttribute.ResizeAttributeAddMode
+ ResizeAttribute.End = ResizeAttribute.ResizeAttributeEnd
+ from tosa import SliceAttribute
+
+ if not hasattr(SliceAttribute, "Start"):
+ SliceAttribute.Start = SliceAttribute.SliceAttributeStart
+ SliceAttribute.AddStart = SliceAttribute.SliceAttributeAddStart
+ SliceAttribute.StartStartVector = (
+ SliceAttribute.SliceAttributeStartStartVector
+ )
+ SliceAttribute.AddSize = SliceAttribute.SliceAttributeAddSize
+ SliceAttribute.StartSizeVector = (
+ SliceAttribute.SliceAttributeStartSizeVector
+ )
+ SliceAttribute.End = SliceAttribute.SliceAttributeEnd
+ from tosa import TableAttribute
+
+ if not hasattr(TableAttribute, "Start"):
+ TableAttribute.Start = TableAttribute.TableAttributeStart
+ TableAttribute.AddTable = TableAttribute.TableAttributeAddTable
+ TableAttribute.StartTableVector = (
+ TableAttribute.TableAttributeStartTableVector
+ )
+ TableAttribute.End = TableAttribute.TableAttributeEnd
+ from tosa import TileAttribute
+
+ if not hasattr(TileAttribute, "Start"):
+ TileAttribute.Start = TileAttribute.TileAttributeStart
+ TileAttribute.AddMultiples = TileAttribute.TileAttributeAddMultiples
+ TileAttribute.StartMultiplesVector = (
+ TileAttribute.TileAttributeStartMultiplesVector
+ )
+ TileAttribute.End = TileAttribute.TileAttributeEnd
+ from tosa import TosaBasicBlock
+
+ if not hasattr(TosaBasicBlock, "Start"):
+ TosaBasicBlock.Start = TosaBasicBlock.TosaBasicBlockStart
+ TosaBasicBlock.AddName = TosaBasicBlock.TosaBasicBlockAddName
+ TosaBasicBlock.AddOperators = TosaBasicBlock.TosaBasicBlockAddOperators
+ TosaBasicBlock.StartOperatorsVector = (
+ TosaBasicBlock.TosaBasicBlockStartOperatorsVector
+ )
+ TosaBasicBlock.AddTensors = TosaBasicBlock.TosaBasicBlockAddTensors
+ TosaBasicBlock.StartTensorsVector = (
+ TosaBasicBlock.TosaBasicBlockStartTensorsVector
+ )
+ TosaBasicBlock.AddInputs = TosaBasicBlock.TosaBasicBlockAddInputs
+ TosaBasicBlock.StartInputsVector = (
+ TosaBasicBlock.TosaBasicBlockStartInputsVector
+ )
+ TosaBasicBlock.AddOutputs = TosaBasicBlock.TosaBasicBlockAddOutputs
+ TosaBasicBlock.StartOutputsVector = (
+ TosaBasicBlock.TosaBasicBlockStartOutputsVector
+ )
+ TosaBasicBlock.End = TosaBasicBlock.TosaBasicBlockEnd
+ from tosa import TosaGraph
+
+ if not hasattr(TosaGraph, "Start"):
+ TosaGraph.Start = TosaGraph.TosaGraphStart
+ TosaGraph.AddVersion = TosaGraph.TosaGraphAddVersion
+ TosaGraph.AddBlocks = TosaGraph.TosaGraphAddBlocks
+ TosaGraph.StartBlocksVector = TosaGraph.TosaGraphStartBlocksVector
+ TosaGraph.End = TosaGraph.TosaGraphEnd
+ from tosa import TosaOperator
+
+ if not hasattr(TosaOperator, "Start"):
+ TosaOperator.Start = TosaOperator.TosaOperatorStart
+ TosaOperator.AddOp = TosaOperator.TosaOperatorAddOp
+ TosaOperator.AddAttributeType = TosaOperator.TosaOperatorAddAttributeType
+ TosaOperator.AddAttribute = TosaOperator.TosaOperatorAddAttribute
+ TosaOperator.AddInputs = TosaOperator.TosaOperatorAddInputs
+ TosaOperator.StartInputsVector = TosaOperator.TosaOperatorStartInputsVector
+ TosaOperator.AddOutputs = TosaOperator.TosaOperatorAddOutputs
+ TosaOperator.StartOutputsVector = (
+ TosaOperator.TosaOperatorStartOutputsVector
+ )
+ TosaOperator.AddQuantInfoType = TosaOperator.TosaOperatorAddQuantInfoType
+ TosaOperator.AddQuantInfo = TosaOperator.TosaOperatorAddQuantInfo
+ TosaOperator.End = TosaOperator.TosaOperatorEnd
+ from tosa import TosaTensor
+
+ if not hasattr(TosaTensor, "Start"):
+ TosaTensor.Start = TosaTensor.TosaTensorStart
+ TosaTensor.AddName = TosaTensor.TosaTensorAddName
+ TosaTensor.AddShape = TosaTensor.TosaTensorAddShape
+ TosaTensor.StartShapeVector = TosaTensor.TosaTensorStartShapeVector
+ TosaTensor.AddType = TosaTensor.TosaTensorAddType
+ TosaTensor.AddData = TosaTensor.TosaTensorAddData
+ TosaTensor.StartDataVector = TosaTensor.TosaTensorStartDataVector
+ TosaTensor.End = TosaTensor.TosaTensorEnd
+ from tosa import TransposeAttribute
+
+ if not hasattr(TransposeAttribute, "Start"):
+ TransposeAttribute.Start = TransposeAttribute.TransposeAttributeStart
+ TransposeAttribute.AddPerms = TransposeAttribute.TransposeAttributeAddPerms
+ TransposeAttribute.StartPermsVector = (
+ TransposeAttribute.TransposeAttributeStartPermsVector
+ )
+ TransposeAttribute.End = TransposeAttribute.TransposeAttributeEnd
+ from tosa import TransposeConvAttribute
+
+ if not hasattr(TransposeConvAttribute, "Start"):
+ TransposeConvAttribute.Start = (
+ TransposeConvAttribute.TransposeConvAttributeStart
+ )
+ TransposeConvAttribute.AddOutpad = (
+ TransposeConvAttribute.TransposeConvAttributeAddOutpad
+ )
+ TransposeConvAttribute.StartOutpadVector = (
+ TransposeConvAttribute.TransposeConvAttributeStartOutpadVector
+ )
+ TransposeConvAttribute.AddStride = (
+ TransposeConvAttribute.TransposeConvAttributeAddStride
+ )
+ TransposeConvAttribute.StartStrideVector = (
+ TransposeConvAttribute.TransposeConvAttributeStartStrideVector
+ )
+ TransposeConvAttribute.AddDilation = (
+ TransposeConvAttribute.TransposeConvAttributeAddDilation
+ )
+ TransposeConvAttribute.StartDilationVector = (
+ TransposeConvAttribute.TransposeConvAttributeStartDilationVector
+ )
+ TransposeConvAttribute.AddOutputShape = (
+ TransposeConvAttribute.TransposeConvAttributeAddOutputShape
+ )
+ TransposeConvAttribute.StartOutputShapeVector = (
+ TransposeConvAttribute.TransposeConvAttributeStartOutputShapeVector
+ )
+ TransposeConvAttribute.End = (
+ TransposeConvAttribute.TransposeConvAttributeEnd
+ )
+ from tosa import UnaryQuantInfo
+
+ if not hasattr(UnaryQuantInfo, "Start"):
+ UnaryQuantInfo.Start = UnaryQuantInfo.UnaryQuantInfoStart
+ UnaryQuantInfo.AddInputZp = UnaryQuantInfo.UnaryQuantInfoAddInputZp
+ UnaryQuantInfo.AddOutputZp = UnaryQuantInfo.UnaryQuantInfoAddOutputZp
+ UnaryQuantInfo.End = UnaryQuantInfo.UnaryQuantInfoEnd
+ from tosa import Version
+
+ if not hasattr(Version, "Start"):
+ Version.Start = Version.VersionStart
+ Version.Add_major = Version.VersionAdd_major
+ Version.Add_minor = Version.VersionAdd_minor
+ Version.Add_patch = Version.VersionAdd_patch
+ Version.Add_draft = Version.VersionAdd_draft
+ Version.End = Version.VersionEnd
+ from tosa import WhileLoopAttribute
+
+ if not hasattr(WhileLoopAttribute, "Start"):
+ WhileLoopAttribute.Start = WhileLoopAttribute.WhileLoopAttributeStart
+ WhileLoopAttribute.AddCondBranch = (
+ WhileLoopAttribute.WhileLoopAttributeAddCondBranch
+ )
+ WhileLoopAttribute.AddBodyBranch = (
+ WhileLoopAttribute.WhileLoopAttributeAddBodyBranch
+ )
+ WhileLoopAttribute.End = WhileLoopAttribute.WhileLoopAttributeEnd