aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/serializer/tosa_serializer.py19
1 files changed, 16 insertions, 3 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index bd3a500..563bc00 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -171,7 +171,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddOutputZp, output_zp))
self.ints.append((a.AddAccumDtype, accum_dtype))
- def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp):
+ def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp, local_bound):
from tosa import ConvAttribute as a, Attribute
self.utype = Attribute.Attribute().ConvAttribute
@@ -182,8 +182,11 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.AddDilation, dilation))
self.ints.append((a.AddInputZp, input_zp))
self.ints.append((a.AddWeightZp, weight_zp))
+ self.bools.append((a.AddLocalBound, local_bound))
- def TransposeConvAttribute(self, outpad, stride, output_shape, input_zp, weight_zp):
+ def TransposeConvAttribute(
+ self, outpad, stride, output_shape, input_zp, weight_zp, local_bound
+ ):
from tosa import TransposeConvAttribute as a, Attribute
self.utype = Attribute.Attribute().TransposeConvAttribute
@@ -194,6 +197,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.AddOutputShape, output_shape))
self.ints.append((a.AddInputZp, input_zp))
self.ints.append((a.AddWeightZp, weight_zp))
+ self.bools.append((a.AddLocalBound, local_bound))
def PadAttribute(self, serializer_builder, padding, pad_const_int, pad_const_fp):
from tosa import PadAttribute as a, Attribute
@@ -374,13 +378,22 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddInput1Zp, input1_zp))
self.ints.append((a.AddOutputZp, output_zp))
- def FFTAttribute(self, inverse):
+ def FFTAttribute(self, inverse, local_bound):
from tosa import FFTAttribute as a, Attribute
self.utype = Attribute.Attribute().FFTAttribute
self.optFcns = (a.Start, a.End)
self.bools.append((a.AddInverse, inverse))
+ self.bools.append((a.AddLocalBound, local_bound))
+
+ def RFFTAttribute(self, local_bound):
+ from tosa import RFFTAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().RFFTAttribute
+ self.optFcns = (a.Start, a.End)
+
+ self.bools.append((a.AddLocalBound, local_bound))
class TosaSerializerTensor: