diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-03-03 11:21:43 -0800 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-04-27 16:01:59 -0700 |
commit | 550ccc52de231621c0bf0c05ae2a398eec37ff51 (patch) | |
tree | d4a5bd8d24560135784208c0fe35615b1d043249 /verif/tosa_serializer.py | |
parent | cf6224e6e8ba4fc2984de3e542538c38e27c9f57 (diff) | |
download | reference_model-550ccc52de231621c0bf0c05ae2a398eec37ff51.tar.gz |
Replace serialization/ and verif/ with MLPlatform's serialization_lib submodule
- Remove Usage and Format
- Run black on verif/*.py scripts
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ie81515891eb0039540f614894f4b6b0e0e78ba74
Diffstat (limited to 'verif/tosa_serializer.py')
-rw-r--r-- | verif/tosa_serializer.py | 405 |
1 files changed, 204 insertions, 201 deletions
diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py index 136f7aa..fa1fdcb 100644 --- a/verif/tosa_serializer.py +++ b/verif/tosa_serializer.py @@ -1,5 +1,3 @@ - - # Copyright (c) 2020-2021, ARM Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,37 +14,57 @@ #!/usr/bin/env python3 +import os +import sys +import json import flatbuffers import numpy as np from enum import Enum, IntEnum, unique -from tosa import TosaGraph, TosaBasicBlock, TosaTensor, TosaOperator, DType, Format, Usage, Op, ResizeMode, Version +from tosa import ( + TosaGraph, + TosaBasicBlock, + TosaTensor, + TosaOperator, + DType, + Op, + ResizeMode, + Version, +) + +# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH +parent_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.append( + os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python") +) import tosa -import os -import json # With the way flatc generates its python types, there is no programatic way # to get string names for the integer types. Manually maintain a string table # here. -DTypeNames = [ 'UNKNOWN', - 'BOOL', - 'UINT8', - 'INT4', - 'INT8', - 'INT16', - 'INT32', - 'INT48', - 'FLOAT' ] +DTypeNames = [ + "UNKNOWN", + "BOOL", + "UINT8", + "INT4", + "INT8", + "INT16", + "INT32", + "INT48", + "FLOAT", +] + def dtype_str_to_val(name): for i in range(len(DTypeNames)): if name.casefold() == DTypeNames[i].casefold(): return i - raise Exception('Unable to parse DType name {}'.format(name)) + raise Exception("Unable to parse DType name {}".format(name)) class TosaSerializerUnion: - '''This class handles encapsulating and serializing union types into flatbuffers''' + """This class handles encapsulating and serializing union types into flatbuffers""" + def __init__(self): # A tuple of the start and end functions. Set by the options constructors below @@ -105,8 +123,9 @@ class TosaSerializerUnion: return endFcn(builder) + class TosaSerializerAttribute(TosaSerializerUnion): - '''This class handles encapsulating all of the enumerated types for attributes''' + """This class handles encapsulating all of the enumerated types for attributes""" def __init__(self): super().__init__() @@ -117,12 +136,9 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.utype = Attribute.Attribute().Pool2dAttribute self.optFcns = (a.Pool2dAttributeStart, a.Pool2dAttributeEnd) - self.intvecs.append((a.Pool2dAttributeAddPadding, - padding)) - self.intvecs.append((a.Pool2dAttributeAddKernel, - kernel)) - self.intvecs.append((a.Pool2dAttributeAddStride, - stride)) + self.intvecs.append((a.Pool2dAttributeAddPadding, padding)) + self.intvecs.append((a.Pool2dAttributeAddKernel, kernel)) + self.intvecs.append((a.Pool2dAttributeAddStride, stride)) def Conv2dAttribute(self, padding, stride, dilation): from tosa import Conv2dAttribute as a, Attribute @@ -130,12 +146,9 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.utype = Attribute.Attribute().Conv2dAttribute self.optFcns = (a.Conv2dAttributeStart, a.Conv2dAttributeEnd) - self.intvecs.append((a.Conv2dAttributeAddPadding, - padding)) - self.intvecs.append((a.Conv2dAttributeAddStride, - stride)) - self.intvecs.append((a.Conv2dAttributeAddDilation, - dilation)) + self.intvecs.append((a.Conv2dAttributeAddPadding, padding)) + self.intvecs.append((a.Conv2dAttributeAddStride, stride)) + self.intvecs.append((a.Conv2dAttributeAddDilation, dilation)) def TransposeConv2DAttribute(self, outpad, stride, dilation, output_shape): from tosa import TransposeConv2dAttribute as a, Attribute @@ -143,14 +156,10 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.utype = Attribute.Attribute().TransposeConv2dAttribute self.optFcns = (a.TransposeConv2dAttributeStart, a.TransposeConv2dAttributeEnd) - self.intvecs.append((a.TransposeConv2dAttributeAddOutpad, - outpad)) - self.intvecs.append((a.TransposeConv2dAttributeAddStride, - stride)) - self.intvecs.append((a.TransposeConv2dAttributeAddDilation, - dilation)) - self.intvecs.append((a.TransposeConv2dAttributeAddOutputShape, - output_shape)) + self.intvecs.append((a.TransposeConv2dAttributeAddOutpad, outpad)) + self.intvecs.append((a.TransposeConv2dAttributeAddStride, stride)) + self.intvecs.append((a.TransposeConv2dAttributeAddDilation, dilation)) + self.intvecs.append((a.TransposeConv2dAttributeAddOutputShape, output_shape)) def ReluNAttribute(self, maxint, maxfp): from tosa import ReluNAttribute as a, Attribute @@ -161,15 +170,13 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.ReluNAttributeAddMaxInt, maxint)) self.ints.append((a.ReluNAttributeAddMaxFp, maxfp)) - def AxisAttribute(self, axis): from tosa import AxisAttribute as a, Attribute self.utype = Attribute.Attribute().AxisAttribute self.optFcns = (a.AxisAttributeStart, a.AxisAttributeEnd) - self.ints.append((a.AxisAttributeAddAxis, - axis)) + self.ints.append((a.AxisAttributeAddAxis, axis)) def ReshapeAttribute(self, shape): from tosa import ReshapeAttribute as a, Attribute @@ -177,8 +184,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.utype = Attribute.Attribute().ReshapeAttribute self.optFcns = (a.ReshapeAttributeStart, a.ReshapeAttributeEnd) - self.intvecs.append((a.ReshapeAttributeAddShape, - shape)) + self.intvecs.append((a.ReshapeAttributeAddShape, shape)) def SliceAttribute(self, begin, size): from tosa import SliceAttribute as a, Attribute @@ -186,10 +192,8 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.utype = Attribute.Attribute().SliceAttribute self.optFcns = (a.SliceAttributeStart, a.SliceAttributeEnd) - self.intvecs.append((a.SliceAttributeAddBegin, - begin)) - self.intvecs.append((a.SliceAttributeAddSize, - size)) + self.intvecs.append((a.SliceAttributeAddBegin, begin)) + self.intvecs.append((a.SliceAttributeAddSize, size)) def TileAttribute(self, multiples): from tosa import TileAttribute as a, Attribute @@ -197,29 +201,23 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.utype = Attribute.Attribute().TileAttribute self.optFcns = (a.TileAttributeStart, a.TileAttributeEnd) - self.intvecs.append((a.TileAttributeAddMultiples, - multiples)) + self.intvecs.append((a.TileAttributeAddMultiples, multiples)) - def ResizeAttribute(self, output_size, stride, offset, shift, stride_fp, offset_fp, mode): + def ResizeAttribute( + self, output_size, stride, offset, shift, stride_fp, offset_fp, mode + ): from tosa import ResizeAttribute as a, Attribute self.utype = Attribute.Attribute().ResizeAttribute self.optFcns = (a.ResizeAttributeStart, a.ResizeAttributeEnd) - self.intvecs.append((a.ResizeAttributeAddOutputSize, - output_size)) - self.intvecs.append((a.ResizeAttributeAddStride, - stride)) - self.intvecs.append((a.ResizeAttributeAddOffset, - offset)) - self.ints.append((a.ResizeAttributeAddShift, - shift)) - self.fpvecs.append((a.ResizeAttributeAddStrideFp, - stride_fp)) - self.fpvecs.append((a.ResizeAttributeAddOffsetFp, - offset_fp)) - self.ints.append((a.ResizeAttributeAddMode, - mode)) + self.intvecs.append((a.ResizeAttributeAddOutputSize, output_size)) + self.intvecs.append((a.ResizeAttributeAddStride, stride)) + self.intvecs.append((a.ResizeAttributeAddOffset, offset)) + self.ints.append((a.ResizeAttributeAddShift, shift)) + self.fpvecs.append((a.ResizeAttributeAddStrideFp, stride_fp)) + self.fpvecs.append((a.ResizeAttributeAddOffsetFp, offset_fp)) + self.ints.append((a.ResizeAttributeAddMode, mode)) def ClampAttribute(self, minint, maxint, minfp, maxfp): from tosa import ClampAttribute as a, Attribute @@ -227,36 +225,27 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.utype = Attribute.Attribute().ClampAttribute self.optFcns = (a.ClampAttributeStart, a.ClampAttributeEnd) - self.ints.append((a.ClampAttributeAddMinInt, - minint)) - self.ints.append((a.ClampAttributeAddMaxInt, - maxint)) + self.ints.append((a.ClampAttributeAddMinInt, minint)) + self.ints.append((a.ClampAttributeAddMaxInt, maxint)) - self.ints.append((a.ClampAttributeAddMinFp, - minfp)) - self.ints.append((a.ClampAttributeAddMaxFp, - maxfp)) + self.ints.append((a.ClampAttributeAddMinFp, minfp)) + self.ints.append((a.ClampAttributeAddMaxFp, maxfp)) - def RescaleAttribute(self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel): + def RescaleAttribute( + self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel + ): from tosa import RescaleAttribute as a, Attribute self.utype = Attribute.Attribute().RescaleAttribute self.optFcns = (a.RescaleAttributeStart, a.RescaleAttributeEnd) - self.ints.append((a.RescaleAttributeAddInputZp, - input_zp)) - self.ints.append((a.RescaleAttributeAddOutputZp, - output_zp)) - self.intvecs.append((a.RescaleAttributeAddMultiplier, - multiplier)) - self.intvecs.append((a.RescaleAttributeAddShift, - shift)) - self.bools.append((a.RescaleAttributeAddScale32, - scale32)) - self.bools.append((a.RescaleAttributeAddDoubleRound, - double_round)) - self.bools.append((a.RescaleAttributeAddPerChannel, - per_channel)) + self.ints.append((a.RescaleAttributeAddInputZp, input_zp)) + self.ints.append((a.RescaleAttributeAddOutputZp, output_zp)) + self.intvecs.append((a.RescaleAttributeAddMultiplier, multiplier)) + self.intvecs.append((a.RescaleAttributeAddShift, shift)) + self.bools.append((a.RescaleAttributeAddScale32, scale32)) + self.bools.append((a.RescaleAttributeAddDoubleRound, double_round)) + self.bools.append((a.RescaleAttributeAddPerChannel, per_channel)) def MulAttribute(self, shift): from tosa import MulAttribute as a, Attribute @@ -264,17 +253,18 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.utype = Attribute.Attribute().MulAttribute self.optFcns = (a.MulAttributeStart, a.MulAttributeEnd) - self.ints.append((a.MulAttributeAddShift, - shift)) + self.ints.append((a.MulAttributeAddShift, shift)) def ArithmeticRightShiftAttribute(self, round): from tosa import ArithmeticRightShiftAttribute as a, Attribute self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute - self.optFcns = (a.ArithmeticRightShiftAttributeStart, a.ArithmeticRightShiftAttributeEnd) + self.optFcns = ( + a.ArithmeticRightShiftAttributeStart, + a.ArithmeticRightShiftAttributeEnd, + ) - self.bools.append((a.ArithmeticRightShiftAttributeAddRound, - round)) + self.bools.append((a.ArithmeticRightShiftAttributeAddRound, round)) def CustomAttribute(self, identifier): from tosa import CustomAttribute as a, Attribute @@ -282,8 +272,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.utype = Attribute.Attribute().CustomAttribute self.optFcns = (a.CustomAttributeStart, a.CustomAttributeEnd) - self.strings.append((a.CustomAttributeAddIdentifier, - identifier)) + self.strings.append((a.CustomAttributeAddIdentifier, identifier)) def CondIfAttribute(self, then_branch, else_branch): from tosa import CondIfAttribute as a, Attribute @@ -291,10 +280,8 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.utype = Attribute.Attribute().CondIfAttribute self.optFcns = (a.CondIfAttributeStart, a.CondIfAttributeEnd) - self.strings.append((a.CondIfAttributeAddThenBranch, - then_branch)) - self.strings.append((a.CondIfAttributeAddElseBranch, - else_branch)) + self.strings.append((a.CondIfAttributeAddThenBranch, then_branch)) + self.strings.append((a.CondIfAttributeAddElseBranch, else_branch)) def WhileLoopAttribute(self, cond_branch, body_branch): from tosa import WhileLoopAttribute as a, Attribute @@ -302,13 +289,13 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.utype = Attribute.Attribute().WhileLoopAttribute self.optFcns = (a.WhileLoopAttributeStart, a.WhileLoopAttributeEnd) - self.strings.append((a.WhileLoopAttributeAddCondBranch, - cond_branch)) - self.strings.append((a.WhileLoopAttributeAddBodyBranch, - body_branch)) + self.strings.append((a.WhileLoopAttributeAddCondBranch, cond_branch)) + self.strings.append((a.WhileLoopAttributeAddBodyBranch, body_branch)) + class TosaSerializerQuantInfo(TosaSerializerUnion): - '''This class handles encapsulating all of the enumerated types for quantinfo types''' + """This class handles encapsulating all of the enumerated types for quantinfo types""" + def __init__(self): super().__init__() @@ -343,8 +330,16 @@ class TosaSerializerQuantInfo(TosaSerializerUnion): self.optFcns = (q.PadQuantInfoStart, q.PadQuantInfoEnd) self.ints.append((q.PadQuantInfoAddInputZp, input_zp)) + class TosaSerializerTensor: - def __init__(self, name, shape, dtype, usage, dformat, filename = None, placeholderFilename = None): + def __init__( + self, + name, + shape, + dtype, + filename=None, + placeholderFilename=None, + ): self.name = name if isinstance(shape, np.ndarray): @@ -353,8 +348,6 @@ class TosaSerializerTensor: self.shape = shape self.dtype = dtype - self.usage = TosaSerializer.toList(usage) - self.dformat = TosaSerializer.toList(dformat) # Filename for const tensors. This gets written to the .tosa serialization self.filename = filename @@ -366,58 +359,35 @@ class TosaSerializerTensor: self.placeholderFilename = placeholderFilename def __str__(self): - str = 'TosaSerializerTensor name: {} shape: {} dtype: {} Usage: {} format {} filename: {}'.format( - self.name, self.shape, DTypeNames[self.dtype], self.usage, self.dformat, self.filename) + str = "TosaSerializerTensor name: {} shape: {} dtype: {} filename: {}".format( + self.name, + self.shape, + DTypeNames[self.dtype], + self.filename, + ) return str - def addUsage(self, usage): - self.usage.append(usage) - - def addFormat(self, format): - self.dformat.append(format) - def setDtype(self, dtype): self.dtype = dtype - def merge(self, name, shape, dtype, usage, dformat, filename = None): - # Merge in additional usage/formats to the list - found = 0 - for i in self.usage: - if i == usage: - found = 1 - break - if not found: - self.usage.append(usage) - - found = 0 - for i in self.dformat: - if i == dformat: - found = 1 - break - if not found: - self.dformat.append(dformat) - def serialize(self, builder): fb_name = builder.CreateString(self.name) if self.filename: fb_filename = builder.CreateString(self.filename) fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape) - fb_usage = TosaSerializer.serializeInt32Vec(builder, self.usage) - fb_dformat = TosaSerializer.serializeInt32Vec(builder, self.dformat) TosaTensor.TosaTensorStart(builder) TosaTensor.TosaTensorAddName(builder, fb_name) TosaTensor.TosaTensorAddShape(builder, fb_shapes) TosaTensor.TosaTensorAddType(builder, self.dtype) - TosaTensor.TosaTensorAddUsage(builder, fb_usage) - TosaTensor.TosaTensorAddFormat(builder, fb_dformat) if self.filename: TosaTensor.TosaTensorAddNpyFilename(builder, fb_filename) return TosaTensor.TosaTensorEnd(builder) + class TosaSerializerOperator: - def __init__(self, op, inputs, outputs, attributes = None, quantInfo = None): + def __init__(self, op, inputs, outputs, attributes=None, quantInfo=None): self.op = op self.attributes = attributes self.inputs = TosaSerializer.toList(inputs) @@ -425,18 +395,22 @@ class TosaSerializerOperator: self.quantInfo = quantInfo def __str__(self): - str = 'Op {}\n----\n'.format(self.op) + str = "Op {}\n----\n".format(self.op) for i in self.inputs: - str = str + ' Input: {}\n'.format(i) + str = str + " Input: {}\n".format(i) for o in self.outputs: - str = str + ' Output: {}\n'.format(o) + str = str + " Output: {}\n".format(o) return str def serialize(self, builder): - fb_inputs = TosaSerializer.serializeStrVec(builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector) - fb_outputs = TosaSerializer.serializeStrVec(builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector) + fb_inputs = TosaSerializer.serializeStrVec( + builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector + ) + fb_outputs = TosaSerializer.serializeStrVec( + builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector + ) # Need to serialize quant_info and attributes enums still if self.attributes is not None: fb_attributes = self.attributes.serialize(builder) @@ -457,6 +431,7 @@ class TosaSerializerOperator: return TosaOperator.TosaOperatorEnd(builder) + class TosaSerializerBasicBlock: def __init__(self, name): self.name = name @@ -468,14 +443,21 @@ class TosaSerializerBasicBlock: self.inputs = [] self.outputs = [] - def addTensor(self, name, shape, dtype, usage, dformat, filename = None, placeholderFilename = None): + def addTensor( + self, + name, + shape, + dtype, + filename=None, + placeholderFilename=None, + ): try: # Someone already added this tensor. - # We may have to add more usages and formats tens = self.tensors[name] - filename = tens.merge(name, shape, dtype, usage, dformat, filename) except KeyError: - self.tensors[name] = TosaSerializerTensor(name, shape, dtype, usage, dformat, filename, placeholderFilename) + self.tensors[name] = TosaSerializerTensor( + name, shape, dtype, filename, placeholderFilename + ) return self.tensors[name] @@ -485,15 +467,27 @@ class TosaSerializerBasicBlock: def addOutput(self, name): self.outputs.append(name) - def addOperator(self, op, inputs, outputs, attributes = None, quant_info = None): - self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes, quant_info)) + def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None): + self.operators.append( + TosaSerializerOperator(op, inputs, outputs, attributes, quant_info) + ) def serialize(self, builder): fb_name = builder.CreateString(self.name) - fbv_inputs = TosaSerializer.serializeStrVec(builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector) - fbv_outputs = TosaSerializer.serializeStrVec(builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector) - fbv_tensors = TosaSerializer.serializeObjVec(builder, list(self.tensors.values()), TosaBasicBlock.TosaBasicBlockStartTensorsVector) - fbv_operators = TosaSerializer.serializeObjVec(builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector) + fbv_inputs = TosaSerializer.serializeStrVec( + builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector + ) + fbv_outputs = TosaSerializer.serializeStrVec( + builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector + ) + fbv_tensors = TosaSerializer.serializeObjVec( + builder, + list(self.tensors.values()), + TosaBasicBlock.TosaBasicBlockStartTensorsVector, + ) + fbv_operators = TosaSerializer.serializeObjVec( + builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector + ) TosaBasicBlock.TosaBasicBlockStart(builder) TosaBasicBlock.TosaBasicBlockAddName(builder, fb_name) @@ -503,6 +497,7 @@ class TosaSerializerBasicBlock: TosaBasicBlock.TosaBasicBlockAddOperators(builder, fbv_operators) return TosaBasicBlock.TosaBasicBlockEnd(builder) + @unique class TensorDir(IntEnum): PLACEHOLDER = 0 @@ -510,6 +505,7 @@ class TensorDir(IntEnum): INTERMEDIATE = 2 RESULT = 3 + class TosaSerializer: def __init__(self, pathPrefix): @@ -522,7 +518,7 @@ class TosaSerializer: self.builder = flatbuffers.Builder(0) self.basicBlocks = [] - self.startBasicBlock('main') + self.startBasicBlock("main") self.pathPrefix = pathPrefix # Indicies used for adding/naming tensors @@ -533,23 +529,23 @@ class TosaSerializer: # Is this an illegal test that is expected to fail? self.expectedFailure = False - self.expectedFailureDesc = '' + self.expectedFailureDesc = "" def __str__(self): - str = '' + str = "" for bb in self.basicBlocks: str = str + bb.__str__() return str - def addPlaceholder(self, shape, dtype, usage, dformat, vals): + def addPlaceholder(self, shape, dtype, vals): if not self.currBasicBlock: - raise Exception('addTensor called without valid basic block') + raise Exception("addTensor called without valid basic block") - name = 'input-{}'.format(self.currInputIdx) - filename = '{}.npy'.format(name) + name = "input-{}".format(self.currInputIdx) + filename = "{}.npy".format(name) self.currInputIdx = self.currInputIdx + 1 - tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, None, filename) + tens = self.currBasicBlock.addTensor(name, shape, dtype, None, filename) # This is always an input to the block self.currBasicBlock.addInput(name) # Add the operator now @@ -560,15 +556,15 @@ class TosaSerializer: return tens - def addConst(self, shape, dtype, usage, dformat, vals): + def addConst(self, shape, dtype, vals): if not self.currBasicBlock: - raise Exception('addTensor called without valid basic block') + raise Exception("addTensor called without valid basic block") - name = 'const-{}'.format(self.currInputIdx) - filename = '{}.npy'.format(name) + name = "const-{}".format(self.currInputIdx) + filename = "{}.npy".format(name) self.currInputIdx = self.currInputIdx + 1 - tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, filename) + tens = self.currBasicBlock.addTensor(name, shape, dtype, filename) # Add the operator now self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name) @@ -576,51 +572,54 @@ class TosaSerializer: np.save(os.path.join(self.pathPrefix, filename), vals, False) return tens - def addIntermediate(self, shape, dtype, usage, dformat): + def addIntermediate(self, shape, dtype): if not self.currBasicBlock: - raise Exception('addTensor called without valid basic block') + raise Exception("addTensor called without valid basic block") - name = 'layer-{}'.format(self.currLayerIdx) - filename = None # No file, so no filename + name = "layer-{}".format(self.currLayerIdx) + filename = None # No file, so no filename self.currLayerIdx = self.currLayerIdx + 1 - tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, filename) + tens = self.currBasicBlock.addTensor(name, shape, dtype, filename) return tens def addInputTensor(self, tensor): self.currBasicBlock.addOperator(tosa.Op.Op().PLACEHOLDER, [], tensor.name) - self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype, tensor.usage, tensor.dformat) + self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype) self.currBasicBlock.addInput(tensor.name) def addOutputTensor(self, tensor): self.currBasicBlock.addOutput(tensor.name) - def addOutput(self, shape, dtype, usage, dformat): + def addOutput(self, shape, dtype): if not self.currBasicBlock: - raise Exception('addTensor called without valid basic block') + raise Exception("addTensor called without valid basic block") - name = 'result-{}'.format(self.currResultIdx) + name = "result-{}".format(self.currResultIdx) self.currResultIdx = self.currResultIdx + 1 - tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, None) + tens = self.currBasicBlock.addTensor(name, shape, dtype, None) self.currBasicBlock.addOutput(name) return tens - def addOperator(self, op, inputs, outputs, attributes = None, quant_info = None): + def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None): - if op == tosa.Op.Op().PLACEHOLDER or \ - op == tosa.Op.Op().CONST: - raise Exception('Use addPlaceholderTensor() or addConstTensor() to add PLACEHOLDER and CONST ops') + if op == tosa.Op.Op().PLACEHOLDER or op == tosa.Op.Op().CONST: + raise Exception( + "Use addPlaceholderTensor() or addConstTensor() to add PLACEHOLDER and CONST ops" + ) - return self.currBasicBlock.addOperator(op, inputs, outputs, attributes, quant_info) + return self.currBasicBlock.addOperator( + op, inputs, outputs, attributes, quant_info + ) - def setExpectedFailure(self, desc='', val=True): + def setExpectedFailure(self, desc="", val=True): self.expectedFailure = val self.expectedFailureDesc = desc - def setExpectedFailure(self, desc='', val=True): + def setExpectedFailure(self, desc="", val=True): self.expectedFailure = val self.expectedFailureDesc = desc @@ -635,7 +634,9 @@ class TosaSerializer: Version.VersionAdd_experimental(builder, TOSA_VERSION[3]) version = Version.VersionEnd(builder) - fbv_bb = TosaSerializer.serializeObjVec(builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector) + fbv_bb = TosaSerializer.serializeObjVec( + builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector + ) TosaGraph.TosaGraphStart(builder) TosaGraph.TosaGraphAddVersion(builder, version) @@ -646,11 +647,11 @@ class TosaSerializer: return self.builder.Output() def writeJson(self, tosa_filename): - '''Write a json test file so that it is fairly easy to pick up the test - and generate commands for third party tool''' + """Write a json test file so that it is fairly easy to pick up the test + and generate commands for third party tool""" test_desc = dict() - test_desc['tosa_file'] = tosa_filename + test_desc["tosa_file"] = tosa_filename ifm_name = [] ifm_shape = [] ifm_file = [] @@ -659,7 +660,7 @@ class TosaSerializer: ofm_shape = [] for b in self.basicBlocks: - if b.name == 'main': + if b.name == "main": for i in b.inputs: ifm_name.append(i) ifm_shape.append(b.tensors[i].shape) @@ -669,19 +670,19 @@ class TosaSerializer: ofm_shape.append(b.tensors[o].shape) # Make up an OFM filename here. One isn't generated until the reference tool is # run, so any name is a good name - ofm_file.append('ref-{}.npy'.format(o)) - - test_desc['ifm_placeholder'] = ifm_name - test_desc['ifm_file'] = ifm_file - test_desc['ifm_shape'] = ifm_shape - test_desc['ofm_name'] = ofm_name - test_desc['ofm_shape'] = ofm_shape - test_desc['ofm_file'] = ofm_file - test_desc['expected_failure'] = self.expectedFailure + ofm_file.append("ref-{}.npy".format(o)) + + test_desc["ifm_placeholder"] = ifm_name + test_desc["ifm_file"] = ifm_file + test_desc["ifm_shape"] = ifm_shape + test_desc["ofm_name"] = ofm_name + test_desc["ofm_shape"] = ofm_shape + test_desc["ofm_file"] = ofm_file + test_desc["expected_failure"] = self.expectedFailure if self.expectedFailureDesc: - test_desc['expected_failure_desc'] = self.expectedFailureDesc + test_desc["expected_failure_desc"] = self.expectedFailureDesc - return json.dumps(test_desc, indent=' ') + return json.dumps(test_desc, indent=" ") def startBasicBlock(self, name): self.currBasicBlock = TosaSerializerBasicBlock(name) @@ -748,7 +749,9 @@ class TosaSerializer: # Store the version as a global variable so that it only needs to be # generated once per process. global TOSA_VERSION - TOSA_VERSION = [root.Version()._major(), - root.Version()._minor(), - root.Version()._patch(), - root.Version()._experimental() ] + TOSA_VERSION = [ + root.Version()._major(), + root.Version()._minor(), + root.Version()._patch(), + root.Version()._experimental(), + ] |