aboutsummaryrefslogtreecommitdiff
path: root/python/serializer/tosa_serializer.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/serializer/tosa_serializer.py')
-rw-r--r--python/serializer/tosa_serializer.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index a1109d8..8330c3e 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -667,12 +667,13 @@ class TosaSerializerRegion:
return tens
- def addConst(self, shape, dtype, vals):
+ def addConst(self, shape, dtype, vals, name=None):
if not self.currBasicBlock:
raise Exception("addTensor called without valid basic block")
- name = "const-{}".format(self.currInputIdx)
- self.currInputIdx = self.currInputIdx + 1
+ if name is None:
+ name = "const-{}".format(self.currInputIdx)
+ self.currInputIdx = self.currInputIdx + 1
if self.constMode == ConstMode.INPUTS:
# Save const as input file
@@ -773,8 +774,8 @@ class TosaSerializer:
def addPlaceholder(self, shape, dtype, vals):
return self.currRegion.addPlaceholder(shape, dtype, vals)
- def addConst(self, shape, dtype, vals):
- return self.currRegion.addConst(shape, dtype, vals)
+ def addConst(self, shape, dtype, vals, name=None):
+ return self.currRegion.addConst(shape, dtype, vals, name)
def addIntermediate(self, shape, dtype):
return self.currRegion.addIntermediate(shape, dtype)