aboutsummaryrefslogtreecommitdiff
path: root/python/tosa_serializer.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/tosa_serializer.py')
-rw-r--r--python/tosa_serializer.py36
1 files changed, 22 insertions, 14 deletions
diff --git a/python/tosa_serializer.py b/python/tosa_serializer.py
index d85494d..3d0019e 100644
--- a/python/tosa_serializer.py
+++ b/python/tosa_serializer.py
@@ -170,14 +170,15 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.TransposeConvAttributeAddDilation, dilation))
self.intvecs.append((a.TransposeConvAttributeAddOutputShape, output_shape))
- def ReluNAttribute(self, maxint, maxfp):
- from tosa import ReluNAttribute as a, Attribute
+ def PadAttribute(self, padding, pad_const_int, pad_const_fp):
+ from tosa import PadAttribute as a, Attribute
- self.utype = Attribute.Attribute().ReluNAttribute
- self.optFcns = (a.ReluNAttributeStart, a.ReluNAttributeEnd)
+ self.utype = Attribute.Attribute().PadAttribute
+ self.optFcns = (a.PadAttributeStart, a.PadAttributeEnd)
- self.ints.append((a.ReluNAttributeAddMaxInt, maxint))
- self.ints.append((a.ReluNAttributeAddMaxFp, maxfp))
+ self.intvecs.append((a.PadAttributeAddPadding, padding))
+ self.ints.append((a.PadAttributeAddPadConstInt, pad_const_int))
+ self.floats.append((a.PadAttributeAddPadConstFp, pad_const_fp))
def AxisAttribute(self, axis):
from tosa import AxisAttribute as a, Attribute
@@ -275,14 +276,6 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.bools.append((a.ArithmeticRightShiftAttributeAddRound, round))
- def CustomAttribute(self, identifier):
- from tosa import CustomAttribute as a, Attribute
-
- self.utype = Attribute.Attribute().CustomAttribute
- self.optFcns = (a.CustomAttributeStart, a.CustomAttributeEnd)
-
- self.strings.append((a.CustomAttributeAddIdentifier, identifier))
-
def CondIfAttribute(self, then_branch, else_branch):
from tosa import CondIfAttribute as a, Attribute
@@ -301,6 +294,21 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.strings.append((a.WhileLoopAttributeAddCondBranch, cond_branch))
self.strings.append((a.WhileLoopAttributeAddBodyBranch, body_branch))
+ def TransposeAttribute(self, perm):
+ from tosa import TransposeAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().TransposeAttribute
+ self.optFcns = (a.TransposeAttributeStart, a.TransposeAttributeEnd)
+
+ self.intvecs.append((a.TransposeAttributeAddPerm, perm))
+
+ def TableAttribute(self, table):
+ from tosa import TableAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().TableAttribute
+ self.optFcns = (a.TableAttributeStart, a.TableAttributeEnd)
+
+ self.intvecs.append((a.TableAttributeAddTable, table))
class TosaSerializerQuantInfo(TosaSerializerUnion):
"""This class handles encapsulating all of the enumerated types for quantinfo types"""