From ca7ce0e94b3ee7339f31b47baa3a3fb4522243a2 Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Tue, 10 Jan 2023 17:24:38 +0000 Subject: Allow test generators to directly add basicBlocks through the serializer + Fixed a writeJson bug, only add input/outputs tensors from main block Signed-off-by: Jerry Ge Change-Id: I2790c2ee47b2ca2a1d8730f846061e31fc0c39f6 --- python/serializer/tosa_serializer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index 5ec45d1..8f70fb0 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -533,11 +533,9 @@ class TosaSerializerOperator: class TosaSerializerBasicBlock: - def __init__(self, name, pathPrefix, saveConstsToFile=False): + def __init__(self, name): self.name = name - self.pathPrefix = pathPrefix self.operators = [] - self.saveConstsToFile = saveConstsToFile # Dict assures uniqueness, but allows us to look up by name self.tensors = dict() @@ -606,10 +604,8 @@ class TosaSerializerRegion: self.pathPrefix = pathPrefix self.saveConstsToFile = saveConstsToFile - def addBasicBlock(self, name, pathPrefix, saveConstsToFile): - self.currBasicBlock = TosaSerializerBasicBlock( - name, pathPrefix, saveConstsToFile - ) + def addBasicBlock(self, name): + self.currBasicBlock = TosaSerializerBasicBlock(name) self.basicBlocks.append(self.currBasicBlock) def serialize(self, builder): @@ -716,7 +712,7 @@ class TosaSerializer: # Enables inspection of constant data outside of graph self.saveConstsToFile = saveConstsToFile - self.currRegion.addBasicBlock("main", pathPrefix, self.saveConstsToFile) + self.currRegion.addBasicBlock("main") # Is this an illegal test that is expected to fail? self.expectedReturnCode = 0 @@ -750,6 +746,9 @@ class TosaSerializer: def addOperator(self, op, inputs, outputs, attributes=None): return self.currRegion.addOperator(op, inputs, outputs, attributes) + def addBasicBlock(self, name): + self.currRegion.addBasicBlock(name) + def setExpectedReturnCode(self, val, fail, desc=""): self.expectedReturnCode = val @@ -792,7 +791,7 @@ class TosaSerializer: for region in self.regions: for block in region.basicBlocks: - if block: + if block and block.name == "main": for i in block.inputs: ifm_name.append(i) ifm_file.append(block.tensors[i].placeholderFilename) -- cgit v1.2.1