aboutsummaryrefslogtreecommitdiff
path: root/python/serializer/tosa_serializer.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/serializer/tosa_serializer.py')
-rw-r--r--python/serializer/tosa_serializer.py782
1 files changed, 287 insertions, 495 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index ec1c12d..9658edf 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,13 +13,14 @@
# limitations under the License.
import os
+import serializer.tosa_serializer as ts
import json
import flatbuffers
import numpy as np
-import struct
from enum import IntEnum, unique
from tosa import (
TosaGraph,
+ TosaRegion,
TosaBasicBlock,
TosaTensor,
TosaOperator,
@@ -30,7 +31,7 @@ import tosa.Op as TosaOp
# Keep version number in sync with the version default value with schema/tosa.fbs
TOSA_VERSION_MAJOR = 0
-TOSA_VERSION_MINOR = 31
+TOSA_VERSION_MINOR = 100
TOSA_VERSION_PATCH = 0
TOSA_VERSION_DRAFT = True
TOSA_VERSION = [
@@ -56,8 +57,13 @@ DTypeNames = [
"INT16",
"INT32",
"INT48",
- "FLOAT",
+ "FP32",
"UINT16",
+ "FP16",
+ "BF16",
+ "SHAPE",
+ "FP8E4M3",
+ "FP8E5M2",
]
ByteMask = np.uint64(0xFF)
@@ -90,6 +96,7 @@ class TosaSerializerUnion:
self.bools = []
self.floats = []
self.strings = []
+ self.int16vecs = []
self.intvecs = []
self.fpvecs = []
@@ -106,6 +113,9 @@ class TosaSerializerUnion:
for fcn, val in self.intvecs:
intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
+ for fcn, val in self.int16vecs:
+ intVecList.append((fcn, TosaSerializer.serializeInt16Vec(builder, val)))
+
for fcn, val in self.fpvecs:
fpVecList.append((fcn, TosaSerializer.serializeFpVec(builder, val)))
@@ -141,7 +151,15 @@ class TosaSerializerAttribute(TosaSerializerUnion):
def __init__(self):
super().__init__()
- def PoolAttribute(self, kernel, stride, pad, input_zp, output_zp):
+ def PoolAttribute(
+ self,
+ kernel,
+ stride,
+ pad,
+ input_zp,
+ output_zp,
+ acc_type,
+ ):
from tosa import PoolAttribute as a, Attribute
self.utype = Attribute.Attribute().PoolAttribute
@@ -152,8 +170,11 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.AddStride, stride))
self.ints.append((a.AddInputZp, input_zp))
self.ints.append((a.AddOutputZp, output_zp))
+ self.ints.append((a.AddAccType, acc_type))
- def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp):
+ def ConvAttribute(
+ self, pad, stride, dilation, input_zp, weight_zp, local_bound, acc_type
+ ):
from tosa import ConvAttribute as a, Attribute
self.utype = Attribute.Attribute().ConvAttribute
@@ -164,8 +185,12 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.AddDilation, dilation))
self.ints.append((a.AddInputZp, input_zp))
self.ints.append((a.AddWeightZp, weight_zp))
+ self.bools.append((a.AddLocalBound, local_bound))
+ self.ints.append((a.AddAccType, acc_type))
- def TransposeConvAttribute(self, outpad, stride, output_shape, input_zp, weight_zp):
+ def TransposeConvAttribute(
+ self, outpad, stride, output_shape, input_zp, weight_zp, local_bound, acc_type
+ ):
from tosa import TransposeConvAttribute as a, Attribute
self.utype = Attribute.Attribute().TransposeConvAttribute
@@ -176,16 +201,21 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.AddOutputShape, output_shape))
self.ints.append((a.AddInputZp, input_zp))
self.ints.append((a.AddWeightZp, weight_zp))
+ self.bools.append((a.AddLocalBound, local_bound))
+ self.ints.append((a.AddAccType, acc_type))
- def PadAttribute(self, padding, pad_const_int, pad_const_fp):
+ def PadAttribute(self, serializer_builder, pad_const_val_as_bytes):
from tosa import PadAttribute as a, Attribute
self.utype = Attribute.Attribute().PadAttribute
self.optFcns = (a.Start, a.End)
- self.intvecs.append((a.AddPadding, padding))
- self.ints.append((a.AddPadConstInt, pad_const_int))
- self.floats.append((a.AddPadConstFp, pad_const_fp))
+ # serialize pad_const_val_as_bytes as uint8 vector
+ serialized_pad_const_val = ts.TosaSerializer.serializeUint8Vec(
+ serializer_builder, pad_const_val_as_bytes
+ )
+
+ self.floats.append((a.AddPadConst, serialized_pad_const_val))
def AxisAttribute(self, axis):
from tosa import AxisAttribute as a, Attribute
@@ -195,61 +225,43 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddAxis, axis))
- def ReshapeAttribute(self, new_shape):
- from tosa import ReshapeAttribute as a, Attribute
-
- self.utype = Attribute.Attribute().ReshapeAttribute
- self.optFcns = (a.Start, a.End)
-
- self.intvecs.append((a.AddNewShape, new_shape))
-
- def SliceAttribute(self, start, size):
- from tosa import SliceAttribute as a, Attribute
-
- self.utype = Attribute.Attribute().SliceAttribute
- self.optFcns = (a.Start, a.End)
-
- self.intvecs.append((a.AddStart, start))
- self.intvecs.append((a.AddSize, size))
-
- def TileAttribute(self, multiples):
- from tosa import TileAttribute as a, Attribute
-
- self.utype = Attribute.Attribute().TileAttribute
- self.optFcns = (a.Start, a.End)
-
- self.intvecs.append((a.AddMultiples, multiples))
-
- def ResizeAttribute(
- self, output_size, stride, offset, shift, stride_fp, offset_fp, mode
- ):
+ def ResizeAttribute(self, scale, offset, border, mode):
from tosa import ResizeAttribute as a, Attribute
self.utype = Attribute.Attribute().ResizeAttribute
self.optFcns = (a.Start, a.End)
- self.intvecs.append((a.AddOutputSize, output_size))
- self.intvecs.append((a.AddStride, stride))
- self.intvecs.append((a.AddOffset, offset))
- self.ints.append((a.AddShift, shift))
- self.fpvecs.append((a.AddStrideFp, stride_fp))
- self.fpvecs.append((a.AddOffsetFp, offset_fp))
+ self.int16vecs.append((a.AddScale, scale))
+ self.int16vecs.append((a.AddOffset, offset))
+ self.int16vecs.append((a.AddBorder, border))
self.ints.append((a.AddMode, mode))
- def ClampAttribute(self, minint, maxint, minfp, maxfp):
+ def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes):
from tosa import ClampAttribute as a, Attribute
self.utype = Attribute.Attribute().ClampAttribute
self.optFcns = (a.Start, a.End)
- self.ints.append((a.AddMinInt, minint))
- self.ints.append((a.AddMaxInt, maxint))
+ # min/max float attributes serialized as uint8 vectors
+ serialized_min_val = ts.TosaSerializer.serializeUint8Vec(
+ serializer_builder, min_val_as_bytes
+ )
+ serialized_max_val = ts.TosaSerializer.serializeUint8Vec(
+ serializer_builder, max_val_as_bytes
+ )
- self.ints.append((a.AddMinFp, minfp))
- self.ints.append((a.AddMaxFp, maxfp))
+ self.floats.append((a.AddMinVal, serialized_min_val))
+ self.floats.append((a.AddMaxVal, serialized_max_val))
def RescaleAttribute(
- self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel
+ self,
+ input_zp,
+ output_zp,
+ scale32,
+ double_round,
+ per_channel,
+ input_unsigned,
+ output_unsigned,
):
from tosa import RescaleAttribute as a, Attribute
@@ -258,11 +270,11 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddInputZp, input_zp))
self.ints.append((a.AddOutputZp, output_zp))
- self.intvecs.append((a.AddMultiplier, multiplier))
- self.intvecs.append((a.AddShift, shift))
self.bools.append((a.AddScale32, scale32))
self.bools.append((a.AddDoubleRound, double_round))
self.bools.append((a.AddPerChannel, per_channel))
+ self.bools.append((a.AddInputUnsigned, input_unsigned))
+ self.bools.append((a.AddOutputUnsigned, output_unsigned))
def MulAttribute(self, shift):
from tosa import MulAttribute as a, Attribute
@@ -283,23 +295,23 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.bools.append((a.AddRound, round))
- def CondIfAttribute(self, then_branch, else_branch):
+ def CondIfAttribute(self, then_graph, else_graph):
from tosa import CondIfAttribute as a, Attribute
self.utype = Attribute.Attribute().CondIfAttribute
self.optFcns = (a.Start, a.End)
- self.strings.append((a.AddThenBranch, then_branch))
- self.strings.append((a.AddElseBranch, else_branch))
+ self.strings.append((a.AddThenGraph, then_graph))
+ self.strings.append((a.AddElseGraph, else_graph))
- def WhileLoopAttribute(self, cond_branch, body_branch):
+ def WhileLoopAttribute(self, cond_graph, body_graph):
from tosa import WhileLoopAttribute as a, Attribute
self.utype = Attribute.Attribute().WhileLoopAttribute
self.optFcns = (a.Start, a.End)
- self.strings.append((a.AddCondBranch, cond_branch))
- self.strings.append((a.AddBodyBranch, body_branch))
+ self.strings.append((a.AddCondGraph, cond_graph))
+ self.strings.append((a.AddBodyGraph, body_graph))
def TransposeAttribute(self, perms):
from tosa import TransposeAttribute as a, Attribute
@@ -315,7 +327,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.utype = Attribute.Attribute().TableAttribute
self.optFcns = (a.Start, a.End)
- self.intvecs.append((a.AddTable, table))
+ self.int16vecs.append((a.AddTable, table))
def MatMulAttribute(self, A_zp, B_zp):
from tosa import MatMulAttribute as a, Attribute
@@ -344,6 +356,23 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddInput1Zp, input1_zp))
self.ints.append((a.AddOutputZp, output_zp))
+ def FFTAttribute(self, inverse, local_bound):
+ from tosa import FFTAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().FFTAttribute
+ self.optFcns = (a.Start, a.End)
+
+ self.bools.append((a.AddInverse, inverse))
+ self.bools.append((a.AddLocalBound, local_bound))
+
+ def RFFTAttribute(self, local_bound):
+ from tosa import RFFTAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().RFFTAttribute
+ self.optFcns = (a.Start, a.End)
+
+ self.bools.append((a.AddLocalBound, local_bound))
+
class TosaSerializerTensor:
def __init__(
@@ -363,12 +392,24 @@ class TosaSerializerTensor:
self.shape = shape
self.dtype = dtype
+ if (
+ dtype == DType.FP32
+ or dtype == DType.BF16
+ or dtype == DType.FP8E4M3
+ or dtype == DType.FP8E5M2
+ ):
+ fntype = np.float32
+ elif dtype == DType.FP16:
+ fntype = np.float16
+ else:
+ fntype = int
+
if isinstance(data, np.ndarray):
- data = data.flatten().astype(int).tolist()
- data = list(map(int, data))
+ data = data.flatten().astype(fntype).tolist()
+ data = list(map(fntype, data))
self.data = data
elif isinstance(data, list):
- data = list(map(int, data))
+ data = list(map(fntype, data))
self.data = data
else:
self.data = None
@@ -381,12 +422,12 @@ class TosaSerializerTensor:
self.placeholderFilename = placeholderFilename
def __str__(self):
- str = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
+ concatString = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
self.name,
self.shape,
DTypeNames[self.dtype],
)
- return str
+ return concatString
def setDtype(self, dtype):
self.dtype = dtype
@@ -415,17 +456,17 @@ class TosaSerializerTensor:
u8_data.append(val_u8)
elif self.dtype == DType.INT8:
for val in self.data:
- val_u8 = np.uint8(val)
+ val_u8 = np.array(val).astype(dtype=np.uint8)
u8_data.append(val_u8)
elif self.dtype == DType.INT16:
for val in self.data:
- val_u16 = np.uint16(val)
+ val_u16 = np.array(val).astype(dtype=np.uint16)
b0 = val_u16 & ByteMask
b1 = (val_u16 >> np.uint16(8)) & ByteMask
u8_data.extend([b0, b1])
elif self.dtype == DType.INT32:
for val in self.data:
- val_u32 = np.uint32(val)
+ val_u32 = np.array(val).astype(dtype=np.uint32)
b0 = val_u32 & ByteMask
b1 = (val_u32 >> np.uint32(8)) & ByteMask
b2 = (val_u32 >> np.uint32(16)) & ByteMask
@@ -441,10 +482,37 @@ class TosaSerializerTensor:
b4 = (val_u64 >> np.uint64(32)) & ByteMask
b5 = (val_u64 >> np.uint64(40)) & ByteMask
u8_data.extend([b0, b1, b2, b3, b4, b5])
- elif self.dtype == DType.FLOAT:
+ elif self.dtype == DType.SHAPE:
for val in self.data:
- b = struct.pack("!f", val)
- u8_data.extend([b[3], b[2], b[1], b[0]])
+ val_u64 = np.uint64(val)
+ b0 = val_u64 & ByteMask
+ b1 = (val_u64 >> np.uint64(8)) & ByteMask
+ b2 = (val_u64 >> np.uint64(16)) & ByteMask
+ b3 = (val_u64 >> np.uint64(24)) & ByteMask
+ b4 = (val_u64 >> np.uint64(32)) & ByteMask
+ b5 = (val_u64 >> np.uint64(40)) & ByteMask
+ b6 = (val_u64 >> np.uint64(48)) & ByteMask
+ b7 = (val_u64 >> np.uint64(56)) & ByteMask
+ u8_data.extend([b0, b1, b2, b3, b4, b5, b6, b7])
+ elif self.dtype == DType.FP16:
+ np_arr = np.array(self.data, dtype=np.float16)
+ u8_data.extend(np_arr.view(np.uint8))
+ elif (
+ self.dtype == DType.FP32
+ or self.dtype == DType.BF16
+ or self.dtype == DType.FP8E4M3
+ or self.dtype == DType.FP8E5M2
+ ):
+ # for val in self.data:
+ # b = struct.pack("!f", val)
+ # u8_data.extend([b[3], b[2], b[1], b[0]])
+ np_arr = np.array(self.data, dtype=np.float32)
+ u8_data.extend(np_arr.view(np.uint8))
+ elif self.dtype == TosaDType.DType:
+ # Serialize DType enum data as uint8 bytes
+ for val in self.data:
+ np_arr = np.array(self.data, dtype=np.uint32)
+ u8_data.extend(np_arr.view(np.uint8))
else:
raise Exception(
"unsupported data type {}".format(DTypeNames[self.dtype])
@@ -469,14 +537,14 @@ class TosaSerializerOperator:
self.outputs = TosaSerializer.toList(outputs)
def __str__(self):
- str = "Op {}\n----\n".format(self.op)
+ concatString = "Op {}\n----\n".format(self.op)
for i in self.inputs:
- str = str + " Input: {}\n".format(i)
+ concatString = concatString + " Input: {}\n".format(i)
for o in self.outputs:
- str = str + " Output: {}\n".format(o)
+ concatString = concatString + " Output: {}\n".format(o)
- return str
+ return concatString
def serialize(self, builder):
fb_inputs = TosaSerializer.serializeStrVec(
@@ -561,41 +629,39 @@ class TosaSerializerBasicBlock:
return TosaBasicBlock.End(builder)
+# How CONSTs are treated in the flatbuffer
@unique
-class TensorDir(IntEnum):
- PLACEHOLDER = 0
- CONST = 1
- INTERMEDIATE = 2
- RESULT = 3
-
+class ConstMode(IntEnum):
+ EMBED = 0
+ EMBED_DUMP = 1
+ INPUTS = 2
-class TosaSerializer:
- def __init__(self, pathPrefix):
- self.add_compat_methods()
- # Get the global TOSA version if not already defined
-
- self.builder = flatbuffers.Builder(0)
+class TosaSerializerRegion:
+ def __init__(self, name, pathPrefix, constMode=ConstMode.EMBED):
+ self.name = name
self.basicBlocks = []
- self.startBasicBlock("main")
- self.pathPrefix = pathPrefix
-
- # Indicies used for adding/naming tensors
self.currInputIdx = 0
self.currConstIdx = 0
self.currLayerIdx = 1
self.currResultIdx = 0
+ self.pathPrefix = pathPrefix
+ self.constMode = constMode
- # Is this an illegal test that is expected to fail?
- self.expectedReturnCode = 0
- self.expectedFailure = False
- self.expectedFailureDesc = ""
+ def addBasicBlock(self, name):
+ self.currBasicBlock = TosaSerializerBasicBlock(name)
+ self.basicBlocks.append(self.currBasicBlock)
- def __str__(self):
- str = ""
- for bb in self.basicBlocks:
- str = str + bb.__str__()
- return str
+ def serialize(self, builder):
+ fb_name = builder.CreateString(self.name)
+ fbv_basicBlocks = TosaSerializer.serializeObjVec(
+ builder, self.basicBlocks, TosaRegion.StartBlocksVector
+ )
+
+ TosaRegion.Start(builder)
+ TosaRegion.AddName(builder, fb_name)
+ TosaRegion.AddBlocks(builder, fbv_basicBlocks)
+ return TosaRegion.End(builder)
def addPlaceholder(self, shape, dtype, vals):
if not self.currBasicBlock:
@@ -614,21 +680,42 @@ class TosaSerializer:
return tens
- def addConst(self, shape, dtype, vals):
+ def addConst(self, shape, dtype, vals, name=None):
if not self.currBasicBlock:
raise Exception("addTensor called without valid basic block")
- name = "const-{}".format(self.currInputIdx)
- self.currInputIdx = self.currInputIdx + 1
+ if name is None:
+ name = "const-{}".format(self.currInputIdx)
+ self.currInputIdx = self.currInputIdx + 1
+
+ if self.constMode == ConstMode.INPUTS:
+ # Save const as input file
+ filename = "{}.npy".format(name)
+ tensor_vals = None
+ self.currBasicBlock.addInput(name)
+ else:
+ # Embed const in flatbuffer
+ filename = None
+ tensor_vals = vals
- tens = self.currBasicBlock.addTensor(name, shape, dtype, vals)
+ tens = self.currBasicBlock.addTensor(name, shape, dtype, tensor_vals, filename)
# Add the operator now
- self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name)
+ if dtype == DType.SHAPE:
+ self.currBasicBlock.addOperator(TosaOp.Op().CONST_SHAPE, [], name)
+ else:
+ self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name)
+
+ # Save the const data to file for debug or as input files
+ if vals is not None and self.constMode in [
+ ConstMode.EMBED_DUMP,
+ ConstMode.INPUTS,
+ ]:
+ filename = "{}.npy".format(name)
+ np.save(os.path.join(self.pathPrefix, filename), vals, False)
return tens
def addIntermediate(self, shape, dtype):
-
if not self.currBasicBlock:
raise Exception("addTensor called without valid basic block")
@@ -640,7 +727,13 @@ class TosaSerializer:
return tens
def addInputTensor(self, tensor):
- self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype)
+ self.currBasicBlock.addTensor(
+ tensor.name,
+ tensor.shape,
+ tensor.dtype,
+ tensor.data,
+ tensor.placeholderFilename,
+ )
self.currBasicBlock.addInput(tensor.name)
def addOutputTensor(self, tensor):
@@ -658,7 +751,6 @@ class TosaSerializer:
return tens
def addOperator(self, op, inputs, outputs, attributes=None):
-
if op == TosaOp.Op().CONST:
raise Exception("Use addConstTensor() to add CONST ops")
@@ -669,6 +761,62 @@ class TosaSerializer:
attributes,
)
+
+@unique
+class TensorDir(IntEnum):
+ PLACEHOLDER = 0
+ CONST = 1
+ INTERMEDIATE = 2
+ RESULT = 3
+
+
+class TosaSerializer:
+ def __init__(self, pathPrefix, constMode=ConstMode.EMBED):
+ self.builder = flatbuffers.Builder(0)
+
+ # Enables inspection of constant data outside of graph
+ self.constMode = constMode
+
+ self.regions = []
+ self.startRegion("main", pathPrefix)
+
+ self.currRegion.addBasicBlock("main")
+
+ # Is this an illegal test that is expected to fail?
+ self.expectedReturnCode = 0
+ self.expectedFailure = False
+ self.expectedFailureDesc = ""
+
+ def __str__(self):
+ concatString = ""
+ for region in self.regions:
+ concatString = concatString + str(region)
+ return concatString
+
+ def addPlaceholder(self, shape, dtype, vals):
+ return self.currRegion.addPlaceholder(shape, dtype, vals)
+
+ def addConst(self, shape, dtype, vals, name=None):
+ return self.currRegion.addConst(shape, dtype, vals, name)
+
+ def addIntermediate(self, shape, dtype):
+ return self.currRegion.addIntermediate(shape, dtype)
+
+ def addInputTensor(self, tensor):
+ self.currRegion.addInputTensor(tensor)
+
+ def addOutputTensor(self, tensor):
+ self.currRegion.addOutputTensor(tensor)
+
+ def addOutput(self, shape, dtype):
+ return self.currRegion.addOutput(shape, dtype)
+
+ def addOperator(self, op, inputs, outputs, attributes=None):
+ return self.currRegion.addOperator(op, inputs, outputs, attributes)
+
+ def addBasicBlock(self, name):
+ self.currRegion.addBasicBlock(name)
+
def setExpectedReturnCode(self, val, fail, desc=""):
self.expectedReturnCode = val
@@ -680,19 +828,19 @@ class TosaSerializer:
builder = self.builder
Version.Start(builder)
- Version.Add_major(builder, TOSA_VERSION[0])
- Version.Add_minor(builder, TOSA_VERSION[1])
- Version.Add_patch(builder, TOSA_VERSION[2])
- Version.Add_draft(builder, TOSA_VERSION[3])
+ Version.Add_Major(builder, TOSA_VERSION[0])
+ Version.Add_Minor(builder, TOSA_VERSION[1])
+ Version.Add_Patch(builder, TOSA_VERSION[2])
+ Version.Add_Draft(builder, TOSA_VERSION[3])
version = Version.End(builder)
- fbv_bb = TosaSerializer.serializeObjVec(
- builder, self.basicBlocks, TosaGraph.StartBlocksVector
+ fbv_region = TosaSerializer.serializeObjVec(
+ builder, self.regions, TosaGraph.StartRegionsVector
)
TosaGraph.Start(builder)
TosaGraph.AddVersion(builder, version)
- TosaGraph.AddBlocks(builder, fbv_bb)
+ TosaGraph.AddRegions(builder, fbv_region)
graph = TosaGraph.End(builder)
self.builder.Finish(graph, TOSA_GRAPH_IDENTIFIER)
@@ -709,16 +857,17 @@ class TosaSerializer:
ofm_name = []
ofm_file = []
- for b in self.basicBlocks:
- if b.name == "main":
- for i in b.inputs:
- ifm_name.append(i)
- ifm_file.append(b.tensors[i].placeholderFilename)
- for o in b.outputs:
- ofm_name.append(o)
- # Make up an OFM filename here. One isn't generated until the
- # reference tool is run, so any name is a good name
- ofm_file.append("ref-{}.npy".format(o))
+ for region in self.regions:
+ for block in region.basicBlocks:
+ if block and block.name == "main":
+ for i in block.inputs:
+ ifm_name.append(i)
+ ifm_file.append(block.tensors[i].placeholderFilename)
+ for o in block.outputs:
+ ofm_name.append(o)
+ # Make up an OFM filename here. One isn't generated until the
+ # reference tool is run, so any name is a good name
+ ofm_file.append("ref-{}.npy".format(o))
test_desc["ifm_name"] = ifm_name
test_desc["ifm_file"] = ifm_file
@@ -731,9 +880,9 @@ class TosaSerializer:
return json.dumps(test_desc, indent=" ")
- def startBasicBlock(self, name):
- self.currBasicBlock = TosaSerializerBasicBlock(name)
- self.basicBlocks.append(self.currBasicBlock)
+ def startRegion(self, name, pathPrefix):
+ self.currRegion = TosaSerializerRegion(name, pathPrefix, self.constMode)
+ self.regions.append(self.currRegion)
@staticmethod
def serializeStrVec(builder, vec, start_fcn):
@@ -757,6 +906,16 @@ class TosaSerializer:
return builder.EndVector(len(vec))
@staticmethod
+ def serializeInt16Vec(builder, vec):
+ builder.StartVector(2, len(vec), 4)
+ for v in vec[::-1]:
+ builder.PrependInt16(v)
+ try:
+ return builder.EndVector()
+ except TypeError:
+ return builder.EndVector(len(vec))
+
+ @staticmethod
def serializeInt32Vec(builder, vec):
builder.StartVector(4, len(vec), 4)
for v in vec[::-1]:
@@ -796,370 +955,3 @@ class TosaSerializer:
return val
else:
return [val]
-
- # Remove when switching to flatbuffers 2.0
- # contains a mapping of the deprecated 1.12 method to the 2.0 version
-
- def add_compat_methods(self):
-
- from tosa import ArithmeticRightShiftAttribute
-
- if not hasattr(ArithmeticRightShiftAttribute, "Start"):
- ArithmeticRightShiftAttribute.Start = (
- ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeStart
- )
- ArithmeticRightShiftAttribute.AddRound = (
- ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeAddRound
- )
- ArithmeticRightShiftAttribute.End = (
- ArithmeticRightShiftAttribute.ArithmeticRightShiftAttributeEnd
- )
- from tosa import AxisAttribute
-
- if not hasattr(AxisAttribute, "Start"):
- AxisAttribute.Start = AxisAttribute.AxisAttributeStart
- AxisAttribute.AddAxis = AxisAttribute.AxisAttributeAddAxis
- AxisAttribute.End = AxisAttribute.AxisAttributeEnd
- from tosa import ClampAttribute
-
- if not hasattr(ClampAttribute, "Start"):
- ClampAttribute.Start = ClampAttribute.ClampAttributeStart
- ClampAttribute.AddMinInt = ClampAttribute.ClampAttributeAddMinInt
- ClampAttribute.AddMaxInt = ClampAttribute.ClampAttributeAddMaxInt
- ClampAttribute.AddMinFp = ClampAttribute.ClampAttributeAddMinFp
- ClampAttribute.AddMaxFp = ClampAttribute.ClampAttributeAddMaxFp
- ClampAttribute.End = ClampAttribute.ClampAttributeEnd
- from tosa import CondIfAttribute
-
- if not hasattr(CondIfAttribute, "Start"):
- CondIfAttribute.Start = CondIfAttribute.CondIfAttributeStart
- CondIfAttribute.AddThenBranch = CondIfAttribute.CondIfAttributeAddThenBranch
- CondIfAttribute.AddElseBranch = CondIfAttribute.CondIfAttributeAddElseBranch
- CondIfAttribute.End = CondIfAttribute.CondIfAttributeEnd
- from tosa import ConvAttribute
-
- if not hasattr(ConvAttribute, "Start"):
- ConvAttribute.Start = ConvAttribute.ConvAttributeStart
- ConvAttribute.AddPad = ConvAttribute.ConvAttributeAddPad
- ConvAttribute.StartPadVector = ConvAttribute.ConvAttributeStartPadVector
- ConvAttribute.AddStride = ConvAttribute.ConvAttributeAddStride
- ConvAttribute.StartStrideVector = (
- ConvAttribute.ConvAttributeStartStrideVector
- )
- ConvAttribute.AddDilation = ConvAttribute.ConvAttributeAddDilation
- ConvAttribute.StartDilationVector = (
- ConvAttribute.ConvAttributeStartDilationVector
- )
- ConvAttribute.AddInputZp = ConvAttribute.ConvAttributeAddInputZp
- ConvAttribute.AddWeightZp = ConvAttribute.ConvAttributeAddWeightZp
- ConvAttribute.End = ConvAttribute.ConvAttributeEnd
- from tosa import FullyConnectedAttribute
-
- if not hasattr(FullyConnectedAttribute, "Start"):
- FullyConnectedAttribute.Start = (
- FullyConnectedAttribute.FullyConnectedAttributeStart
- )
- FullyConnectedAttribute.AddInputZp = (
- FullyConnectedAttribute.FullyConnectedAttributeAddInputZp
- )
- FullyConnectedAttribute.AddWeightZp = (
- FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp
- )
- FullyConnectedAttribute.End = (
- FullyConnectedAttribute.FullyConnectedAttributeEnd
- )
- from tosa import MatMulAttribute
-
- if not hasattr(MatMulAttribute, "Start"):
- MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart
- MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp
- MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp
- MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd
- from tosa import PoolAttribute
-
- if not hasattr(PoolAttribute, "Start"):
- PoolAttribute.Start = PoolAttribute.PoolAttributeStart
- PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad
- PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector
- PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel
- PoolAttribute.StartKernelVector = (
- PoolAttribute.PoolAttributeStartKernelVector
- )
- PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride
- PoolAttribute.StartStrideVector = (
- PoolAttribute.PoolAttributeStartStrideVector
- )
- PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp
- PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp
- PoolAttribute.End = PoolAttribute.PoolAttributeEnd
- from tosa import MulAttribute
-
- if not hasattr(MulAttribute, "Start"):
- MulAttribute.Start = MulAttribute.MulAttributeStart
- MulAttribute.AddShift = MulAttribute.MulAttributeAddShift
- MulAttribute.End = MulAttribute.MulAttributeEnd
- from tosa import PadAttribute
-
- if not hasattr(PadAttribute, "Start"):
- PadAttribute.Start = PadAttribute.PadAttributeStart
- PadAttribute.AddPadding = PadAttribute.PadAttributeAddPadding
- PadAttribute.StartPaddingVector = (
- PadAttribute.PadAttributeStartPaddingVector
- )
- PadAttribute.AddPadConstInt = PadAttribute.PadAttributeAddPadConstInt
- PadAttribute.AddPadConstFp = PadAttribute.PadAttributeAddPadConstFp
- PadAttribute.End = PadAttribute.PadAttributeEnd
- from tosa import PoolAttribute
-
- if not hasattr(PoolAttribute, "Start"):
- PoolAttribute.Start = PoolAttribute.PoolAttributeStart
- PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad
- PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector
- PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel
- PoolAttribute.StartKernelVector = (
- PoolAttribute.PoolAttributeStartKernelVector
- )
- PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride
- PoolAttribute.StartStrideVector = (
- PoolAttribute.PoolAttributeStartStrideVector
- )
- PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp
- PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp
- PoolAttribute.End = PoolAttribute.PoolAttributeEnd
- from tosa import RescaleAttribute
-
- if not hasattr(RescaleAttribute, "Start"):
- RescaleAttribute.Start = RescaleAttribute.RescaleAttributeStart
- RescaleAttribute.AddInputZp = RescaleAttribute.RescaleAttributeAddInputZp
- RescaleAttribute.AddOutputZp = RescaleAttribute.RescaleAttributeAddOutputZp
- RescaleAttribute.AddMultiplier = (
- RescaleAttribute.RescaleAttributeAddMultiplier
- )
- RescaleAttribute.StartMultiplierVector = (
- RescaleAttribute.RescaleAttributeStartMultiplierVector
- )
- RescaleAttribute.AddShift = RescaleAttribute.RescaleAttributeAddShift
- RescaleAttribute.StartShiftVector = (
- RescaleAttribute.RescaleAttributeStartShiftVector
- )
- RescaleAttribute.AddScale32 = RescaleAttribute.RescaleAttributeAddScale32
- RescaleAttribute.AddDoubleRound = (
- RescaleAttribute.RescaleAttributeAddDoubleRound
- )
- RescaleAttribute.AddPerChannel = (
- RescaleAttribute.RescaleAttributeAddPerChannel
- )
- RescaleAttribute.End = RescaleAttribute.RescaleAttributeEnd
- from tosa import ReshapeAttribute
-
- if not hasattr(ReshapeAttribute, "Start"):
- ReshapeAttribute.Start = ReshapeAttribute.ReshapeAttributeStart
- ReshapeAttribute.AddNewShape = ReshapeAttribute.ReshapeAttributeAddNewShape
- ReshapeAttribute.StartNewShapeVector = (
- ReshapeAttribute.ReshapeAttributeStartNewShapeVector
- )
- ReshapeAttribute.End = ReshapeAttribute.ReshapeAttributeEnd
- from tosa import ResizeAttribute
-
- if not hasattr(ResizeAttribute, "Start"):
- ResizeAttribute.Start = ResizeAttribute.ResizeAttributeStart
- ResizeAttribute.AddOutputSize = ResizeAttribute.ResizeAttributeAddOutputSize
- ResizeAttribute.StartOutputSizeVector = (
- ResizeAttribute.ResizeAttributeStartOutputSizeVector
- )
- ResizeAttribute.AddStride = ResizeAttribute.ResizeAttributeAddStride
- ResizeAttribute.StartStrideVector = (
- ResizeAttribute.ResizeAttributeStartStrideVector
- )
- ResizeAttribute.AddOffset = ResizeAttribute.ResizeAttributeAddOffset
- ResizeAttribute.StartOffsetVector = (
- ResizeAttribute.ResizeAttributeStartOffsetVector
- )
- ResizeAttribute.AddShift = ResizeAttribute.ResizeAttributeAddShift
- ResizeAttribute.AddStrideFp = ResizeAttribute.ResizeAttributeAddStrideFp
- ResizeAttribute.StartStrideFpVector = (
- ResizeAttribute.ResizeAttributeStartStrideFpVector
- )
- ResizeAttribute.AddOffsetFp = ResizeAttribute.ResizeAttributeAddOffsetFp
- ResizeAttribute.StartOffsetFpVector = (
- ResizeAttribute.ResizeAttributeStartOffsetFpVector
- )
- ResizeAttribute.AddMode = ResizeAttribute.ResizeAttributeAddMode
- ResizeAttribute.End = ResizeAttribute.ResizeAttributeEnd
- from tosa import SliceAttribute
-
- if not hasattr(SliceAttribute, "Start"):
- SliceAttribute.Start = SliceAttribute.SliceAttributeStart
- SliceAttribute.AddStart = SliceAttribute.SliceAttributeAddStart
- SliceAttribute.StartStartVector = (
- SliceAttribute.SliceAttributeStartStartVector
- )
- SliceAttribute.AddSize = SliceAttribute.SliceAttributeAddSize
- SliceAttribute.StartSizeVector = (
- SliceAttribute.SliceAttributeStartSizeVector
- )
- SliceAttribute.End = SliceAttribute.SliceAttributeEnd
- from tosa import TableAttribute
-
- if not hasattr(TableAttribute, "Start"):
- TableAttribute.Start = TableAttribute.TableAttributeStart
- TableAttribute.AddTable = TableAttribute.TableAttributeAddTable
- TableAttribute.StartTableVector = (
- TableAttribute.TableAttributeStartTableVector
- )
- TableAttribute.End = TableAttribute.TableAttributeEnd
- from tosa import TileAttribute
-
- if not hasattr(TileAttribute, "Start"):
- TileAttribute.Start = TileAttribute.TileAttributeStart
- TileAttribute.AddMultiples = TileAttribute.TileAttributeAddMultiples
- TileAttribute.StartMultiplesVector = (
- TileAttribute.TileAttributeStartMultiplesVector
- )
- TileAttribute.End = TileAttribute.TileAttributeEnd
- from tosa import TosaBasicBlock
-
- if not hasattr(TosaBasicBlock, "Start"):
- TosaBasicBlock.Start = TosaBasicBlock.TosaBasicBlockStart
- TosaBasicBlock.AddName = TosaBasicBlock.TosaBasicBlockAddName
- TosaBasicBlock.AddOperators = TosaBasicBlock.TosaBasicBlockAddOperators
- TosaBasicBlock.StartOperatorsVector = (
- TosaBasicBlock.TosaBasicBlockStartOperatorsVector
- )
- TosaBasicBlock.AddTensors = TosaBasicBlock.TosaBasicBlockAddTensors
- TosaBasicBlock.StartTensorsVector = (
- TosaBasicBlock.TosaBasicBlockStartTensorsVector
- )
- TosaBasicBlock.AddInputs = TosaBasicBlock.TosaBasicBlockAddInputs
- TosaBasicBlock.StartInputsVector = (
- TosaBasicBlock.TosaBasicBlockStartInputsVector
- )
- TosaBasicBlock.AddOutputs = TosaBasicBlock.TosaBasicBlockAddOutputs
- TosaBasicBlock.StartOutputsVector = (
- TosaBasicBlock.TosaBasicBlockStartOutputsVector
- )
- TosaBasicBlock.End = TosaBasicBlock.TosaBasicBlockEnd
- from tosa import TosaGraph
-
- if not hasattr(TosaGraph, "Start"):
- TosaGraph.Start = TosaGraph.TosaGraphStart
- TosaGraph.AddVersion = TosaGraph.TosaGraphAddVersion
- TosaGraph.AddBlocks = TosaGraph.TosaGraphAddBlocks
- TosaGraph.StartBlocksVector = TosaGraph.TosaGraphStartBlocksVector
- TosaGraph.End = TosaGraph.TosaGraphEnd
- from tosa import TosaOperator
-
- if not hasattr(TosaOperator, "Start"):
- TosaOperator.Start = TosaOperator.TosaOperatorStart
- TosaOperator.AddOp = TosaOperator.TosaOperatorAddOp
- TosaOperator.AddAttributeType = TosaOperator.TosaOperatorAddAttributeType
- TosaOperator.AddAttribute = TosaOperator.TosaOperatorAddAttribute
- TosaOperator.AddInputs = TosaOperator.TosaOperatorAddInputs
- TosaOperator.StartInputsVector = TosaOperator.TosaOperatorStartInputsVector
- TosaOperator.AddOutputs = TosaOperator.TosaOperatorAddOutputs
- TosaOperator.StartOutputsVector = (
- TosaOperator.TosaOperatorStartOutputsVector
- )
- TosaOperator.End = TosaOperator.TosaOperatorEnd
- from tosa import TosaTensor
-
- if not hasattr(TosaTensor, "Start"):
- TosaTensor.Start = TosaTensor.TosaTensorStart
- TosaTensor.AddName = TosaTensor.TosaTensorAddName
- TosaTensor.AddShape = TosaTensor.TosaTensorAddShape
- TosaTensor.StartShapeVector = TosaTensor.TosaTensorStartShapeVector
- TosaTensor.AddType = TosaTensor.TosaTensorAddType
- TosaTensor.AddData = TosaTensor.TosaTensorAddData
- TosaTensor.StartDataVector = TosaTensor.TosaTensorStartDataVector
- TosaTensor.End = TosaTensor.TosaTensorEnd
- from tosa import TransposeAttribute
-
- if not hasattr(TransposeAttribute, "Start"):
- TransposeAttribute.Start = TransposeAttribute.TransposeAttributeStart
- TransposeAttribute.AddPerms = TransposeAttribute.TransposeAttributeAddPerms
- TransposeAttribute.StartPermsVector = (
- TransposeAttribute.TransposeAttributeStartPermsVector
- )
- TransposeAttribute.End = TransposeAttribute.TransposeAttributeEnd
- from tosa import TransposeConvAttribute
-
- if not hasattr(TransposeConvAttribute, "Start"):
- TransposeConvAttribute.Start = (
- TransposeConvAttribute.TransposeConvAttributeStart
- )
- TransposeConvAttribute.AddOutPad = (
- TransposeConvAttribute.TransposeConvAttributeAddOutPad
- )
- TransposeConvAttribute.StartOutPadVector = (
- TransposeConvAttribute.TransposeConvAttributeStartOutPadVector
- )
- TransposeConvAttribute.AddStride = (
- TransposeConvAttribute.TransposeConvAttributeAddStride
- )
- TransposeConvAttribute.StartStrideVector = (
- TransposeConvAttribute.TransposeConvAttributeStartStrideVector
- )
- TransposeConvAttribute.AddOutputShape = (
- TransposeConvAttribute.TransposeConvAttributeAddOutputShape
- )
- TransposeConvAttribute.StartOutputShapeVector = (
- TransposeConvAttribute.TransposeConvAttributeStartOutputShapeVector
- )
- TransposeConvAttribute.AddInputZp = (
- TransposeConvAttribute.TransposeConvAttributeAddInputZp
- )
- TransposeConvAttribute.AddWeightZp = (
- TransposeConvAttribute.TransposeConvAttributeAddWeightZp
- )
- TransposeConvAttribute.End = (
- TransposeConvAttribute.TransposeConvAttributeEnd
- )
- from tosa import Version
-
- if not hasattr(Version, "Start"):
- Version.Start = Version.VersionStart
- Version.Add_major = Version.VersionAdd_major
- Version.Add_minor = Version.VersionAdd_minor
- Version.Add_patch = Version.VersionAdd_patch
- Version.Add_draft = Version.VersionAdd_draft
- Version.End = Version.VersionEnd
- from tosa import MatMulAttribute
-
- if not hasattr(MatMulAttribute, "Start"):
- MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart
- MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp
- MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp
- MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd
- from tosa import FullyConnectedAttribute
-
- if not hasattr(FullyConnectedAttribute, "Start"):
- FullyConnectedAttribute.Start = (
- FullyConnectedAttribute.FullyConnectedAttributeStart
- )
- FullyConnectedAttribute.AddInputZp = (
- FullyConnectedAttribute.FullyConnectedAttributeAddInputZp
- )
- FullyConnectedAttribute.AddWeightZp = (
- FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp
- )
- FullyConnectedAttribute.End = (
- FullyConnectedAttribute.FullyConnectedAttributeEnd
- )
- from tosa import NegateAttribute
-
- if not hasattr(NegateAttribute, "Start"):
- NegateAttribute.Start = NegateAttribute.NegateAttributeStart
- NegateAttribute.AddInput1Zp = NegateAttribute.NegateAttributeAddInput1Zp
- NegateAttribute.AddOutputZp = NegateAttribute.NegateAttributeAddOutputZp
- NegateAttribute.End = NegateAttribute.NegateAttributeEnd
- from tosa import WhileLoopAttribute
-
- if not hasattr(WhileLoopAttribute, "Start"):
- WhileLoopAttribute.Start = WhileLoopAttribute.WhileLoopAttributeStart
- WhileLoopAttribute.AddCondBranch = (
- WhileLoopAttribute.WhileLoopAttributeAddCondBranch
- )
- WhileLoopAttribute.AddBodyBranch = (
- WhileLoopAttribute.WhileLoopAttributeAddBodyBranch
- )
- WhileLoopAttribute.End = WhileLoopAttribute.WhileLoopAttributeEnd