aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_serializer.py
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-03-03 11:21:43 -0800
committerKevin Cheng <kevin.cheng@arm.com>2021-04-27 16:01:59 -0700
commit550ccc52de231621c0bf0c05ae2a398eec37ff51 (patch)
treed4a5bd8d24560135784208c0fe35615b1d043249 /verif/tosa_serializer.py
parentcf6224e6e8ba4fc2984de3e542538c38e27c9f57 (diff)
downloadreference_model-550ccc52de231621c0bf0c05ae2a398eec37ff51.tar.gz
Replace serialization/ and verif/ with MLPlatform's serialization_lib submodule
- Remove Usage and Format - Run black on verif/*.py scripts Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ie81515891eb0039540f614894f4b6b0e0e78ba74
Diffstat (limited to 'verif/tosa_serializer.py')
-rw-r--r--verif/tosa_serializer.py405
1 files changed, 204 insertions, 201 deletions
diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py
index 136f7aa..fa1fdcb 100644
--- a/verif/tosa_serializer.py
+++ b/verif/tosa_serializer.py
@@ -1,5 +1,3 @@
-
-
# Copyright (c) 2020-2021, ARM Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,37 +14,57 @@
#!/usr/bin/env python3
+import os
+import sys
+import json
import flatbuffers
import numpy as np
from enum import Enum, IntEnum, unique
-from tosa import TosaGraph, TosaBasicBlock, TosaTensor, TosaOperator, DType, Format, Usage, Op, ResizeMode, Version
+from tosa import (
+ TosaGraph,
+ TosaBasicBlock,
+ TosaTensor,
+ TosaOperator,
+ DType,
+ Op,
+ ResizeMode,
+ Version,
+)
+
+# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
+parent_dir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(
+ os.path.join(parent_dir, "..", "thirdparty", "serialization_lib", "python")
+)
import tosa
-import os
-import json
# With the way flatc generates its python types, there is no programatic way
# to get string names for the integer types. Manually maintain a string table
# here.
-DTypeNames = [ 'UNKNOWN',
- 'BOOL',
- 'UINT8',
- 'INT4',
- 'INT8',
- 'INT16',
- 'INT32',
- 'INT48',
- 'FLOAT' ]
+DTypeNames = [
+ "UNKNOWN",
+ "BOOL",
+ "UINT8",
+ "INT4",
+ "INT8",
+ "INT16",
+ "INT32",
+ "INT48",
+ "FLOAT",
+]
+
def dtype_str_to_val(name):
for i in range(len(DTypeNames)):
if name.casefold() == DTypeNames[i].casefold():
return i
- raise Exception('Unable to parse DType name {}'.format(name))
+ raise Exception("Unable to parse DType name {}".format(name))
class TosaSerializerUnion:
- '''This class handles encapsulating and serializing union types into flatbuffers'''
+ """This class handles encapsulating and serializing union types into flatbuffers"""
+
def __init__(self):
# A tuple of the start and end functions. Set by the options constructors below
@@ -105,8 +123,9 @@ class TosaSerializerUnion:
return endFcn(builder)
+
class TosaSerializerAttribute(TosaSerializerUnion):
- '''This class handles encapsulating all of the enumerated types for attributes'''
+ """This class handles encapsulating all of the enumerated types for attributes"""
def __init__(self):
super().__init__()
@@ -117,12 +136,9 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().Pool2dAttribute
self.optFcns = (a.Pool2dAttributeStart, a.Pool2dAttributeEnd)
- self.intvecs.append((a.Pool2dAttributeAddPadding,
- padding))
- self.intvecs.append((a.Pool2dAttributeAddKernel,
- kernel))
- self.intvecs.append((a.Pool2dAttributeAddStride,
- stride))
+ self.intvecs.append((a.Pool2dAttributeAddPadding, padding))
+ self.intvecs.append((a.Pool2dAttributeAddKernel, kernel))
+ self.intvecs.append((a.Pool2dAttributeAddStride, stride))
def Conv2dAttribute(self, padding, stride, dilation):
from tosa import Conv2dAttribute as a, Attribute
@@ -130,12 +146,9 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().Conv2dAttribute
self.optFcns = (a.Conv2dAttributeStart, a.Conv2dAttributeEnd)
- self.intvecs.append((a.Conv2dAttributeAddPadding,
- padding))
- self.intvecs.append((a.Conv2dAttributeAddStride,
- stride))
- self.intvecs.append((a.Conv2dAttributeAddDilation,
- dilation))
+ self.intvecs.append((a.Conv2dAttributeAddPadding, padding))
+ self.intvecs.append((a.Conv2dAttributeAddStride, stride))
+ self.intvecs.append((a.Conv2dAttributeAddDilation, dilation))
def TransposeConv2DAttribute(self, outpad, stride, dilation, output_shape):
from tosa import TransposeConv2dAttribute as a, Attribute
@@ -143,14 +156,10 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().TransposeConv2dAttribute
self.optFcns = (a.TransposeConv2dAttributeStart, a.TransposeConv2dAttributeEnd)
- self.intvecs.append((a.TransposeConv2dAttributeAddOutpad,
- outpad))
- self.intvecs.append((a.TransposeConv2dAttributeAddStride,
- stride))
- self.intvecs.append((a.TransposeConv2dAttributeAddDilation,
- dilation))
- self.intvecs.append((a.TransposeConv2dAttributeAddOutputShape,
- output_shape))
+ self.intvecs.append((a.TransposeConv2dAttributeAddOutpad, outpad))
+ self.intvecs.append((a.TransposeConv2dAttributeAddStride, stride))
+ self.intvecs.append((a.TransposeConv2dAttributeAddDilation, dilation))
+ self.intvecs.append((a.TransposeConv2dAttributeAddOutputShape, output_shape))
def ReluNAttribute(self, maxint, maxfp):
from tosa import ReluNAttribute as a, Attribute
@@ -161,15 +170,13 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.ReluNAttributeAddMaxInt, maxint))
self.ints.append((a.ReluNAttributeAddMaxFp, maxfp))
-
def AxisAttribute(self, axis):
from tosa import AxisAttribute as a, Attribute
self.utype = Attribute.Attribute().AxisAttribute
self.optFcns = (a.AxisAttributeStart, a.AxisAttributeEnd)
- self.ints.append((a.AxisAttributeAddAxis,
- axis))
+ self.ints.append((a.AxisAttributeAddAxis, axis))
def ReshapeAttribute(self, shape):
from tosa import ReshapeAttribute as a, Attribute
@@ -177,8 +184,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().ReshapeAttribute
self.optFcns = (a.ReshapeAttributeStart, a.ReshapeAttributeEnd)
- self.intvecs.append((a.ReshapeAttributeAddShape,
- shape))
+ self.intvecs.append((a.ReshapeAttributeAddShape, shape))
def SliceAttribute(self, begin, size):
from tosa import SliceAttribute as a, Attribute
@@ -186,10 +192,8 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().SliceAttribute
self.optFcns = (a.SliceAttributeStart, a.SliceAttributeEnd)
- self.intvecs.append((a.SliceAttributeAddBegin,
- begin))
- self.intvecs.append((a.SliceAttributeAddSize,
- size))
+ self.intvecs.append((a.SliceAttributeAddBegin, begin))
+ self.intvecs.append((a.SliceAttributeAddSize, size))
def TileAttribute(self, multiples):
from tosa import TileAttribute as a, Attribute
@@ -197,29 +201,23 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().TileAttribute
self.optFcns = (a.TileAttributeStart, a.TileAttributeEnd)
- self.intvecs.append((a.TileAttributeAddMultiples,
- multiples))
+ self.intvecs.append((a.TileAttributeAddMultiples, multiples))
- def ResizeAttribute(self, output_size, stride, offset, shift, stride_fp, offset_fp, mode):
+ def ResizeAttribute(
+ self, output_size, stride, offset, shift, stride_fp, offset_fp, mode
+ ):
from tosa import ResizeAttribute as a, Attribute
self.utype = Attribute.Attribute().ResizeAttribute
self.optFcns = (a.ResizeAttributeStart, a.ResizeAttributeEnd)
- self.intvecs.append((a.ResizeAttributeAddOutputSize,
- output_size))
- self.intvecs.append((a.ResizeAttributeAddStride,
- stride))
- self.intvecs.append((a.ResizeAttributeAddOffset,
- offset))
- self.ints.append((a.ResizeAttributeAddShift,
- shift))
- self.fpvecs.append((a.ResizeAttributeAddStrideFp,
- stride_fp))
- self.fpvecs.append((a.ResizeAttributeAddOffsetFp,
- offset_fp))
- self.ints.append((a.ResizeAttributeAddMode,
- mode))
+ self.intvecs.append((a.ResizeAttributeAddOutputSize, output_size))
+ self.intvecs.append((a.ResizeAttributeAddStride, stride))
+ self.intvecs.append((a.ResizeAttributeAddOffset, offset))
+ self.ints.append((a.ResizeAttributeAddShift, shift))
+ self.fpvecs.append((a.ResizeAttributeAddStrideFp, stride_fp))
+ self.fpvecs.append((a.ResizeAttributeAddOffsetFp, offset_fp))
+ self.ints.append((a.ResizeAttributeAddMode, mode))
def ClampAttribute(self, minint, maxint, minfp, maxfp):
from tosa import ClampAttribute as a, Attribute
@@ -227,36 +225,27 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().ClampAttribute
self.optFcns = (a.ClampAttributeStart, a.ClampAttributeEnd)
- self.ints.append((a.ClampAttributeAddMinInt,
- minint))
- self.ints.append((a.ClampAttributeAddMaxInt,
- maxint))
+ self.ints.append((a.ClampAttributeAddMinInt, minint))
+ self.ints.append((a.ClampAttributeAddMaxInt, maxint))
- self.ints.append((a.ClampAttributeAddMinFp,
- minfp))
- self.ints.append((a.ClampAttributeAddMaxFp,
- maxfp))
+ self.ints.append((a.ClampAttributeAddMinFp, minfp))
+ self.ints.append((a.ClampAttributeAddMaxFp, maxfp))
- def RescaleAttribute(self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel):
+ def RescaleAttribute(
+ self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel
+ ):
from tosa import RescaleAttribute as a, Attribute
self.utype = Attribute.Attribute().RescaleAttribute
self.optFcns = (a.RescaleAttributeStart, a.RescaleAttributeEnd)
- self.ints.append((a.RescaleAttributeAddInputZp,
- input_zp))
- self.ints.append((a.RescaleAttributeAddOutputZp,
- output_zp))
- self.intvecs.append((a.RescaleAttributeAddMultiplier,
- multiplier))
- self.intvecs.append((a.RescaleAttributeAddShift,
- shift))
- self.bools.append((a.RescaleAttributeAddScale32,
- scale32))
- self.bools.append((a.RescaleAttributeAddDoubleRound,
- double_round))
- self.bools.append((a.RescaleAttributeAddPerChannel,
- per_channel))
+ self.ints.append((a.RescaleAttributeAddInputZp, input_zp))
+ self.ints.append((a.RescaleAttributeAddOutputZp, output_zp))
+ self.intvecs.append((a.RescaleAttributeAddMultiplier, multiplier))
+ self.intvecs.append((a.RescaleAttributeAddShift, shift))
+ self.bools.append((a.RescaleAttributeAddScale32, scale32))
+ self.bools.append((a.RescaleAttributeAddDoubleRound, double_round))
+ self.bools.append((a.RescaleAttributeAddPerChannel, per_channel))
def MulAttribute(self, shift):
from tosa import MulAttribute as a, Attribute
@@ -264,17 +253,18 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().MulAttribute
self.optFcns = (a.MulAttributeStart, a.MulAttributeEnd)
- self.ints.append((a.MulAttributeAddShift,
- shift))
+ self.ints.append((a.MulAttributeAddShift, shift))
def ArithmeticRightShiftAttribute(self, round):
from tosa import ArithmeticRightShiftAttribute as a, Attribute
self.utype = Attribute.Attribute().ArithmeticRightShiftAttribute
- self.optFcns = (a.ArithmeticRightShiftAttributeStart, a.ArithmeticRightShiftAttributeEnd)
+ self.optFcns = (
+ a.ArithmeticRightShiftAttributeStart,
+ a.ArithmeticRightShiftAttributeEnd,
+ )
- self.bools.append((a.ArithmeticRightShiftAttributeAddRound,
- round))
+ self.bools.append((a.ArithmeticRightShiftAttributeAddRound, round))
def CustomAttribute(self, identifier):
from tosa import CustomAttribute as a, Attribute
@@ -282,8 +272,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().CustomAttribute
self.optFcns = (a.CustomAttributeStart, a.CustomAttributeEnd)
- self.strings.append((a.CustomAttributeAddIdentifier,
- identifier))
+ self.strings.append((a.CustomAttributeAddIdentifier, identifier))
def CondIfAttribute(self, then_branch, else_branch):
from tosa import CondIfAttribute as a, Attribute
@@ -291,10 +280,8 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().CondIfAttribute
self.optFcns = (a.CondIfAttributeStart, a.CondIfAttributeEnd)
- self.strings.append((a.CondIfAttributeAddThenBranch,
- then_branch))
- self.strings.append((a.CondIfAttributeAddElseBranch,
- else_branch))
+ self.strings.append((a.CondIfAttributeAddThenBranch, then_branch))
+ self.strings.append((a.CondIfAttributeAddElseBranch, else_branch))
def WhileLoopAttribute(self, cond_branch, body_branch):
from tosa import WhileLoopAttribute as a, Attribute
@@ -302,13 +289,13 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().WhileLoopAttribute
self.optFcns = (a.WhileLoopAttributeStart, a.WhileLoopAttributeEnd)
- self.strings.append((a.WhileLoopAttributeAddCondBranch,
- cond_branch))
- self.strings.append((a.WhileLoopAttributeAddBodyBranch,
- body_branch))
+ self.strings.append((a.WhileLoopAttributeAddCondBranch, cond_branch))
+ self.strings.append((a.WhileLoopAttributeAddBodyBranch, body_branch))
+
class TosaSerializerQuantInfo(TosaSerializerUnion):
- '''This class handles encapsulating all of the enumerated types for quantinfo types'''
+ """This class handles encapsulating all of the enumerated types for quantinfo types"""
+
def __init__(self):
super().__init__()
@@ -343,8 +330,16 @@ class TosaSerializerQuantInfo(TosaSerializerUnion):
self.optFcns = (q.PadQuantInfoStart, q.PadQuantInfoEnd)
self.ints.append((q.PadQuantInfoAddInputZp, input_zp))
+
class TosaSerializerTensor:
- def __init__(self, name, shape, dtype, usage, dformat, filename = None, placeholderFilename = None):
+ def __init__(
+ self,
+ name,
+ shape,
+ dtype,
+ filename=None,
+ placeholderFilename=None,
+ ):
self.name = name
if isinstance(shape, np.ndarray):
@@ -353,8 +348,6 @@ class TosaSerializerTensor:
self.shape = shape
self.dtype = dtype
- self.usage = TosaSerializer.toList(usage)
- self.dformat = TosaSerializer.toList(dformat)
# Filename for const tensors. This gets written to the .tosa serialization
self.filename = filename
@@ -366,58 +359,35 @@ class TosaSerializerTensor:
self.placeholderFilename = placeholderFilename
def __str__(self):
- str = 'TosaSerializerTensor name: {} shape: {} dtype: {} Usage: {} format {} filename: {}'.format(
- self.name, self.shape, DTypeNames[self.dtype], self.usage, self.dformat, self.filename)
+ str = "TosaSerializerTensor name: {} shape: {} dtype: {} filename: {}".format(
+ self.name,
+ self.shape,
+ DTypeNames[self.dtype],
+ self.filename,
+ )
return str
- def addUsage(self, usage):
- self.usage.append(usage)
-
- def addFormat(self, format):
- self.dformat.append(format)
-
def setDtype(self, dtype):
self.dtype = dtype
- def merge(self, name, shape, dtype, usage, dformat, filename = None):
- # Merge in additional usage/formats to the list
- found = 0
- for i in self.usage:
- if i == usage:
- found = 1
- break
- if not found:
- self.usage.append(usage)
-
- found = 0
- for i in self.dformat:
- if i == dformat:
- found = 1
- break
- if not found:
- self.dformat.append(dformat)
-
def serialize(self, builder):
fb_name = builder.CreateString(self.name)
if self.filename:
fb_filename = builder.CreateString(self.filename)
fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
- fb_usage = TosaSerializer.serializeInt32Vec(builder, self.usage)
- fb_dformat = TosaSerializer.serializeInt32Vec(builder, self.dformat)
TosaTensor.TosaTensorStart(builder)
TosaTensor.TosaTensorAddName(builder, fb_name)
TosaTensor.TosaTensorAddShape(builder, fb_shapes)
TosaTensor.TosaTensorAddType(builder, self.dtype)
- TosaTensor.TosaTensorAddUsage(builder, fb_usage)
- TosaTensor.TosaTensorAddFormat(builder, fb_dformat)
if self.filename:
TosaTensor.TosaTensorAddNpyFilename(builder, fb_filename)
return TosaTensor.TosaTensorEnd(builder)
+
class TosaSerializerOperator:
- def __init__(self, op, inputs, outputs, attributes = None, quantInfo = None):
+ def __init__(self, op, inputs, outputs, attributes=None, quantInfo=None):
self.op = op
self.attributes = attributes
self.inputs = TosaSerializer.toList(inputs)
@@ -425,18 +395,22 @@ class TosaSerializerOperator:
self.quantInfo = quantInfo
def __str__(self):
- str = 'Op {}\n----\n'.format(self.op)
+ str = "Op {}\n----\n".format(self.op)
for i in self.inputs:
- str = str + ' Input: {}\n'.format(i)
+ str = str + " Input: {}\n".format(i)
for o in self.outputs:
- str = str + ' Output: {}\n'.format(o)
+ str = str + " Output: {}\n".format(o)
return str
def serialize(self, builder):
- fb_inputs = TosaSerializer.serializeStrVec(builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector)
- fb_outputs = TosaSerializer.serializeStrVec(builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector)
+ fb_inputs = TosaSerializer.serializeStrVec(
+ builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector
+ )
+ fb_outputs = TosaSerializer.serializeStrVec(
+ builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector
+ )
# Need to serialize quant_info and attributes enums still
if self.attributes is not None:
fb_attributes = self.attributes.serialize(builder)
@@ -457,6 +431,7 @@ class TosaSerializerOperator:
return TosaOperator.TosaOperatorEnd(builder)
+
class TosaSerializerBasicBlock:
def __init__(self, name):
self.name = name
@@ -468,14 +443,21 @@ class TosaSerializerBasicBlock:
self.inputs = []
self.outputs = []
- def addTensor(self, name, shape, dtype, usage, dformat, filename = None, placeholderFilename = None):
+ def addTensor(
+ self,
+ name,
+ shape,
+ dtype,
+ filename=None,
+ placeholderFilename=None,
+ ):
try:
# Someone already added this tensor.
- # We may have to add more usages and formats
tens = self.tensors[name]
- filename = tens.merge(name, shape, dtype, usage, dformat, filename)
except KeyError:
- self.tensors[name] = TosaSerializerTensor(name, shape, dtype, usage, dformat, filename, placeholderFilename)
+ self.tensors[name] = TosaSerializerTensor(
+ name, shape, dtype, filename, placeholderFilename
+ )
return self.tensors[name]
@@ -485,15 +467,27 @@ class TosaSerializerBasicBlock:
def addOutput(self, name):
self.outputs.append(name)
- def addOperator(self, op, inputs, outputs, attributes = None, quant_info = None):
- self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes, quant_info))
+ def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
+ self.operators.append(
+ TosaSerializerOperator(op, inputs, outputs, attributes, quant_info)
+ )
def serialize(self, builder):
fb_name = builder.CreateString(self.name)
- fbv_inputs = TosaSerializer.serializeStrVec(builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector)
- fbv_outputs = TosaSerializer.serializeStrVec(builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector)
- fbv_tensors = TosaSerializer.serializeObjVec(builder, list(self.tensors.values()), TosaBasicBlock.TosaBasicBlockStartTensorsVector)
- fbv_operators = TosaSerializer.serializeObjVec(builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector)
+ fbv_inputs = TosaSerializer.serializeStrVec(
+ builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector
+ )
+ fbv_outputs = TosaSerializer.serializeStrVec(
+ builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector
+ )
+ fbv_tensors = TosaSerializer.serializeObjVec(
+ builder,
+ list(self.tensors.values()),
+ TosaBasicBlock.TosaBasicBlockStartTensorsVector,
+ )
+ fbv_operators = TosaSerializer.serializeObjVec(
+ builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector
+ )
TosaBasicBlock.TosaBasicBlockStart(builder)
TosaBasicBlock.TosaBasicBlockAddName(builder, fb_name)
@@ -503,6 +497,7 @@ class TosaSerializerBasicBlock:
TosaBasicBlock.TosaBasicBlockAddOperators(builder, fbv_operators)
return TosaBasicBlock.TosaBasicBlockEnd(builder)
+
@unique
class TensorDir(IntEnum):
PLACEHOLDER = 0
@@ -510,6 +505,7 @@ class TensorDir(IntEnum):
INTERMEDIATE = 2
RESULT = 3
+
class TosaSerializer:
def __init__(self, pathPrefix):
@@ -522,7 +518,7 @@ class TosaSerializer:
self.builder = flatbuffers.Builder(0)
self.basicBlocks = []
- self.startBasicBlock('main')
+ self.startBasicBlock("main")
self.pathPrefix = pathPrefix
# Indicies used for adding/naming tensors
@@ -533,23 +529,23 @@ class TosaSerializer:
# Is this an illegal test that is expected to fail?
self.expectedFailure = False
- self.expectedFailureDesc = ''
+ self.expectedFailureDesc = ""
def __str__(self):
- str = ''
+ str = ""
for bb in self.basicBlocks:
str = str + bb.__str__()
return str
- def addPlaceholder(self, shape, dtype, usage, dformat, vals):
+ def addPlaceholder(self, shape, dtype, vals):
if not self.currBasicBlock:
- raise Exception('addTensor called without valid basic block')
+ raise Exception("addTensor called without valid basic block")
- name = 'input-{}'.format(self.currInputIdx)
- filename = '{}.npy'.format(name)
+ name = "input-{}".format(self.currInputIdx)
+ filename = "{}.npy".format(name)
self.currInputIdx = self.currInputIdx + 1
- tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, None, filename)
+ tens = self.currBasicBlock.addTensor(name, shape, dtype, None, filename)
# This is always an input to the block
self.currBasicBlock.addInput(name)
# Add the operator now
@@ -560,15 +556,15 @@ class TosaSerializer:
return tens
- def addConst(self, shape, dtype, usage, dformat, vals):
+ def addConst(self, shape, dtype, vals):
if not self.currBasicBlock:
- raise Exception('addTensor called without valid basic block')
+ raise Exception("addTensor called without valid basic block")
- name = 'const-{}'.format(self.currInputIdx)
- filename = '{}.npy'.format(name)
+ name = "const-{}".format(self.currInputIdx)
+ filename = "{}.npy".format(name)
self.currInputIdx = self.currInputIdx + 1
- tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, filename)
+ tens = self.currBasicBlock.addTensor(name, shape, dtype, filename)
# Add the operator now
self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name)
@@ -576,51 +572,54 @@ class TosaSerializer:
np.save(os.path.join(self.pathPrefix, filename), vals, False)
return tens
- def addIntermediate(self, shape, dtype, usage, dformat):
+ def addIntermediate(self, shape, dtype):
if not self.currBasicBlock:
- raise Exception('addTensor called without valid basic block')
+ raise Exception("addTensor called without valid basic block")
- name = 'layer-{}'.format(self.currLayerIdx)
- filename = None # No file, so no filename
+ name = "layer-{}".format(self.currLayerIdx)
+ filename = None # No file, so no filename
self.currLayerIdx = self.currLayerIdx + 1
- tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, filename)
+ tens = self.currBasicBlock.addTensor(name, shape, dtype, filename)
return tens
def addInputTensor(self, tensor):
self.currBasicBlock.addOperator(tosa.Op.Op().PLACEHOLDER, [], tensor.name)
- self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype, tensor.usage, tensor.dformat)
+ self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype)
self.currBasicBlock.addInput(tensor.name)
def addOutputTensor(self, tensor):
self.currBasicBlock.addOutput(tensor.name)
- def addOutput(self, shape, dtype, usage, dformat):
+ def addOutput(self, shape, dtype):
if not self.currBasicBlock:
- raise Exception('addTensor called without valid basic block')
+ raise Exception("addTensor called without valid basic block")
- name = 'result-{}'.format(self.currResultIdx)
+ name = "result-{}".format(self.currResultIdx)
self.currResultIdx = self.currResultIdx + 1
- tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, None)
+ tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
self.currBasicBlock.addOutput(name)
return tens
- def addOperator(self, op, inputs, outputs, attributes = None, quant_info = None):
+ def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
- if op == tosa.Op.Op().PLACEHOLDER or \
- op == tosa.Op.Op().CONST:
- raise Exception('Use addPlaceholderTensor() or addConstTensor() to add PLACEHOLDER and CONST ops')
+ if op == tosa.Op.Op().PLACEHOLDER or op == tosa.Op.Op().CONST:
+ raise Exception(
+ "Use addPlaceholderTensor() or addConstTensor() to add PLACEHOLDER and CONST ops"
+ )
- return self.currBasicBlock.addOperator(op, inputs, outputs, attributes, quant_info)
+ return self.currBasicBlock.addOperator(
+ op, inputs, outputs, attributes, quant_info
+ )
- def setExpectedFailure(self, desc='', val=True):
+ def setExpectedFailure(self, desc="", val=True):
self.expectedFailure = val
self.expectedFailureDesc = desc
- def setExpectedFailure(self, desc='', val=True):
+ def setExpectedFailure(self, desc="", val=True):
self.expectedFailure = val
self.expectedFailureDesc = desc
@@ -635,7 +634,9 @@ class TosaSerializer:
Version.VersionAdd_experimental(builder, TOSA_VERSION[3])
version = Version.VersionEnd(builder)
- fbv_bb = TosaSerializer.serializeObjVec(builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector)
+ fbv_bb = TosaSerializer.serializeObjVec(
+ builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector
+ )
TosaGraph.TosaGraphStart(builder)
TosaGraph.TosaGraphAddVersion(builder, version)
@@ -646,11 +647,11 @@ class TosaSerializer:
return self.builder.Output()
def writeJson(self, tosa_filename):
- '''Write a json test file so that it is fairly easy to pick up the test
- and generate commands for third party tool'''
+ """Write a json test file so that it is fairly easy to pick up the test
+ and generate commands for third party tool"""
test_desc = dict()
- test_desc['tosa_file'] = tosa_filename
+ test_desc["tosa_file"] = tosa_filename
ifm_name = []
ifm_shape = []
ifm_file = []
@@ -659,7 +660,7 @@ class TosaSerializer:
ofm_shape = []
for b in self.basicBlocks:
- if b.name == 'main':
+ if b.name == "main":
for i in b.inputs:
ifm_name.append(i)
ifm_shape.append(b.tensors[i].shape)
@@ -669,19 +670,19 @@ class TosaSerializer:
ofm_shape.append(b.tensors[o].shape)
# Make up an OFM filename here. One isn't generated until the reference tool is
# run, so any name is a good name
- ofm_file.append('ref-{}.npy'.format(o))
-
- test_desc['ifm_placeholder'] = ifm_name
- test_desc['ifm_file'] = ifm_file
- test_desc['ifm_shape'] = ifm_shape
- test_desc['ofm_name'] = ofm_name
- test_desc['ofm_shape'] = ofm_shape
- test_desc['ofm_file'] = ofm_file
- test_desc['expected_failure'] = self.expectedFailure
+ ofm_file.append("ref-{}.npy".format(o))
+
+ test_desc["ifm_placeholder"] = ifm_name
+ test_desc["ifm_file"] = ifm_file
+ test_desc["ifm_shape"] = ifm_shape
+ test_desc["ofm_name"] = ofm_name
+ test_desc["ofm_shape"] = ofm_shape
+ test_desc["ofm_file"] = ofm_file
+ test_desc["expected_failure"] = self.expectedFailure
if self.expectedFailureDesc:
- test_desc['expected_failure_desc'] = self.expectedFailureDesc
+ test_desc["expected_failure_desc"] = self.expectedFailureDesc
- return json.dumps(test_desc, indent=' ')
+ return json.dumps(test_desc, indent=" ")
def startBasicBlock(self, name):
self.currBasicBlock = TosaSerializerBasicBlock(name)
@@ -748,7 +749,9 @@ class TosaSerializer:
# Store the version as a global variable so that it only needs to be
# generated once per process.
global TOSA_VERSION
- TOSA_VERSION = [root.Version()._major(),
- root.Version()._minor(),
- root.Version()._patch(),
- root.Version()._experimental() ]
+ TOSA_VERSION = [
+ root.Version()._major(),
+ root.Version()._minor(),
+ root.Version()._patch(),
+ root.Version()._experimental(),
+ ]