diff options
author | Jerry Ge <jerry.ge@arm.com> | 2023-08-14 20:15:10 +0000 |
---|---|---|
committer | Jerry Ge <jerry.ge@arm.com> | 2023-08-14 21:20:24 +0000 |
commit | 53ceb4838914d50be1537d2eb32f569aee9ff823 (patch) | |
tree | b79f0abaedfaeda0bd23f934b3cfd6116b7c61d3 | |
parent | 1adc5d05d9fd21591790678a3f1cdaa4c5b347c4 (diff) | |
download | serialization_lib-53ceb4838914d50be1537d2eb32f569aee9ff823.tar.gz |
Enable passing in custom names for addConst
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I22e73e2aa9fbd54610fed776da9fbd09a4adae25
-rw-r--r-- | python/serializer/tosa_serializer.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index a1109d8..8330c3e 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -667,12 +667,13 @@ class TosaSerializerRegion: return tens - def addConst(self, shape, dtype, vals): + def addConst(self, shape, dtype, vals, name=None): if not self.currBasicBlock: raise Exception("addTensor called without valid basic block") - name = "const-{}".format(self.currInputIdx) - self.currInputIdx = self.currInputIdx + 1 + if name is None: + name = "const-{}".format(self.currInputIdx) + self.currInputIdx = self.currInputIdx + 1 if self.constMode == ConstMode.INPUTS: # Save const as input file @@ -773,8 +774,8 @@ class TosaSerializer: def addPlaceholder(self, shape, dtype, vals): return self.currRegion.addPlaceholder(shape, dtype, vals) - def addConst(self, shape, dtype, vals): - return self.currRegion.addConst(shape, dtype, vals) + def addConst(self, shape, dtype, vals, name=None): + return self.currRegion.addConst(shape, dtype, vals, name) def addIntermediate(self, shape, dtype): return self.currRegion.addIntermediate(shape, dtype) |