diff options
Diffstat (limited to 'python/serializer/tosa_serializer.py')
-rw-r--r-- | python/serializer/tosa_serializer.py | 782 |
1 files changed, 287 insertions, 495 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index ec1c12d..9658edf 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2024, ARM Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +13,14 @@ # limitations under the License. import os +import serializer.tosa_serializer as ts import json import flatbuffers import numpy as np -import struct from enum import IntEnum, unique from tosa import ( TosaGraph, + TosaRegion, TosaBasicBlock, TosaTensor, TosaOperator, @@ -30,7 +31,7 @@ import tosa.Op as TosaOp # Keep version number in sync with the version default value with schema/tosa.fbs TOSA_VERSION_MAJOR = 0 -TOSA_VERSION_MINOR = 31 +TOSA_VERSION_MINOR = 100 TOSA_VERSION_PATCH = 0 TOSA_VERSION_DRAFT = True TOSA_VERSION = [ @@ -56,8 +57,13 @@ DTypeNames = [ "INT16", "INT32", "INT48", - "FLOAT", + "FP32", "UINT16", + "FP16", + "BF16", + "SHAPE", + "FP8E4M3", + "FP8E5M2", ] ByteMask = np.uint64(0xFF) @@ -90,6 +96,7 @@ class TosaSerializerUnion: self.bools = [] self.floats = [] self.strings = [] + self.int16vecs = [] self.intvecs = [] self.fpvecs = [] @@ -106,6 +113,9 @@ class TosaSerializerUnion: for fcn, val in self.intvecs: intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val))) + for fcn, val in self.int16vecs: + intVecList.append((fcn, TosaSerializer.serializeInt16Vec(builder, val))) + for fcn, val in self.fpvecs: fpVecList.append((fcn, TosaSerializer.serializeFpVec(builder, val))) @@ -141,7 +151,15 @@ class TosaSerializerAttribute(TosaSerializerUnion): def __init__(self): super().__init__() - def PoolAttribute(self, kernel, stride, pad, input_zp, output_zp): + def PoolAttribute( + self, + kernel, + stride, + pad, + input_zp, + output_zp, + acc_type, + ): from tosa import PoolAttribute as a, Attribute self.utype = Attribute.Attribute().PoolAttribute @@ -152,8 +170,11 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddStride, stride)) self.ints.append((a.AddInputZp, input_zp)) self.ints.append((a.AddOutputZp, output_zp)) + self.ints.append((a.AddAccType, acc_type)) - def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp): + def ConvAttribute( + self, pad, stride, dilation, input_zp, weight_zp, local_bound, acc_type + ): from tosa import ConvAttribute as a, Attribute self.utype = Attribute.Attribute().ConvAttribute @@ -164,8 +185,12 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddDilation, dilation)) self.ints.append((a.AddInputZp, input_zp)) self.ints.append((a.AddWeightZp, weight_zp)) + self.bools.append((a.AddLocalBound, local_bound)) + self.ints.append((a.AddAccType, acc_type)) - def TransposeConvAttribute(self, outpad, stride, output_shape, input_zp, weight_zp): + def TransposeConvAttribute( + self, outpad, stride, output_shape, input_zp, weight_zp, local_bound, acc_type + ): from tosa import TransposeConvAttribute as a, Attribute self.utype = Attribute.Attribute().TransposeConvAttribute @@ -176,16 +201,21 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddOutputShape, output_shape)) self.ints.append((a.AddInputZp, input_zp)) self.ints.append((a.AddWeightZp, weight_zp)) + self.bools.append((a.AddLocalBound, local_bound)) + self.ints.append((a.AddAccType, acc_type)) - def PadAttribute(self, padding, pad_const_int, pad_const_fp): + def PadAttribute(self, serializer_builder, pad_const_val_as_bytes): from tosa import PadAttribute as a, Attribute self.utype = Attribute.Attribute().PadAttribute self.optFcns = (a.Start, a.End) - self.intvecs.append((a.AddPadding, padding)) - self.ints.append((a.AddPadConstInt, pad_const_int)) - self.floats.append((a.AddPadConstFp, pad_const_fp)) + # serialize pad_const_val_as_bytes as uint8 vector + serialized_pad_const_val = ts.TosaSerializer.serializeUint8Vec( + serializer_builder, pad_const_val_as_bytes + ) + + self.floats.append((a.AddPadConst, serialized_pad_const_val)) def AxisAttribute(self, axis): from tosa import AxisAttribute as a, Attribute @@ -195,61 +225,43 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddAxis, axis)) - def ReshapeAttribute(self, new_shape): - from tosa import ReshapeAttribute as a, Attribute - - self.utype = Attribute.Attribute().ReshapeAttribute - self.optFcns = (a.Start, a.End) - - self.intvecs.append((a.AddNewShape, new_shape)) - - def SliceAttribute(self, start, size): - from tosa import SliceAttribute as a, Attribute - - self.utype = Attribute.Attribute().SliceAttribute - self.optFcns = (a.Start, a.End) - - self.intvecs.append((a.AddStart, start)) - self.intvecs.append((a.AddSize, size)) - - def TileAttribute(self, multiples): - from tosa import TileAttribute as a, Attribute - - self.utype = Attribute.Attribute().TileAttribute - self.optFcns = (a.Start, a.End) - - self.intvecs.append((a.AddMultiples, multiples)) - - def ResizeAttribute( - self, output_size, stride, offset, shift, stride_fp, offset_fp, mode - ): + def ResizeAttribute(self, scale, offset, border, mode): from tosa import ResizeAttribute as a, Attribute self.utype = Attribute.Attribute().ResizeAttribute self.optFcns = (a.Start, a.End) - self.intvecs.append((a.AddOutputSize, output_size)) - self.intvecs.append((a.AddStride, stride)) - self.intvecs.append((a.AddOffset, offset)) - self.ints.append((a.AddShift, shift)) - self.fpvecs.append((a.AddStrideFp, stride_fp)) - self.fpvecs.append((a.AddOffsetFp, offset_fp)) + self.int16vecs.append((a.AddScale, scale)) + self.int16vecs.append((a.AddOffset, offset)) + self.int16vecs.append((a.AddBorder, border)) self.ints.append((a.AddMode, mode)) - def ClampAttribute(self, minint, maxint, minfp, maxfp): + def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes): from tosa import ClampAttribute as a, Attribute self.utype = Attribute.Attribute().ClampAttribute self.optFcns = (a.Start, a.End) - self.ints.append((a.AddMinInt, minint)) - self.ints.append((a.AddMaxInt, maxint)) + # min/max float attributes serialized as uint8 vectors + serialized_min_val = ts.TosaSerializer.serializeUint8Vec( + serializer_builder, min_val_as_bytes + ) + serialized_max_val = ts.TosaSerializer.serializeUint8Vec( + serializer_builder, max_val_as_bytes + ) - self.ints.append((a.AddMinFp, minfp)) - self.ints.append((a.AddMaxFp, maxfp)) + self.floats.append((a.AddMinVal, serialized_min_val)) + self.floats.append((a.AddMaxVal, serialized_max_val)) def RescaleAttribute( - self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel + self, + input_zp, + output_zp, + scale32, + double_round, + per_channel, + input_unsigned, + output_unsigned, ): from tosa import RescaleAttribute as a, Attribute @@ -258,11 +270,11 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddInputZp, input_zp)) self.ints.append((a.AddOutputZp, output_zp)) - self.intvecs.append((a.AddMultiplier, multiplier)) - self.intvecs.append((a.AddShift, shift)) self.bools.append((a.AddScale32, scale32)) self.bools.append((a.AddDoubleRound, double_round)) self.bools.append((a.AddPerChannel, per_channel)) + self.bools.append((a.AddInputUnsigned, input_unsigned)) + self.bools.append((a.AddOutputUnsigned, output_unsigned)) def MulAttribute(self, shift): from tosa import MulAttribute as a, Attribute @@ -283,23 +295,23 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.bools.append((a.AddRound, round)) - def CondIfAttribute(self, then_branch, else_branch): + def CondIfAttribute(self, then_graph, else_graph): from tosa import CondIfAttribute as a, Attribute self.utype = Attribute.Attribute().CondIfAttribute self.optFcns = (a.Start, a.End) - self.strings.append((a.AddThenBranch, then_branch)) - self.strings.append((a.AddElseBranch, else_branch)) + self.strings.append((a.AddThenGraph, then_graph)) + self.strings.append((a.AddElseGraph, else_graph)) - def WhileLoopAttribute(self, cond_branch, body_branch): + def WhileLoopAttribute(self, cond_graph, body_graph): from tosa import WhileLoopAttribute as a, Attribute self.utype = Attribute.Attribute().WhileLoopAttribute self.optFcns = (a.Start, a.End) - self.strings.append((a.AddCondBranch, cond_branch)) - self.strings.append((a.AddBodyBranch, body_branch)) + self.strings.append((a.AddCondGraph, cond_graph)) + self.strings.append((a.AddBodyGraph, body_graph)) def TransposeAttribute(self, perms): from tosa import TransposeAttribute as a, Attribute @@ -315,7 +327,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.utype = Attribute.Attribute().TableAttribute self.optFcns = (a.Start, a.End) - self.intvecs.append((a.AddTable, table)) + self.int16vecs.append((a.AddTable, table)) def MatMulAttribute(self, A_zp, B_zp): from tosa import MatMulAttribute as a, Attribute @@ -344,6 +356,23 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddInput1Zp, input1_zp)) self.ints.append((a.AddOutputZp, output_zp)) + def FFTAttribute(self, inverse, local_bound): + from tosa import FFTAttribute as a, Attribute + + self.utype = Attribute.Attribute().FFTAttribute + self.optFcns = (a.Start, a.End) + + self.bools.append((a.AddInverse, inverse)) + self.bools.append((a.AddLocalBound, local_bound)) + + def RFFTAttribute(self, local_bound): + from tosa import RFFTAttribute as a, Attribute + + self.utype = Attribute.Attribute().RFFTAttribute + self.optFcns = (a.Start, a.End) + + self.bools.append((a.AddLocalBound, local_bound)) + class TosaSerializerTensor: def __init__( @@ -363,12 +392,24 @@ class TosaSerializerTensor: self.shape = shape self.dtype = dtype + if ( + dtype == DType.FP32 + or dtype == DType.BF16 + or dtype == DType.FP8E4M3 + or dtype == DType.FP8E5M2 + ): + fntype = np.float32 + elif dtype == DType.FP16: + fntype = np.float16 + else: + fntype = int + if isinstance(data, np.ndarray): - data = data.flatten().astype(int).tolist() - data = list(map(int, data)) + data = data.flatten().astype(fntype).tolist() + data = list(map(fntype, data)) self.data = data elif isinstance(data, list): - data = list(map(int, data)) + data = list(map(fntype, data)) self.data = data else: self.data = None @@ -381,12 +422,12 @@ class TosaSerializerTensor: self.placeholderFilename = placeholderFilename def __str__(self): - str = "TosaSerializerTensor name: {} shape: {} dtype: {}".format( + concatString = "TosaSerializerTensor name: {} shape: {} dtype: {}".format( self.name, self.shape, DTypeNames[self.dtype], ) - return str + return concatString def setDtype(self, dtype): self.dtype = dtype @@ -415,17 +456,17 @@ class TosaSerializerTensor: u8_data.append(val_u8) elif self.dtype == DType.INT8: for val in self.data: - val_u8 = np.uint8(val) + val_u8 = np.array(val).astype(dtype=np.uint8) u8_data.append(val_u8) elif self.dtype == DType.INT16: for val in self.data: - val_u16 = np.uint16(val) + val_u16 = np.array(val).astype(dtype=np.uint16) b0 = val_u16 & ByteMask b1 = (val_u16 >> np.uint16(8)) & ByteMask u8_data.extend([b0, b1]) elif self.dtype == DType.INT32: for val in self.data: - val_u32 = np.uint32(val) + val_u32 = np.array(val).astype(dtype=np.uint32) b0 = val_u32 & ByteMask b1 = (val_u32 >> np.uint32(8)) & ByteMask b2 = (val_u32 >> np.uint32(16)) & ByteMask @@ -441,10 +482,37 @@ class TosaSerializerTensor: b4 = (val_u64 >> np.uint64(32)) & ByteMask b5 = (val_u64 >> np.uint64(40)) & ByteMask u8_data.extend([b0, b1, b2, b3, b4, b5]) - elif self.dtype == DType.FLOAT: + elif self.dtype == DType.SHAPE: for val in self.data: - b = struct.pack("!f", val) - u8_data.extend([b[3], b[2], b[1], b[0]]) + val_u64 = np.uint64(val) + b0 = val_u64 & ByteMask + b1 = (val_u64 >> np.uint64(8)) & ByteMask + b2 = (val_u64 >> np.uint64(16)) & ByteMask + b3 = (val_u64 >> np.uint64(24)) & ByteMask + b4 = (val_u64 >> np.uint64(32)) & ByteMask + b5 = (val_u64 >> np.uint64(40)) & ByteMask + b6 = (val_u64 >> np.uint64(48)) & ByteMask + b7 = (val_u64 >> np.uint64(56)) & ByteMask + u8_data.extend([b0, b1, b2, b3, b4, b5, b6, b7]) + elif self.dtype == DType.FP16: + np_arr = np.array(self.data, dtype=np.float16) + u8_data.extend(np_arr.view(np.uint8)) + elif ( + self.dtype == DType.FP32 + or self.dtype == DType.BF16 + or self.dtype == DType.FP8E4M3 + or self.dtype == DType.FP8E5M2 + ): + # for val in self.data: + # b = struct.pack("!f", val) + # u8_data.extend([b[3], b[2], b[1], b[0]]) + np_arr = np.array(self.data, dtype=np.float32) + u8_data.extend(np_arr.view(np.uint8)) + elif self.dtype == TosaDType.DType: + # Serialize DType enum data as uint8 bytes + for val in self.data: + np_arr = np.array(self.data, dtype=np.uint32) + u8_data.extend(np_arr.view(np.uint8)) else: raise Exception( "unsupported data type {}".format(DTypeNames[self.dtype]) @@ -469,14 +537,14 @@ class TosaSerializerOperator: self.outputs = TosaSerializer.toList(outputs) def __str__(self): - str = "Op {}\n----\n".format(self.op) + concatString = "Op {}\n----\n".format(self.op) for i in self.inputs: - str = str + " Input: {}\n".format(i) + concatString = concatString + " Input: {}\n".format(i) for o in self.outputs: - str = str + " Output: {}\n".format(o) + concatString = concatString + " Output: {}\n".format(o) - return str + return concatString def serialize(self, builder): fb_inputs = TosaSerializer.serializeStrVec( @@ -561,41 +629,39 @@ class TosaSerializerBasicBlock: return TosaBasicBlock.End(builder) +# How CONSTs are treated in the flatbuffer @unique -class TensorDir(IntEnum): - PLACEHOLDER = 0 - CONST = 1 - INTERMEDIATE = 2 - RESULT = 3 - +class ConstMode(IntEnum): + EMBED = 0 + EMBED_DUMP = 1 + INPUTS = 2 -class TosaSerializer: - def __init__(self, pathPrefix): - self.add_compat_methods() - # Get the global TOSA version if not already defined - - self.builder = flatbuffers.Builder(0) +class TosaSerializerRegion: + def __init__(self, name, pathPrefix, constMode=ConstMode.EMBED): + self.name = name self.basicBlocks = [] - self.startBasicBlock("main") - self.pathPrefix = pathPrefix - - # Indicies used for adding/naming tensors self.currInputIdx = 0 self.currConstIdx = 0 self.currLayerIdx = 1 self.currResultIdx = 0 + self.pathPrefix = pathPrefix + self.constMode = constMode - # Is this an illegal test that is expected to fail? - self.expectedReturnCode = 0 - self.expectedFailure = False - self.expectedFailureDesc = "" + def addBasicBlock(self, name): + self.currBasicBlock = TosaSerializerBasicBlock(name) + self.basicBlocks.append(self.currBasicBlock) - def __str__(self): - str = "" - for bb in self.basicBlocks: - str = str + bb.__str__() - return str + def serialize(self, builder): + fb_name = builder.CreateString(self.name) + fbv_basicBlocks = TosaSerializer.serializeObjVec( + builder, self.basicBlocks, TosaRegion.StartBlocksVector + ) + + TosaRegion.Start(builder) + TosaRegion.AddName(builder, fb_name) + TosaRegion.AddBlocks(builder, fbv_basicBlocks) + return TosaRegion.End(builder) def addPlaceholder(self, shape, dtype, vals): if not self.currBasicBlock: @@ -614,21 +680,42 @@ class TosaSerializer: return tens - def addConst(self, shape, dtype, vals): + def addConst(self, shape, dtype, vals, name=None): if not self.currBasicBlock: raise Exception("addTensor called without valid basic block") - name = "const-{}".format(self.currInputIdx) - self.currInputIdx = self.currInputIdx + 1 + if name is None: + name = "const-{}".format(self.currInputIdx) + self.currInputIdx = self.currInputIdx + 1 + + if self.constMode == ConstMode.INPUTS: + # Save const as input file + filename = "{}.npy".format(name) + tensor_vals = None + self.currBasicBlock.addInput(name) + else: + # Embed const in flatbuffer + filename = None + tensor_vals = vals - tens = self.currBasicBlock.addTensor(name, shape, dtype, vals) + tens = self.currBasicBlock.addTensor(name, shape, dtype, tensor_vals, filename) # Add the operator now - self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name) + if dtype == DType.SHAPE: + self.currBasicBlock.addOperator(TosaOp.Op().CONST_SHAPE, [], name) + else: + self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name) + + # Save the const data to file for debug or as input files + if vals is not None and self.constMode in [ + ConstMode.EMBED_DUMP, + ConstMode.INPUTS, + ]: + filename = "{}.npy".format(name) + np.save(os.path.join(self.pathPrefix, filename), vals, False) return tens def addIntermediate(self, shape, dtype): - if not self.currBasicBlock: raise Exception("addTensor called without valid basic block") @@ -640,7 +727,13 @@ class TosaSerializer: return tens def addInputTensor(self, tensor): - self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype) + self.currBasicBlock.addTensor( + tensor.name, + tensor.shape, + tensor.dtype, + tensor.data, + tensor.placeholderFilename, + ) self.currBasicBlock.addInput(tensor.name) def addOutputTensor(self, tensor): @@ -658,7 +751,6 @@ class TosaSerializer: return tens def addOperator(self, op, inputs, outputs, attributes=None): - if op == TosaOp.Op().CONST: raise Exception("Use addConstTensor() to add CONST ops") @@ -669,6 +761,62 @@ class TosaSerializer: attributes, ) + +@unique +class TensorDir(IntEnum): + PLACEHOLDER = 0 + CONST = 1 + INTERMEDIATE = 2 + RESULT = 3 + + +class TosaSerializer: + def __init__(self, pathPrefix, constMode=ConstMode.EMBED): + self.builder = flatbuffers.Builder(0) + + # Enables inspection of constant data outside of graph + self.constMode = constMode + + self.regions = [] + self.startRegion("main", pathPrefix) + + self.currRegion.addBasicBlock("main") + + # Is this an illegal test that is expected to fail? + self.expectedReturnCode = 0 + self.expectedFailure = False + self.expectedFailureDesc = "" + + def __str__(self): + concatString = "" + for region in self.regions: + concatString = concatString + str(region) + return concatString + + def addPlaceholder(self, shape, dtype, vals): + return self.currRegion.addPlaceholder(shape, dtype, vals) + + def addConst(self, shape, dtype, vals, name=None): + return self.currRegion.addConst(shape, dtype, vals, name) + + def addIntermediate(self, shape, dtype): + return self.currRegion.addIntermediate(shape, dtype) + + def addInputTensor(self, tensor): + self.currRegion.addInputTensor(tensor) + + def addOutputTensor(self, tensor): + self.currRegion.addOutputTensor(tensor) + + def addOutput(self, shape, dtype): + return self.currRegion.addOutput(shape, dtype) + + def addOperator(self, op, inputs, outputs, attributes=None): + return self.currRegion.addOperator(op, inputs, outputs, attributes) + + def addBasicBlock(self, name): + self.currRegion.addBasicBlock(name) + def setExpectedReturnCode(self, val, fail, desc=""): self.expectedReturnCode = val @@ -680,19 +828,19 @@ class TosaSerializer: builder = self.builder Version.Start(builder) - Version.Add_major(builder, TOSA_VERSION[0]) - Version.Add_minor(builder, TOSA_VERSION[1]) - Version.Add_patch(builder, TOSA_VERSION[2]) - Version.Add_draft(builder, TOSA_VERSION[3]) + Version.Add_Major(builder, TOSA_VERSION[0]) + Version.Add_Minor(builder, TOSA_VERSION[1]) + Version.Add_Patch(builder, TOSA_VERSION[2]) + Version.Add_Draft(builder, TOSA_VERSION[3]) version = Version.End(builder) - fbv_bb = TosaSerializer.serializeObjVec( - builder, self.basicBlocks, TosaGraph.StartBlocksVector + fbv_region = TosaSerializer.serializeObjVec( + builder, self.regions, TosaGraph.StartRegionsVector ) TosaGraph.Start(builder) TosaGraph.AddVersion(builder, version) - TosaGraph.AddBlocks(builder, fbv_bb) + TosaGraph.AddRegions(builder, fbv_region) graph = TosaGraph.End(builder) self.builder.Finish(graph, TOSA_GRAPH_IDENTIFIER) @@ -709,16 +857,17 @@ class TosaSerializer: ofm_name = [] ofm_file = [] - for b in self.basicBlocks: - if b.name == "main": - for i in b.inputs: - ifm_name.append(i) - ifm_file.append(b.tensors[i].placeholderFilename) - for o in b.outputs: - ofm_name.append(o) - # Make up an OFM filename here. One isn't generated until the - # reference tool is run, so any name is a good name - ofm_file.append("ref-{}.npy".format(o)) + for region in self.regions: + for block in region.basicBlocks: + if block and block.name == "main": + for i in block.inputs: + ifm_name.append(i) + ifm_file.append(block.tensors[i].placeholderFilename) + for o in block.outputs: + ofm_name.append(o) + # Make up an OFM filename here. One isn't generated until the + # reference tool is run, so any name is a good name + ofm_file.append("ref-{}.npy".format(o)) test_desc["ifm_name"] = ifm_name test_desc["ifm_file"] = ifm_file @@ -731,9 +880,9 @@ class TosaSerializer: return json.dumps(test_desc, indent=" ") - def startBasicBlock(self, name): - self.currBasicBlock = TosaSerializerBasicBlock(name) - self.basicBlocks.append(self.currBasicBlock) + def startRegion(self, name, pathPrefix): + self.currRegion = TosaSerializerRegion(name, pathPrefix, self.constMode) + self.regions.append(self.currRegion) @staticmethod def serializeStrVec(builder, vec, start_fcn): @@ -757,6 +906,16 @@ class TosaSerializer: return builder.EndVector(len(vec)) @staticmethod + def serializeInt16Vec(builder, vec): + builder.StartVector(2, len(vec), 4) + for v in vec[::-1]: + builder.PrependInt16(v) + try: + return builder.EndVector() + except TypeError: + return builder.EndVector(len(vec)) + + @staticmethod def serializeInt32Vec(builder, vec): builder.StartVector(4, len(vec), 4) for v in vec[::-1]: @@ -796,370 +955,3 @@ class TosaSerializer: return val else: return [val] - - # Remove when switching to flatbuffers 2.0 - # contains a mapping of the deprecated 1.12 method to the 2.0 version - - def add_compat_methods(self): - - from tosa import ArithmeticRightShiftAttribute - - if not hasattr(ArithmeticRightShiftAttribute, "Start"): - ArithmeticRightShiftAttribute.Start = ( - ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeStart - ) - ArithmeticRightShiftAttribute.AddRound = ( - ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeAddRound - ) - ArithmeticRightShiftAttribute.End = ( - ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeEnd - ) - from tosa import AxisAttribute - - if not hasattr(AxisAttribute, "Start"): - AxisAttribute.Start = AxisAttribute.AxisAttributeStart - AxisAttribute.AddAxis = AxisAttribute.AxisAttributeAddAxis - AxisAttribute.End = AxisAttribute.AxisAttributeEnd - from tosa import ClampAttribute - - if not hasattr(ClampAttribute, "Start"): - ClampAttribute.Start = ClampAttribute.ClampAttributeStart - ClampAttribute.AddMinInt = ClampAttribute.ClampAttributeAddMinInt - ClampAttribute.AddMaxInt = ClampAttribute.ClampAttributeAddMaxInt - ClampAttribute.AddMinFp = ClampAttribute.ClampAttributeAddMinFp - ClampAttribute.AddMaxFp = ClampAttribute.ClampAttributeAddMaxFp - ClampAttribute.End = ClampAttribute.ClampAttributeEnd - from tosa import CondIfAttribute - - if not hasattr(CondIfAttribute, "Start"): - CondIfAttribute.Start = CondIfAttribute.CondIfAttributeStart - CondIfAttribute.AddThenBranch = CondIfAttribute.CondIfAttributeAddThenBranch - CondIfAttribute.AddElseBranch = CondIfAttribute.CondIfAttributeAddElseBranch - CondIfAttribute.End = CondIfAttribute.CondIfAttributeEnd - from tosa import ConvAttribute - - if not hasattr(ConvAttribute, "Start"): - ConvAttribute.Start = ConvAttribute.ConvAttributeStart - ConvAttribute.AddPad = ConvAttribute.ConvAttributeAddPad - ConvAttribute.StartPadVector = ConvAttribute.ConvAttributeStartPadVector - ConvAttribute.AddStride = ConvAttribute.ConvAttributeAddStride - ConvAttribute.StartStrideVector = ( - ConvAttribute.ConvAttributeStartStrideVector - ) - ConvAttribute.AddDilation = ConvAttribute.ConvAttributeAddDilation - ConvAttribute.StartDilationVector = ( - ConvAttribute.ConvAttributeStartDilationVector - ) - ConvAttribute.AddInputZp = ConvAttribute.ConvAttributeAddInputZp - ConvAttribute.AddWeightZp = ConvAttribute.ConvAttributeAddWeightZp - ConvAttribute.End = ConvAttribute.ConvAttributeEnd - from tosa import FullyConnectedAttribute - - if not hasattr(FullyConnectedAttribute, "Start"): - FullyConnectedAttribute.Start = ( - FullyConnectedAttribute.FullyConnectedAttributeStart - ) - FullyConnectedAttribute.AddInputZp = ( - FullyConnectedAttribute.FullyConnectedAttributeAddInputZp - ) - FullyConnectedAttribute.AddWeightZp = ( - FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp - ) - FullyConnectedAttribute.End = ( - FullyConnectedAttribute.FullyConnectedAttributeEnd - ) - from tosa import MatMulAttribute - - if not hasattr(MatMulAttribute, "Start"): - MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart - MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp - MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp - MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd - from tosa import PoolAttribute - - if not hasattr(PoolAttribute, "Start"): - PoolAttribute.Start = PoolAttribute.PoolAttributeStart - PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad - PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector - PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel - PoolAttribute.StartKernelVector = ( - PoolAttribute.PoolAttributeStartKernelVector - ) - PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride - PoolAttribute.StartStrideVector = ( - PoolAttribute.PoolAttributeStartStrideVector - ) - PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp - PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp - PoolAttribute.End = PoolAttribute.PoolAttributeEnd - from tosa import MulAttribute - - if not hasattr(MulAttribute, "Start"): - MulAttribute.Start = MulAttribute.MulAttributeStart - MulAttribute.AddShift = MulAttribute.MulAttributeAddShift - MulAttribute.End = MulAttribute.MulAttributeEnd - from tosa import PadAttribute - - if not hasattr(PadAttribute, "Start"): - PadAttribute.Start = PadAttribute.PadAttributeStart - PadAttribute.AddPadding = PadAttribute.PadAttributeAddPadding - PadAttribute.StartPaddingVector = ( - PadAttribute.PadAttributeStartPaddingVector - ) - PadAttribute.AddPadConstInt = PadAttribute.PadAttributeAddPadConstInt - PadAttribute.AddPadConstFp = PadAttribute.PadAttributeAddPadConstFp - PadAttribute.End = PadAttribute.PadAttributeEnd - from tosa import PoolAttribute - - if not hasattr(PoolAttribute, "Start"): - PoolAttribute.Start = PoolAttribute.PoolAttributeStart - PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad - PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector - PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel - PoolAttribute.StartKernelVector = ( - PoolAttribute.PoolAttributeStartKernelVector - ) - PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride - PoolAttribute.StartStrideVector = ( - PoolAttribute.PoolAttributeStartStrideVector - ) - PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp - PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp - PoolAttribute.End = PoolAttribute.PoolAttributeEnd - from tosa import RescaleAttribute - - if not hasattr(RescaleAttribute, "Start"): - RescaleAttribute.Start = RescaleAttribute.RescaleAttributeStart - RescaleAttribute.AddInputZp = RescaleAttribute.RescaleAttributeAddInputZp - RescaleAttribute.AddOutputZp = RescaleAttribute.RescaleAttributeAddOutputZp - RescaleAttribute.AddMultiplier = ( - RescaleAttribute.RescaleAttributeAddMultiplier - ) - RescaleAttribute.StartMultiplierVector = ( - RescaleAttribute.RescaleAttributeStartMultiplierVector - ) - RescaleAttribute.AddShift = RescaleAttribute.RescaleAttributeAddShift - RescaleAttribute.StartShiftVector = ( - RescaleAttribute.RescaleAttributeStartShiftVector - ) - RescaleAttribute.AddScale32 = RescaleAttribute.RescaleAttributeAddScale32 - RescaleAttribute.AddDoubleRound = ( - RescaleAttribute.RescaleAttributeAddDoubleRound - ) - RescaleAttribute.AddPerChannel = ( - RescaleAttribute.RescaleAttributeAddPerChannel - ) - RescaleAttribute.End = RescaleAttribute.RescaleAttributeEnd - from tosa import ReshapeAttribute - - if not hasattr(ReshapeAttribute, "Start"): - ReshapeAttribute.Start = ReshapeAttribute.ReshapeAttributeStart - ReshapeAttribute.AddNewShape = ReshapeAttribute.ReshapeAttributeAddNewShape - ReshapeAttribute.StartNewShapeVector = ( - ReshapeAttribute.ReshapeAttributeStartNewShapeVector - ) - ReshapeAttribute.End = ReshapeAttribute.ReshapeAttributeEnd - from tosa import ResizeAttribute - - if not hasattr(ResizeAttribute, "Start"): - ResizeAttribute.Start = ResizeAttribute.ResizeAttributeStart - ResizeAttribute.AddOutputSize = ResizeAttribute.ResizeAttributeAddOutputSize - ResizeAttribute.StartOutputSizeVector = ( - ResizeAttribute.ResizeAttributeStartOutputSizeVector - ) - ResizeAttribute.AddStride = ResizeAttribute.ResizeAttributeAddStride - ResizeAttribute.StartStrideVector = ( - ResizeAttribute.ResizeAttributeStartStrideVector - ) - ResizeAttribute.AddOffset = ResizeAttribute.ResizeAttributeAddOffset - ResizeAttribute.StartOffsetVector = ( - ResizeAttribute.ResizeAttributeStartOffsetVector - ) - ResizeAttribute.AddShift = ResizeAttribute.ResizeAttributeAddShift - ResizeAttribute.AddStrideFp = ResizeAttribute.ResizeAttributeAddStrideFp - ResizeAttribute.StartStrideFpVector = ( - ResizeAttribute.ResizeAttributeStartStrideFpVector - ) - ResizeAttribute.AddOffsetFp = ResizeAttribute.ResizeAttributeAddOffsetFp - ResizeAttribute.StartOffsetFpVector = ( - ResizeAttribute.ResizeAttributeStartOffsetFpVector - ) - ResizeAttribute.AddMode = ResizeAttribute.ResizeAttributeAddMode - ResizeAttribute.End = ResizeAttribute.ResizeAttributeEnd - from tosa import SliceAttribute - - if not hasattr(SliceAttribute, "Start"): - SliceAttribute.Start = SliceAttribute.SliceAttributeStart - SliceAttribute.AddStart = SliceAttribute.SliceAttributeAddStart - SliceAttribute.StartStartVector = ( - SliceAttribute.SliceAttributeStartStartVector - ) - SliceAttribute.AddSize = SliceAttribute.SliceAttributeAddSize - SliceAttribute.StartSizeVector = ( - SliceAttribute.SliceAttributeStartSizeVector - ) - SliceAttribute.End = SliceAttribute.SliceAttributeEnd - from tosa import TableAttribute - - if not hasattr(TableAttribute, "Start"): - TableAttribute.Start = TableAttribute.TableAttributeStart - TableAttribute.AddTable = TableAttribute.TableAttributeAddTable - TableAttribute.StartTableVector = ( - TableAttribute.TableAttributeStartTableVector - ) - TableAttribute.End = TableAttribute.TableAttributeEnd - from tosa import TileAttribute - - if not hasattr(TileAttribute, "Start"): - TileAttribute.Start = TileAttribute.TileAttributeStart - TileAttribute.AddMultiples = TileAttribute.TileAttributeAddMultiples - TileAttribute.StartMultiplesVector = ( - TileAttribute.TileAttributeStartMultiplesVector - ) - TileAttribute.End = TileAttribute.TileAttributeEnd - from tosa import TosaBasicBlock - - if not hasattr(TosaBasicBlock, "Start"): - TosaBasicBlock.Start = TosaBasicBlock.TosaBasicBlockStart - TosaBasicBlock.AddName = TosaBasicBlock.TosaBasicBlockAddName - TosaBasicBlock.AddOperators = TosaBasicBlock.TosaBasicBlockAddOperators - TosaBasicBlock.StartOperatorsVector = ( - TosaBasicBlock.TosaBasicBlockStartOperatorsVector - ) - TosaBasicBlock.AddTensors = TosaBasicBlock.TosaBasicBlockAddTensors - TosaBasicBlock.StartTensorsVector = ( - TosaBasicBlock.TosaBasicBlockStartTensorsVector - ) - TosaBasicBlock.AddInputs = TosaBasicBlock.TosaBasicBlockAddInputs - TosaBasicBlock.StartInputsVector = ( - TosaBasicBlock.TosaBasicBlockStartInputsVector - ) - TosaBasicBlock.AddOutputs = TosaBasicBlock.TosaBasicBlockAddOutputs - TosaBasicBlock.StartOutputsVector = ( - TosaBasicBlock.TosaBasicBlockStartOutputsVector - ) - TosaBasicBlock.End = TosaBasicBlock.TosaBasicBlockEnd - from tosa import TosaGraph - - if not hasattr(TosaGraph, "Start"): - TosaGraph.Start = TosaGraph.TosaGraphStart - TosaGraph.AddVersion = TosaGraph.TosaGraphAddVersion - TosaGraph.AddBlocks = TosaGraph.TosaGraphAddBlocks - TosaGraph.StartBlocksVector = TosaGraph.TosaGraphStartBlocksVector - TosaGraph.End = TosaGraph.TosaGraphEnd - from tosa import TosaOperator - - if not hasattr(TosaOperator, "Start"): - TosaOperator.Start = TosaOperator.TosaOperatorStart - TosaOperator.AddOp = TosaOperator.TosaOperatorAddOp - TosaOperator.AddAttributeType = TosaOperator.TosaOperatorAddAttributeType - TosaOperator.AddAttribute = TosaOperator.TosaOperatorAddAttribute - TosaOperator.AddInputs = TosaOperator.TosaOperatorAddInputs - TosaOperator.StartInputsVector = TosaOperator.TosaOperatorStartInputsVector - TosaOperator.AddOutputs = TosaOperator.TosaOperatorAddOutputs - TosaOperator.StartOutputsVector = ( - TosaOperator.TosaOperatorStartOutputsVector - ) - TosaOperator.End = TosaOperator.TosaOperatorEnd - from tosa import TosaTensor - - if not hasattr(TosaTensor, "Start"): - TosaTensor.Start = TosaTensor.TosaTensorStart - TosaTensor.AddName = TosaTensor.TosaTensorAddName - TosaTensor.AddShape = TosaTensor.TosaTensorAddShape - TosaTensor.StartShapeVector = TosaTensor.TosaTensorStartShapeVector - TosaTensor.AddType = TosaTensor.TosaTensorAddType - TosaTensor.AddData = TosaTensor.TosaTensorAddData - TosaTensor.StartDataVector = TosaTensor.TosaTensorStartDataVector - TosaTensor.End = TosaTensor.TosaTensorEnd - from tosa import TransposeAttribute - - if not hasattr(TransposeAttribute, "Start"): - TransposeAttribute.Start = TransposeAttribute.TransposeAttributeStart - TransposeAttribute.AddPerms = TransposeAttribute.TransposeAttributeAddPerms - TransposeAttribute.StartPermsVector = ( - TransposeAttribute.TransposeAttributeStartPermsVector - ) - TransposeAttribute.End = TransposeAttribute.TransposeAttributeEnd - from tosa import TransposeConvAttribute - - if not hasattr(TransposeConvAttribute, "Start"): - TransposeConvAttribute.Start = ( - TransposeConvAttribute.TransposeConvAttributeStart - ) - TransposeConvAttribute.AddOutPad = ( - TransposeConvAttribute.TransposeConvAttributeAddOutPad - ) - TransposeConvAttribute.StartOutPadVector = ( - TransposeConvAttribute.TransposeConvAttributeStartOutPadVector - ) - TransposeConvAttribute.AddStride = ( - TransposeConvAttribute.TransposeConvAttributeAddStride - ) - TransposeConvAttribute.StartStrideVector = ( - TransposeConvAttribute.TransposeConvAttributeStartStrideVector - ) - TransposeConvAttribute.AddOutputShape = ( - TransposeConvAttribute.TransposeConvAttributeAddOutputShape - ) - TransposeConvAttribute.StartOutputShapeVector = ( - TransposeConvAttribute.TransposeConvAttributeStartOutputShapeVector - ) - TransposeConvAttribute.AddInputZp = ( - TransposeConvAttribute.TransposeConvAttributeAddInputZp - ) - TransposeConvAttribute.AddWeightZp = ( - TransposeConvAttribute.TransposeConvAttributeAddWeightZp - ) - TransposeConvAttribute.End = ( - TransposeConvAttribute.TransposeConvAttributeEnd - ) - from tosa import Version - - if not hasattr(Version, "Start"): - Version.Start = Version.VersionStart - Version.Add_major = Version.VersionAdd_major - Version.Add_minor = Version.VersionAdd_minor - Version.Add_patch = Version.VersionAdd_patch - Version.Add_draft = Version.VersionAdd_draft - Version.End = Version.VersionEnd - from tosa import MatMulAttribute - - if not hasattr(MatMulAttribute, "Start"): - MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart - MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp - MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp - MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd - from tosa import FullyConnectedAttribute - - if not hasattr(FullyConnectedAttribute, "Start"): - FullyConnectedAttribute.Start = ( - FullyConnectedAttribute.FullyConnectedAttributeStart - ) - FullyConnectedAttribute.AddInputZp = ( - FullyConnectedAttribute.FullyConnectedAttributeAddInputZp - ) - FullyConnectedAttribute.AddWeightZp = ( - FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp - ) - FullyConnectedAttribute.End = ( - FullyConnectedAttribute.FullyConnectedAttributeEnd - ) - from tosa import NegateAttribute - - if not hasattr(NegateAttribute, "Start"): - NegateAttribute.Start = NegateAttribute.NegateAttributeStart - NegateAttribute.AddInput1Zp = NegateAttribute.NegateAttributeAddInput1Zp - NegateAttribute.AddOutputZp = NegateAttribute.NegateAttributeAddOutputZp - NegateAttribute.End = NegateAttribute.NegateAttributeEnd - from tosa import WhileLoopAttribute - - if not hasattr(WhileLoopAttribute, "Start"): - WhileLoopAttribute.Start = WhileLoopAttribute.WhileLoopAttributeStart - WhileLoopAttribute.AddCondBranch = ( - WhileLoopAttribute.WhileLoopAttributeAddCondBranch - ) - WhileLoopAttribute.AddBodyBranch = ( - WhileLoopAttribute.WhileLoopAttributeAddBodyBranch - ) - WhileLoopAttribute.End = WhileLoopAttribute.WhileLoopAttributeEnd |