From f5dfad14f0cdc9556785b610674350c2e5a33553 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Wed, 15 Nov 2023 21:09:58 +0000 Subject: [serialization_lib] Add local_bound to tosa_serializer.py This adds local_bound to python attribute constructors for: ConvAttribute TransposeConvAttribute FFTAttribute RFFTAttribute Signed-off-by: Tai Ly Change-Id: I26ae8d718425e5563e281c44490348361c7d44b4 --- python/serializer/tosa_serializer.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index bd3a500..563bc00 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -171,7 +171,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddOutputZp, output_zp)) self.ints.append((a.AddAccumDtype, accum_dtype)) - def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp): + def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp, local_bound): from tosa import ConvAttribute as a, Attribute self.utype = Attribute.Attribute().ConvAttribute @@ -182,8 +182,11 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddDilation, dilation)) self.ints.append((a.AddInputZp, input_zp)) self.ints.append((a.AddWeightZp, weight_zp)) + self.bools.append((a.AddLocalBound, local_bound)) - def TransposeConvAttribute(self, outpad, stride, output_shape, input_zp, weight_zp): + def TransposeConvAttribute( + self, outpad, stride, output_shape, input_zp, weight_zp, local_bound + ): from tosa import TransposeConvAttribute as a, Attribute self.utype = Attribute.Attribute().TransposeConvAttribute @@ -194,6 +197,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddOutputShape, output_shape)) self.ints.append((a.AddInputZp, input_zp)) self.ints.append((a.AddWeightZp, weight_zp)) + self.bools.append((a.AddLocalBound, local_bound)) def PadAttribute(self, serializer_builder, padding, pad_const_int, pad_const_fp): from tosa import PadAttribute as a, Attribute @@ -374,13 +378,22 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddInput1Zp, input1_zp)) self.ints.append((a.AddOutputZp, output_zp)) - def FFTAttribute(self, inverse): + def FFTAttribute(self, inverse, local_bound): from tosa import FFTAttribute as a, Attribute self.utype = Attribute.Attribute().FFTAttribute self.optFcns = (a.Start, a.End) self.bools.append((a.AddInverse, inverse)) + self.bools.append((a.AddLocalBound, local_bound)) + + def RFFTAttribute(self, local_bound): + from tosa import RFFTAttribute as a, Attribute + + self.utype = Attribute.Attribute().RFFTAttribute + self.optFcns = (a.Start, a.End) + + self.bools.append((a.AddLocalBound, local_bound)) class TosaSerializerTensor: -- cgit v1.2.1