aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-08-14 20:15:10 +0000
committerJerry Ge <jerry.ge@arm.com>2023-08-14 21:20:24 +0000
commit53ceb4838914d50be1537d2eb32f569aee9ff823 (patch)
treeb79f0abaedfaeda0bd23f934b3cfd6116b7c61d3
parent1adc5d05d9fd21591790678a3f1cdaa4c5b347c4 (diff)
downloadserialization_lib-53ceb4838914d50be1537d2eb32f569aee9ff823.tar.gz
Enable passing in custom names for addConst
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I22e73e2aa9fbd54610fed776da9fbd09a4adae25
-rw-r--r--python/serializer/tosa_serializer.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index a1109d8..8330c3e 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -667,12 +667,13 @@ class TosaSerializerRegion:
return tens
- def addConst(self, shape, dtype, vals):
+ def addConst(self, shape, dtype, vals, name=None):
if not self.currBasicBlock:
raise Exception("addTensor called without valid basic block")
- name = "const-{}".format(self.currInputIdx)
- self.currInputIdx = self.currInputIdx + 1
+ if name is None:
+ name = "const-{}".format(self.currInputIdx)
+ self.currInputIdx = self.currInputIdx + 1
if self.constMode == ConstMode.INPUTS:
# Save const as input file
@@ -773,8 +774,8 @@ class TosaSerializer:
def addPlaceholder(self, shape, dtype, vals):
return self.currRegion.addPlaceholder(shape, dtype, vals)
- def addConst(self, shape, dtype, vals):
- return self.currRegion.addConst(shape, dtype, vals)
+ def addConst(self, shape, dtype, vals, name=None):
+ return self.currRegion.addConst(shape, dtype, vals, name)
def addIntermediate(self, shape, dtype):
return self.currRegion.addIntermediate(shape, dtype)