aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-11-15 21:09:58 +0000
committerTai Ly <tai.ly@arm.com>2023-11-16 01:29:06 +0000
commitf5dfad14f0cdc9556785b610674350c2e5a33553 (patch)
tree83c89dacea85164e63ea299f54167987f6a32e86
parent9bb9c4e67b42a1c6b3f3cc93c4f3a345856c010f (diff)
downloadserialization_lib-f5dfad14f0cdc9556785b610674350c2e5a33553.tar.gz
[serialization_lib] Add local_bound to tosa_serializer.py
This adds local_bound to python attribute constructors for: ConvAttribute TransposeConvAttribute FFTAttribute RFFTAttribute Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I26ae8d718425e5563e281c44490348361c7d44b4
-rw-r--r--python/serializer/tosa_serializer.py19
1 files changed, 16 insertions, 3 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index bd3a500..563bc00 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -171,7 +171,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddOutputZp, output_zp))
self.ints.append((a.AddAccumDtype, accum_dtype))
- def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp):
+ def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp, local_bound):
from tosa import ConvAttribute as a, Attribute
self.utype = Attribute.Attribute().ConvAttribute
@@ -182,8 +182,11 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.AddDilation, dilation))
self.ints.append((a.AddInputZp, input_zp))
self.ints.append((a.AddWeightZp, weight_zp))
+ self.bools.append((a.AddLocalBound, local_bound))
- def TransposeConvAttribute(self, outpad, stride, output_shape, input_zp, weight_zp):
+ def TransposeConvAttribute(
+ self, outpad, stride, output_shape, input_zp, weight_zp, local_bound
+ ):
from tosa import TransposeConvAttribute as a, Attribute
self.utype = Attribute.Attribute().TransposeConvAttribute
@@ -194,6 +197,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.AddOutputShape, output_shape))
self.ints.append((a.AddInputZp, input_zp))
self.ints.append((a.AddWeightZp, weight_zp))
+ self.bools.append((a.AddLocalBound, local_bound))
def PadAttribute(self, serializer_builder, padding, pad_const_int, pad_const_fp):
from tosa import PadAttribute as a, Attribute
@@ -374,13 +378,22 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddInput1Zp, input1_zp))
self.ints.append((a.AddOutputZp, output_zp))
- def FFTAttribute(self, inverse):
+ def FFTAttribute(self, inverse, local_bound):
from tosa import FFTAttribute as a, Attribute
self.utype = Attribute.Attribute().FFTAttribute
self.optFcns = (a.Start, a.End)
self.bools.append((a.AddInverse, inverse))
+ self.bools.append((a.AddLocalBound, local_bound))
+
+ def RFFTAttribute(self, local_bound):
+ from tosa import RFFTAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().RFFTAttribute
+ self.optFcns = (a.Start, a.End)
+
+ self.bools.append((a.AddLocalBound, local_bound))
class TosaSerializerTensor: