From 005c46d51660f23eace39c98e11e8d12709caab5 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Tue, 25 Jul 2023 13:56:34 +0100 Subject: Enable const data to be saved as input files Signed-off-by: Jeremy Johnson Change-Id: I2c4cb229356f874bf78cf635f6d69c79278f01f6 --- python/serializer/tosa_serializer.py | 44 +++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index c286f5f..286a067 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -615,8 +615,16 @@ class TosaSerializerBasicBlock: return TosaBasicBlock.End(builder) +# How CONSTs are treated in the flatbuffer +@unique +class ConstMode(IntEnum): + EMBED = 0 + EMBED_DUMP = 1 + INPUTS = 2 + + class TosaSerializerRegion: - def __init__(self, name, pathPrefix, saveConstsToFile=False): + def __init__(self, name, pathPrefix, constMode=ConstMode.EMBED): self.name = name self.basicBlocks = [] self.currInputIdx = 0 @@ -624,7 +632,7 @@ class TosaSerializerRegion: self.currLayerIdx = 1 self.currResultIdx = 0 self.pathPrefix = pathPrefix - self.saveConstsToFile = saveConstsToFile + self.constMode = constMode def addBasicBlock(self, name): self.currBasicBlock = TosaSerializerBasicBlock(name) @@ -665,11 +673,25 @@ class TosaSerializerRegion: name = "const-{}".format(self.currInputIdx) self.currInputIdx = self.currInputIdx + 1 - tens = self.currBasicBlock.addTensor(name, shape, dtype, vals) + if self.constMode == ConstMode.INPUTS: + # Save const as input file + filename = "{}.npy".format(name) + tensor_vals = None + self.currBasicBlock.addInput(name) + else: + # Embed const in flatbuffer + filename = None + tensor_vals = vals + + tens = self.currBasicBlock.addTensor(name, shape, dtype, tensor_vals, filename) # Add the operator now self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name) - if self.saveConstsToFile: + # Save the const data to file for debug or as input files + if vals is not None and self.constMode in [ + ConstMode.EMBED_DUMP, + ConstMode.INPUTS, + ]: filename = "{}.npy".format(name) np.save(os.path.join(self.pathPrefix, filename), vals, False) @@ -725,14 +747,14 @@ class TensorDir(IntEnum): class TosaSerializer: - def __init__(self, pathPrefix, saveConstsToFile=False): + def __init__(self, pathPrefix, constMode=ConstMode.EMBED): self.builder = flatbuffers.Builder(0) - self.regions = [] - self.startRegion("main", pathPrefix, saveConstsToFile) - # Enables inspection of constant data outside of graph - self.saveConstsToFile = saveConstsToFile + self.constMode = constMode + + self.regions = [] + self.startRegion("main", pathPrefix) self.currRegion.addBasicBlock("main") @@ -834,8 +856,8 @@ class TosaSerializer: return json.dumps(test_desc, indent=" ") - def startRegion(self, name, pathPrefix, saveConstsToFile): - self.currRegion = TosaSerializerRegion(name, pathPrefix, saveConstsToFile) + def startRegion(self, name, pathPrefix): + self.currRegion = TosaSerializerRegion(name, pathPrefix, self.constMode) self.regions.append(self.currRegion) @staticmethod -- cgit v1.2.1