aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-01-06 14:19:14 -0800
committerEric Kunze <eric.kunze@arm.com>2023-01-09 23:26:31 +0000
commit1eb8504e43d987cb9584149b8fdb0d37eb82964e (patch)
treed0a5226ef9942709df2d5d1269ba5705628a7ea5
parent497ab5d8cfce56eaa15db8853c903ef4dcf13e42 (diff)
downloadserialization_lib-1eb8504e43d987cb9584149b8fdb0d37eb82964e.tar.gz
Add TosaSerializerRegion to python version of serialization_lib
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: Ibd15f21aa24168730c904224f08fd55e27aae41f
-rw-r--r--python/serializer/tosa_serializer.py165
1 files changed, 106 insertions, 59 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index 85955aa..2d03d49 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@ import struct
from enum import IntEnum, unique
from tosa import (
TosaGraph,
+ TosaRegion,
TosaBasicBlock,
TosaTensor,
TosaOperator,
@@ -404,12 +405,12 @@ class TosaSerializerTensor:
self.placeholderFilename = placeholderFilename
def __str__(self):
- str = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
+ concatString = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
self.name,
self.shape,
DTypeNames[self.dtype],
)
- return str
+ return concatString
def setDtype(self, dtype):
self.dtype = dtype
@@ -500,14 +501,14 @@ class TosaSerializerOperator:
self.outputs = TosaSerializer.toList(outputs)
def __str__(self):
- str = "Op {}\n----\n".format(self.op)
+ concatString = "Op {}\n----\n".format(self.op)
for i in self.inputs:
- str = str + " Input: {}\n".format(i)
+ concatString = concatString + " Input: {}\n".format(i)
for o in self.outputs:
- str = str + " Output: {}\n".format(o)
+ concatString = concatString + " Output: {}\n".format(o)
- return str
+ return concatString
def serialize(self, builder):
fb_inputs = TosaSerializer.serializeStrVec(
@@ -532,9 +533,11 @@ class TosaSerializerOperator:
class TosaSerializerBasicBlock:
- def __init__(self, name):
+ def __init__(self, name, pathPrefix, saveConstsToFile=False):
self.name = name
+ self.pathPrefix = pathPrefix
self.operators = []
+ self.saveConstsToFile = saveConstsToFile
# Dict assures uniqueness, but allows us to look up by name
self.tensors = dict()
@@ -592,44 +595,33 @@ class TosaSerializerBasicBlock:
return TosaBasicBlock.End(builder)
-@unique
-class TensorDir(IntEnum):
- PLACEHOLDER = 0
- CONST = 1
- INTERMEDIATE = 2
- RESULT = 3
-
-
-class TosaSerializer:
- def __init__(self, pathPrefix, saveConstsToFile=False):
- self.add_compat_methods()
- # Get the global TOSA version if not already defined
-
- self.builder = flatbuffers.Builder(0)
-
+class TosaSerializerRegion:
+ def __init__(self, name, pathPrefix, saveConstsToFile=False):
+ self.name = name
self.basicBlocks = []
- self.startBasicBlock("main")
- self.pathPrefix = pathPrefix
-
- # Enables inspection of constant data outside of graph
- self.saveConstsToFile = saveConstsToFile
-
- # Indicies used for adding/naming tensors
self.currInputIdx = 0
self.currConstIdx = 0
self.currLayerIdx = 1
self.currResultIdx = 0
+ self.pathPrefix = pathPrefix
+ self.saveConstsToFile = saveConstsToFile
- # Is this an illegal test that is expected to fail?
- self.expectedReturnCode = 0
- self.expectedFailure = False
- self.expectedFailureDesc = ""
+ def addBasicBlock(self, name, pathPrefix, saveConstsToFile):
+ self.currBasicBlock = TosaSerializerBasicBlock(
+ name, pathPrefix, saveConstsToFile
+ )
+ self.basicBlocks.append(self.currBasicBlock)
- def __str__(self):
- str = ""
- for bb in self.basicBlocks:
- str = str + bb.__str__()
- return str
+ def serialize(self, builder):
+ fb_name = builder.CreateString(self.name)
+ fbv_basicBlocks = TosaSerializer.serializeObjVec(
+ builder, self.basicBlocks, TosaRegion.StartBlocksVector
+ )
+
+ TosaRegion.Start(builder)
+ TosaRegion.AddName(builder, fb_name)
+ TosaRegion.AddBlocks(builder, fbv_basicBlocks)
+ return TosaRegion.End(builder)
def addPlaceholder(self, shape, dtype, vals):
if not self.currBasicBlock:
@@ -666,7 +658,6 @@ class TosaSerializer:
return tens
def addIntermediate(self, shape, dtype):
-
if not self.currBasicBlock:
raise Exception("addTensor called without valid basic block")
@@ -696,7 +687,6 @@ class TosaSerializer:
return tens
def addOperator(self, op, inputs, outputs, attributes=None):
-
if op == TosaOp.Op().CONST:
raise Exception("Use addConstTensor() to add CONST ops")
@@ -707,6 +697,62 @@ class TosaSerializer:
attributes,
)
+
+@unique
+class TensorDir(IntEnum):
+ PLACEHOLDER = 0
+ CONST = 1
+ INTERMEDIATE = 2
+ RESULT = 3
+
+
+class TosaSerializer:
+ def __init__(self, pathPrefix, saveConstsToFile=False):
+ self.add_compat_methods()
+ # Get the global TOSA version if not already defined
+
+ self.builder = flatbuffers.Builder(0)
+
+ self.regions = []
+ self.startRegion("main", pathPrefix, saveConstsToFile)
+
+ # Enables inspection of constant data outside of graph
+ self.saveConstsToFile = saveConstsToFile
+
+ self.currRegion.addBasicBlock("main", pathPrefix, self.saveConstsToFile)
+
+ # Is this an illegal test that is expected to fail?
+ self.expectedReturnCode = 0
+ self.expectedFailure = False
+ self.expectedFailureDesc = ""
+
+ def __str__(self):
+ concatString = ""
+ for region in self.regions:
+ concatString = concatString + str(region)
+ return concatString
+
+ def addPlaceholder(self, shape, dtype, vals):
+ return self.currRegion.addPlaceholder(shape, dtype, vals)
+
+ def addConst(self, shape, dtype, vals):
+ return self.currRegion.addConst(shape, dtype, vals)
+
+ def addIntermediate(self, shape, dtype):
+ return self.currRegion.addIntermediate(shape, dtype)
+
+ def addInputTensor(self, tensor):
+ self.currRegion.addInputTensor(tensor)
+
+ def addOutputTensor(self, tensor):
+ self.currRegion.addOutputTensor(tensor)
+
+ def addOutput(self, shape, dtype):
+ return self.currRegion.addOutput(shape, dtype)
+
+ def addOperator(self, op, inputs, outputs, attributes=None):
+ return self.currRegion.addOperator(op, inputs, outputs, attributes)
+
def setExpectedReturnCode(self, val, fail, desc=""):
self.expectedReturnCode = val
@@ -724,13 +770,13 @@ class TosaSerializer:
Version.Add_draft(builder, TOSA_VERSION[3])
version = Version.End(builder)
- fbv_bb = TosaSerializer.serializeObjVec(
- builder, self.basicBlocks, TosaGraph.StartBlocksVector
+ fbv_region = TosaSerializer.serializeObjVec(
+ builder, self.regions, TosaGraph.StartRegionsVector
)
TosaGraph.Start(builder)
TosaGraph.AddVersion(builder, version)
- TosaGraph.AddBlocks(builder, fbv_bb)
+ TosaGraph.AddRegions(builder, fbv_region)
graph = TosaGraph.End(builder)
self.builder.Finish(graph, TOSA_GRAPH_IDENTIFIER)
@@ -747,16 +793,17 @@ class TosaSerializer:
ofm_name = []
ofm_file = []
- for b in self.basicBlocks:
- if b.name == "main":
- for i in b.inputs:
- ifm_name.append(i)
- ifm_file.append(b.tensors[i].placeholderFilename)
- for o in b.outputs:
- ofm_name.append(o)
- # Make up an OFM filename here. One isn't generated until the
- # reference tool is run, so any name is a good name
- ofm_file.append("ref-{}.npy".format(o))
+ for region in self.regions:
+ for block in region.basicBlocks:
+ if block:
+ for i in block.inputs:
+ ifm_name.append(i)
+ ifm_file.append(block.tensors[i].placeholderFilename)
+ for o in block.outputs:
+ ofm_name.append(o)
+ # Make up an OFM filename here. One isn't generated until the
+ # reference tool is run, so any name is a good name
+ ofm_file.append("ref-{}.npy".format(o))
test_desc["ifm_name"] = ifm_name
test_desc["ifm_file"] = ifm_file
@@ -769,9 +816,9 @@ class TosaSerializer:
return json.dumps(test_desc, indent=" ")
- def startBasicBlock(self, name):
- self.currBasicBlock = TosaSerializerBasicBlock(name)
- self.basicBlocks.append(self.currBasicBlock)
+ def startRegion(self, name, pathPrefix, saveConstsToFile):
+ self.currRegion = TosaSerializerRegion(name, pathPrefix, saveConstsToFile)
+ self.regions.append(self.currRegion)
@staticmethod
def serializeStrVec(builder, vec, start_fcn):
@@ -1090,8 +1137,8 @@ class TosaSerializer:
if not hasattr(TosaGraph, "Start"):
TosaGraph.Start = TosaGraph.TosaGraphStart
TosaGraph.AddVersion = TosaGraph.TosaGraphAddVersion
- TosaGraph.AddBlocks = TosaGraph.TosaGraphAddBlocks
- TosaGraph.StartBlocksVector = TosaGraph.TosaGraphStartBlocksVector
+ TosaGraph.AddRegions = TosaGraph.TosaGraphAddRegions
+ TosaGraph.StartRegionsVector = TosaGraph.TosaGraphStartRegionsVector
TosaGraph.End = TosaGraph.TosaGraphEnd
from tosa import TosaOperator