aboutsummaryrefslogtreecommitdiff
path: root/python/tosa_serializer.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/tosa_serializer.py')
-rw-r--r--python/tosa_serializer.py232
1 files changed, 113 insertions, 119 deletions
diff --git a/python/tosa_serializer.py b/python/tosa_serializer.py
index 6915c83..f294ba3 100644
--- a/python/tosa_serializer.py
+++ b/python/tosa_serializer.py
@@ -144,74 +144,74 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().PoolAttribute
- self.optFcns = (a.PoolAttributeStart, a.PoolAttributeEnd)
- self.intvecs.append((a.PoolAttributeAddPadding, padding))
- self.intvecs.append((a.PoolAttributeAddKernel, kernel))
- self.intvecs.append((a.PoolAttributeAddStride, stride))
+ self.optFcns = (a.Start, a.End)
+ self.intvecs.append((a.AddPadding, padding))
+ self.intvecs.append((a.AddKernel, kernel))
+ self.intvecs.append((a.AddStride, stride))
def ConvAttribute(self, padding, stride, dilation):
from tosa import ConvAttribute as a, Attribute
self.utype = Attribute.Attribute().ConvAttribute
- self.optFcns = (a.ConvAttributeStart, a.ConvAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.intvecs.append((a.ConvAttributeAddPadding, padding))
- self.intvecs.append((a.ConvAttributeAddStride, stride))
- self.intvecs.append((a.ConvAttributeAddDilation, dilation))
+ self.intvecs.append((a.AddPadding, padding))
+ self.intvecs.append((a.AddStride, stride))
+ self.intvecs.append((a.AddDilation, dilation))
def TransposeConvAttribute(self, outpad, stride, dilation, output_shape):
from tosa import TransposeConvAttribute as a, Attribute
self.utype = Attribute.Attribute().TransposeConvAttribute
- self.optFcns = (a.TransposeConvAttributeStart, a.TransposeConvAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.intvecs.append((a.TransposeConvAttributeAddOutpad, outpad))
- self.intvecs.append((a.TransposeConvAttributeAddStride, stride))
- self.intvecs.append((a.TransposeConvAttributeAddDilation, dilation))
- self.intvecs.append((a.TransposeConvAttributeAddOutputShape, output_shape))
+ self.intvecs.append((a.AddOutpad, outpad))
+ self.intvecs.append((a.AddStride, stride))
+ self.intvecs.append((a.AddDilation, dilation))
+ self.intvecs.append((a.AddOutputShape, output_shape))
def PadAttribute(self, padding, pad_const_int, pad_const_fp):
from tosa import PadAttribute as a, Attribute
self.utype = Attribute.Attribute().PadAttribute
- self.optFcns = (a.PadAttributeStart, a.PadAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.intvecs.append((a.PadAttributeAddPadding, padding))
- self.ints.append((a.PadAttributeAddPadConstInt, pad_const_int))
- self.floats.append((a.PadAttributeAddPadConstFp, pad_const_fp))
+ self.intvecs.append((a.AddPadding, padding))
+ self.ints.append((a.AddPadConstInt, pad_const_int))
+ self.floats.append((a.AddPadConstFp, pad_const_fp))
def AxisAttribute(self, axis):
from tosa import AxisAttribute as a, Attribute
self.utype = Attribute.Attribute().AxisAttribute
- self.optFcns = (a.AxisAttributeStart, a.AxisAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.ints.append((a.AxisAttributeAddAxis, axis))
+ self.ints.append((a.AddAxis, axis))
def ReshapeAttribute(self, shape):
from tosa import ReshapeAttribute as a, Attribute
self.utype = Attribute.Attribute().ReshapeAttribute
- self.optFcns = (a.ReshapeAttributeStart, a.ReshapeAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.intvecs.append((a.ReshapeAttributeAddShape, shape))
+ self.intvecs.append((a.AddShape, shape))
def SliceAttribute(self, begin, size):
from tosa import SliceAttribute as a, Attribute
self.utype = Attribute.Attribute().SliceAttribute
- self.optFcns = (a.SliceAttributeStart, a.SliceAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.intvecs.append((a.SliceAttributeAddBegin, begin))
- self.intvecs.append((a.SliceAttributeAddSize, size))
+ self.intvecs.append((a.AddBegin, begin))
+ self.intvecs.append((a.AddSize, size))
def TileAttribute(self, multiples):
from tosa import TileAttribute as a, Attribute
self.utype = Attribute.Attribute().TileAttribute
- self.optFcns = (a.TileAttributeStart, a.TileAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.intvecs.append((a.TileAttributeAddMultiples, multiples))
+ self.intvecs.append((a.AddMultiples, multiples))
def ResizeAttribute(
self, output_size, stride, offset, shift, stride_fp, offset_fp, mode
@@ -219,27 +219,27 @@ class TosaSerializerAttribute(TosaSerializerUnion):
from tosa import ResizeAttribute as a, Attribute
self.utype = Attribute.Attribute().ResizeAttribute
- self.optFcns = (a.ResizeAttributeStart, a.ResizeAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.intvecs.append((a.ResizeAttributeAddOutputSize, output_size))
- self.intvecs.append((a.ResizeAttributeAddStride, stride))
- self.intvecs.append((a.ResizeAttributeAddOffset, offset))
- self.ints.append((a.ResizeAttributeAddShift, shift))
- self.fpvecs.append((a.ResizeAttributeAddStrideFp, stride_fp))
- self.fpvecs.append((a.ResizeAttributeAddOffsetFp, offset_fp))
- self.ints.append((a.ResizeAttributeAddMode, mode))
+ self.intvecs.append((a.AddOutputSize, output_size))
+ self.intvecs.append((a.AddStride, stride))
+ self.intvecs.append((a.AddOffset, offset))
+ self.ints.append((a.AddShift, shift))
+ self.fpvecs.append((a.AddStrideFp, stride_fp))
+ self.fpvecs.append((a.AddOffsetFp, offset_fp))
+ self.ints.append((a.AddMode, mode))
def ClampAttribute(self, minint, maxint, minfp, maxfp):
from tosa import ClampAttribute as a, Attribute
self.utype = Attribute.Attribute().ClampAttribute
- self.optFcns = (a.ClampAttributeStart, a.ClampAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.ints.append((a.ClampAttributeAddMinInt, minint))
- self.ints.append((a.ClampAttributeAddMaxInt, maxint))
+ self.ints.append((a.AddMinInt, minint))
+ self.ints.append((a.AddMaxInt, maxint))
- self.ints.append((a.ClampAttributeAddMinFp, minfp))
- self.ints.append((a.ClampAttributeAddMaxFp, maxfp))
+ self.ints.append((a.AddMinFp, minfp))
+ self.ints.append((a.AddMaxFp, maxfp))
def RescaleAttribute(
self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel
@@ -247,68 +247,68 @@ class TosaSerializerAttribute(TosaSerializerUnion):
from tosa import RescaleAttribute as a, Attribute
self.utype = Attribute.Attribute().RescaleAttribute
- self.optFcns = (a.RescaleAttributeStart, a.RescaleAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.ints.append((a.RescaleAttributeAddInputZp, input_zp))
- self.ints.append((a.RescaleAttributeAddOutputZp, output_zp))
- self.intvecs.append((a.RescaleAttributeAddMultiplier, multiplier))
- self.intvecs.append((a.RescaleAttributeAddShift, shift))
- self.bools.append((a.RescaleAttributeAddScale32, scale32))
- self.bools.append((a.RescaleAttributeAddDoubleRound, double_round))
- self.bools.append((a.RescaleAttributeAddPerChannel, per_channel))
+ self.ints.append((a.AddInputZp, input_zp))
+ self.ints.append((a.AddOutputZp, output_zp))
+ self.intvecs.append((a.AddMultiplier, multiplier))
+ self.intvecs.append((a.AddShift, shift))
+ self.bools.append((a.AddScale32, scale32))
+ self.bools.append((a.AddDoubleRound, double_round))
+ self.bools.append((a.AddPerChannel, per_channel))
def MulAttribute(self, shift):
from tosa import MulAttribute as a, Attribute
self.utype = Attribute.Attribute().MulAttribute
- self.optFcns = (a.MulAttributeStart, a.MulAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.ints.append((a.MulAttributeAddShift, shift))
+ self.ints.append((a.AddShift, shift))
def ArithmeticRightShiftAttribute(self, round):
from tosa import ArithmeticRightShiftAttribute as a, Attribute
self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
self.optFcns = (
- a.ArithmeticRightShiftAttributeStart,
- a.ArithmeticRightShiftAttributeEnd,
+ a.Start,
+ a.End,
)
- self.bools.append((a.ArithmeticRightShiftAttributeAddRound, round))
+ self.bools.append((a.AddRound, round))
def CondIfAttribute(self, then_branch, else_branch):
from tosa import CondIfAttribute as a, Attribute
self.utype = Attribute.Attribute().CondIfAttribute
- self.optFcns = (a.CondIfAttributeStart, a.CondIfAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.strings.append((a.CondIfAttributeAddThenBranch, then_branch))
- self.strings.append((a.CondIfAttributeAddElseBranch, else_branch))
+ self.strings.append((a.AddThenBranch, then_branch))
+ self.strings.append((a.AddElseBranch, else_branch))
def WhileLoopAttribute(self, cond_branch, body_branch):
from tosa import WhileLoopAttribute as a, Attribute
self.utype = Attribute.Attribute().WhileLoopAttribute
- self.optFcns = (a.WhileLoopAttributeStart, a.WhileLoopAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.strings.append((a.WhileLoopAttributeAddCondBranch, cond_branch))
- self.strings.append((a.WhileLoopAttributeAddBodyBranch, body_branch))
+ self.strings.append((a.AddCondBranch, cond_branch))
+ self.strings.append((a.AddBodyBranch, body_branch))
def TransposeAttribute(self, perm):
from tosa import TransposeAttribute as a, Attribute
self.utype = Attribute.Attribute().TransposeAttribute
- self.optFcns = (a.TransposeAttributeStart, a.TransposeAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.intvecs.append((a.TransposeAttributeAddPerm, perm))
+ self.intvecs.append((a.AddPerm, perm))
def TableAttribute(self, table):
from tosa import TableAttribute as a, Attribute
self.utype = Attribute.Attribute().TableAttribute
- self.optFcns = (a.TableAttributeStart, a.TableAttributeEnd)
+ self.optFcns = (a.Start, a.End)
- self.intvecs.append((a.TableAttributeAddTable, table))
+ self.intvecs.append((a.AddTable, table))
class TosaSerializerQuantInfo(TosaSerializerUnion):
"""This class handles encapsulating all of the enumerated types for quantinfo types"""
@@ -320,32 +320,32 @@ class TosaSerializerQuantInfo(TosaSerializerUnion):
from tosa import ConvQuantInfo as q, QuantInfo
self.utype = QuantInfo.QuantInfo().ConvQuantInfo
- self.optFcns = (q.ConvQuantInfoStart, q.ConvQuantInfoEnd)
- self.ints.append((q.ConvQuantInfoAddInputZp, input_zp))
- self.ints.append((q.ConvQuantInfoAddWeightZp, weight_zp))
+ self.optFcns = (q.Start, q.End)
+ self.ints.append((q.AddInputZp, input_zp))
+ self.ints.append((q.AddWeightZp, weight_zp))
def UnaryQuantInfo(self, input_zp, output_zp):
from tosa import UnaryQuantInfo as q, QuantInfo
self.utype = QuantInfo.QuantInfo().UnaryQuantInfo
- self.optFcns = (q.UnaryQuantInfoStart, q.UnaryQuantInfoEnd)
- self.ints.append((q.UnaryQuantInfoAddInputZp, input_zp))
- self.ints.append((q.UnaryQuantInfoAddOutputZp, output_zp))
+ self.optFcns = (q.Start, q.End)
+ self.ints.append((q.AddInputZp, input_zp))
+ self.ints.append((q.AddOutputZp, output_zp))
def MatMulQuantInfo(self, a_zp, b_zp):
from tosa import MatMulQuantInfo as q, QuantInfo
self.utype = QuantInfo.QuantInfo().MatMulQuantInfo
- self.optFcns = (q.MatMulQuantInfoStart, q.MatMulQuantInfoEnd)
- self.ints.append((q.MatMulQuantInfoAddAZp, a_zp))
- self.ints.append((q.MatMulQuantInfoAddBZp, b_zp))
+ self.optFcns = (q.Start, q.End)
+ self.ints.append((q.AddAZp, a_zp))
+ self.ints.append((q.AddBZp, b_zp))
def PadQuantInfo(self, input_zp):
from tosa import PadQuantInfo as q, QuantInfo
self.utype = QuantInfo.QuantInfo().PadQuantInfo
- self.optFcns = (q.PadQuantInfoStart, q.PadQuantInfoEnd)
- self.ints.append((q.PadQuantInfoAddInputZp, input_zp))
+ self.optFcns = (q.Start, q.End)
+ self.ints.append((q.AddInputZp, input_zp))
class TosaSerializerTensor:
@@ -453,14 +453,14 @@ class TosaSerializerTensor:
)
fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
- TosaTensor.TosaTensorStart(builder)
- TosaTensor.TosaTensorAddName(builder, fb_name)
- TosaTensor.TosaTensorAddShape(builder, fb_shapes)
- TosaTensor.TosaTensorAddType(builder, self.dtype)
+ TosaTensor.Start(builder)
+ TosaTensor.AddName(builder, fb_name)
+ TosaTensor.AddShape(builder, fb_shapes)
+ TosaTensor.AddType(builder, self.dtype)
if self.data:
- TosaTensor.TosaTensorAddData(builder, fb_data)
+ TosaTensor.AddData(builder, fb_data)
- return TosaTensor.TosaTensorEnd(builder)
+ return TosaTensor.End(builder)
class TosaSerializerOperator:
@@ -483,10 +483,10 @@ class TosaSerializerOperator:
def serialize(self, builder):
fb_inputs = TosaSerializer.serializeStrVec(
- builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector
+ builder, self.inputs, TosaOperator.StartInputsVector
)
fb_outputs = TosaSerializer.serializeStrVec(
- builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector
+ builder, self.outputs, TosaOperator.StartOutputsVector
)
# Need to serialize quant_info and attributes enums still
if self.attributes is not None:
@@ -495,18 +495,18 @@ class TosaSerializerOperator:
if self.quantInfo is not None:
fb_qinfo = self.quantInfo.serialize(builder)
- TosaOperator.TosaOperatorStart(builder)
- TosaOperator.TosaOperatorAddOp(builder, self.op)
- TosaOperator.TosaOperatorAddInputs(builder, fb_inputs)
- TosaOperator.TosaOperatorAddOutputs(builder, fb_outputs)
+ TosaOperator.Start(builder)
+ TosaOperator.AddOp(builder, self.op)
+ TosaOperator.AddInputs(builder, fb_inputs)
+ TosaOperator.AddOutputs(builder, fb_outputs)
if self.attributes is not None:
- TosaOperator.TosaOperatorAddAttributeType(builder, self.attributes.utype)
- TosaOperator.TosaOperatorAddAttribute(builder, fb_attributes)
+ TosaOperator.AddAttributeType(builder, self.attributes.utype)
+ TosaOperator.AddAttribute(builder, fb_attributes)
if self.quantInfo is not None:
- TosaOperator.TosaOperatorAddQuantInfoType(builder, self.quantInfo.utype)
- TosaOperator.TosaOperatorAddQuantInfo(builder, fb_qinfo)
+ TosaOperator.AddQuantInfoType(builder, self.quantInfo.utype)
+ TosaOperator.AddQuantInfo(builder, fb_qinfo)
- return TosaOperator.TosaOperatorEnd(builder)
+ return TosaOperator.End(builder)
class TosaSerializerBasicBlock:
@@ -552,27 +552,27 @@ class TosaSerializerBasicBlock:
def serialize(self, builder):
fb_name = builder.CreateString(self.name)
fbv_inputs = TosaSerializer.serializeStrVec(
- builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector
+ builder, list(self.inputs), TosaBasicBlock.StartInputsVector
)
fbv_outputs = TosaSerializer.serializeStrVec(
- builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector
+ builder, list(self.outputs), TosaBasicBlock.StartOutputsVector
)
fbv_tensors = TosaSerializer.serializeObjVec(
builder,
list(self.tensors.values()),
- TosaBasicBlock.TosaBasicBlockStartTensorsVector,
+ TosaBasicBlock.StartTensorsVector,
)
fbv_operators = TosaSerializer.serializeObjVec(
- builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector
+ builder, self.operators, TosaBasicBlock.StartOperatorsVector
)
- TosaBasicBlock.TosaBasicBlockStart(builder)
- TosaBasicBlock.TosaBasicBlockAddName(builder, fb_name)
- TosaBasicBlock.TosaBasicBlockAddInputs(builder, fbv_inputs)
- TosaBasicBlock.TosaBasicBlockAddOutputs(builder, fbv_outputs)
- TosaBasicBlock.TosaBasicBlockAddTensors(builder, fbv_tensors)
- TosaBasicBlock.TosaBasicBlockAddOperators(builder, fbv_operators)
- return TosaBasicBlock.TosaBasicBlockEnd(builder)
+ TosaBasicBlock.Start(builder)
+ TosaBasicBlock.AddName(builder, fb_name)
+ TosaBasicBlock.AddInputs(builder, fbv_inputs)
+ TosaBasicBlock.AddOutputs(builder, fbv_outputs)
+ TosaBasicBlock.AddTensors(builder, fbv_tensors)
+ TosaBasicBlock.AddOperators(builder, fbv_operators)
+ return TosaBasicBlock.End(builder)
@unique
@@ -697,21 +697,21 @@ class TosaSerializer:
builder = self.builder
- Version.VersionStart(builder)
- Version.VersionAdd_major(builder, TOSA_VERSION[0])
- Version.VersionAdd_minor(builder, TOSA_VERSION[1])
- Version.VersionAdd_patch(builder, TOSA_VERSION[2])
- Version.VersionAdd_draft(builder, TOSA_VERSION[3])
- version = Version.VersionEnd(builder)
+ Version.Start(builder)
+ Version.Add_major(builder, TOSA_VERSION[0])
+ Version.Add_minor(builder, TOSA_VERSION[1])
+ Version.Add_patch(builder, TOSA_VERSION[2])
+ Version.Add_draft(builder, TOSA_VERSION[3])
+ version = Version.End(builder)
fbv_bb = TosaSerializer.serializeObjVec(
- builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector
+ builder, self.basicBlocks, TosaGraph.StartBlocksVector
)
- TosaGraph.TosaGraphStart(builder)
- TosaGraph.TosaGraphAddVersion(builder, version)
- TosaGraph.TosaGraphAddBlocks(builder, fbv_bb)
- graph = TosaGraph.TosaGraphEnd(builder)
+ TosaGraph.Start(builder)
+ TosaGraph.AddVersion(builder, version)
+ TosaGraph.AddBlocks(builder, fbv_bb)
+ graph = TosaGraph.End(builder)
self.builder.Finish(graph)
return self.builder.Output()
@@ -759,13 +759,7 @@ class TosaSerializer:
start_fcn(builder, len(fb_strs))
for s in fb_strs[::-1]:
builder.PrependUOffsetTRelative(s)
- # This try/except block supports both the Flatbuffers 2.x and 1.x APIs,
- # defaulting to 2.x. If/when Flatbuffers 1.x support is deprecated, the
- # try block and builder.EndVector(len) function calls can be removed.
- try:
- return builder.EndVector()
- except TypeError:
- return builder.EndVector(len(fb_strs))
+ return builder.EndVector()
@staticmethod
def serializeUint8Vec(builder, vec):