diff options
Diffstat (limited to 'python/serializer/tosa_serializer.py')
-rw-r--r-- | python/serializer/tosa_serializer.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index a1109d8..8330c3e 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -667,12 +667,13 @@ class TosaSerializerRegion: return tens - def addConst(self, shape, dtype, vals): + def addConst(self, shape, dtype, vals, name=None): if not self.currBasicBlock: raise Exception("addTensor called without valid basic block") - name = "const-{}".format(self.currInputIdx) - self.currInputIdx = self.currInputIdx + 1 + if name is None: + name = "const-{}".format(self.currInputIdx) + self.currInputIdx = self.currInputIdx + 1 if self.constMode == ConstMode.INPUTS: # Save const as input file @@ -773,8 +774,8 @@ class TosaSerializer: def addPlaceholder(self, shape, dtype, vals): return self.currRegion.addPlaceholder(shape, dtype, vals) - def addConst(self, shape, dtype, vals): - return self.currRegion.addConst(shape, dtype, vals) + def addConst(self, shape, dtype, vals, name=None): + return self.currRegion.addConst(shape, dtype, vals, name) def addIntermediate(self, shape, dtype): return self.currRegion.addIntermediate(shape, dtype) |