aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-07-25 13:56:34 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2023-07-25 15:08:00 +0100
commit005c46d51660f23eace39c98e11e8d12709caab5 (patch)
tree4d1978ed5a3cc6fb15676bae82da52ee0695fa91
parent89963aa8fad822ab7a6e1ff92f6b7b4ee0b9350c (diff)
downloadserialization_lib-005c46d51660f23eace39c98e11e8d12709caab5.tar.gz
Enable const data to be saved as input files
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I2c4cb229356f874bf78cf635f6d69c79278f01f6
-rw-r--r--python/serializer/tosa_serializer.py44
1 files changed, 33 insertions, 11 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index c286f5f..286a067 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -615,8 +615,16 @@ class TosaSerializerBasicBlock:
return TosaBasicBlock.End(builder)
+# How CONSTs are treated in the flatbuffer
+@unique
+class ConstMode(IntEnum):
+ EMBED = 0
+ EMBED_DUMP = 1
+ INPUTS = 2
+
+
class TosaSerializerRegion:
- def __init__(self, name, pathPrefix, saveConstsToFile=False):
+ def __init__(self, name, pathPrefix, constMode=ConstMode.EMBED):
self.name = name
self.basicBlocks = []
self.currInputIdx = 0
@@ -624,7 +632,7 @@ class TosaSerializerRegion:
self.currLayerIdx = 1
self.currResultIdx = 0
self.pathPrefix = pathPrefix
- self.saveConstsToFile = saveConstsToFile
+ self.constMode = constMode
def addBasicBlock(self, name):
self.currBasicBlock = TosaSerializerBasicBlock(name)
@@ -665,11 +673,25 @@ class TosaSerializerRegion:
name = "const-{}".format(self.currInputIdx)
self.currInputIdx = self.currInputIdx + 1
- tens = self.currBasicBlock.addTensor(name, shape, dtype, vals)
+ if self.constMode == ConstMode.INPUTS:
+ # Save const as input file
+ filename = "{}.npy".format(name)
+ tensor_vals = None
+ self.currBasicBlock.addInput(name)
+ else:
+ # Embed const in flatbuffer
+ filename = None
+ tensor_vals = vals
+
+ tens = self.currBasicBlock.addTensor(name, shape, dtype, tensor_vals, filename)
# Add the operator now
self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name)
- if self.saveConstsToFile:
+ # Save the const data to file for debug or as input files
+ if vals is not None and self.constMode in [
+ ConstMode.EMBED_DUMP,
+ ConstMode.INPUTS,
+ ]:
filename = "{}.npy".format(name)
np.save(os.path.join(self.pathPrefix, filename), vals, False)
@@ -725,14 +747,14 @@ class TensorDir(IntEnum):
class TosaSerializer:
- def __init__(self, pathPrefix, saveConstsToFile=False):
+ def __init__(self, pathPrefix, constMode=ConstMode.EMBED):
self.builder = flatbuffers.Builder(0)
- self.regions = []
- self.startRegion("main", pathPrefix, saveConstsToFile)
-
# Enables inspection of constant data outside of graph
- self.saveConstsToFile = saveConstsToFile
+ self.constMode = constMode
+
+ self.regions = []
+ self.startRegion("main", pathPrefix)
self.currRegion.addBasicBlock("main")
@@ -834,8 +856,8 @@ class TosaSerializer:
return json.dumps(test_desc, indent=" ")
- def startRegion(self, name, pathPrefix, saveConstsToFile):
- self.currRegion = TosaSerializerRegion(name, pathPrefix, saveConstsToFile)
+ def startRegion(self, name, pathPrefix):
+ self.currRegion = TosaSerializerRegion(name, pathPrefix, self.constMode)
self.regions.append(self.currRegion)
@staticmethod