From ad78daaf0fa1e41742cbed314459c3dbbb483c20 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Wed, 13 Mar 2024 18:52:45 +0000 Subject: [serialization_lib] Add acc_type to Conv Attrs This adds acc_type to ConvAttribute and TransposeConvAttribute Signed-off-by: Tai Ly Change-Id: I73bab71b2eb90f6451fadee21d5bed1811ecbfd7 --- python/serializer/tosa_serializer.py | 8 ++++++-- python/tosa/ConvAttribute.py | 15 ++++++++++++++- python/tosa/TransposeConvAttribute.py | 15 ++++++++++++++- 3 files changed, 34 insertions(+), 4 deletions(-) (limited to 'python') diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index 2c7996a..9658edf 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -172,7 +172,9 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddOutputZp, output_zp)) self.ints.append((a.AddAccType, acc_type)) - def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp, local_bound): + def ConvAttribute( + self, pad, stride, dilation, input_zp, weight_zp, local_bound, acc_type + ): from tosa import ConvAttribute as a, Attribute self.utype = Attribute.Attribute().ConvAttribute @@ -184,9 +186,10 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddInputZp, input_zp)) self.ints.append((a.AddWeightZp, weight_zp)) self.bools.append((a.AddLocalBound, local_bound)) + self.ints.append((a.AddAccType, acc_type)) def TransposeConvAttribute( - self, outpad, stride, output_shape, input_zp, weight_zp, local_bound + self, outpad, stride, output_shape, input_zp, weight_zp, local_bound, acc_type ): from tosa import TransposeConvAttribute as a, Attribute @@ -199,6 +202,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddInputZp, input_zp)) self.ints.append((a.AddWeightZp, weight_zp)) self.bools.append((a.AddLocalBound, local_bound)) + self.ints.append((a.AddAccType, acc_type)) def PadAttribute(self, serializer_builder, pad_const_val_as_bytes): from tosa import PadAttribute as a, Attribute diff --git a/python/tosa/ConvAttribute.py b/python/tosa/ConvAttribute.py index b35b67c..dfa75dc 100644 --- a/python/tosa/ConvAttribute.py +++ b/python/tosa/ConvAttribute.py @@ -130,8 +130,15 @@ class ConvAttribute(object): return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) return False + # ConvAttribute + def AccType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + def ConvAttributeStart(builder): - builder.StartObject(6) + builder.StartObject(7) def Start(builder): ConvAttributeStart(builder) @@ -190,6 +197,12 @@ def ConvAttributeAddLocalBound(builder, localBound): def AddLocalBound(builder, localBound): ConvAttributeAddLocalBound(builder, localBound) +def ConvAttributeAddAccType(builder, accType): + builder.PrependUint32Slot(6, accType, 0) + +def AddAccType(builder, accType): + ConvAttributeAddAccType(builder, accType) + def ConvAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/TransposeConvAttribute.py b/python/tosa/TransposeConvAttribute.py index a74a433..e5397a8 100644 --- a/python/tosa/TransposeConvAttribute.py +++ b/python/tosa/TransposeConvAttribute.py @@ -130,8 +130,15 @@ class TransposeConvAttribute(object): return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) return False + # TransposeConvAttribute + def AccType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + def TransposeConvAttributeStart(builder): - builder.StartObject(6) + builder.StartObject(7) def Start(builder): TransposeConvAttributeStart(builder) @@ -190,6 +197,12 @@ def TransposeConvAttributeAddLocalBound(builder, localBound): def AddLocalBound(builder, localBound): TransposeConvAttributeAddLocalBound(builder, localBound) +def TransposeConvAttributeAddAccType(builder, accType): + builder.PrependUint32Slot(6, accType, 0) + +def AddAccType(builder, accType): + TransposeConvAttributeAddAccType(builder, accType) + def TransposeConvAttributeEnd(builder): return builder.EndObject() -- cgit v1.2.1