aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-13 18:52:45 +0000
committerTai Ly <tai.ly@arm.com>2024-03-13 19:25:43 +0000
commitad78daaf0fa1e41742cbed314459c3dbbb483c20 (patch)
treef27e56f497b796d36676b42bb713deb1af883a31 /python
parent0b6d7c271af1e6593e6a2cf14b32acea765f4b64 (diff)
downloadserialization_lib-ad78daaf0fa1e41742cbed314459c3dbbb483c20.tar.gz
[serialization_lib] Add acc_type to Conv Attrs
This adds acc_type to ConvAttribute and TransposeConvAttribute Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I73bab71b2eb90f6451fadee21d5bed1811ecbfd7
Diffstat (limited to 'python')
-rw-r--r--python/serializer/tosa_serializer.py8
-rw-r--r--python/tosa/ConvAttribute.py15
-rw-r--r--python/tosa/TransposeConvAttribute.py15
3 files changed, 34 insertions, 4 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index 2c7996a..9658edf 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -172,7 +172,9 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddOutputZp, output_zp))
self.ints.append((a.AddAccType, acc_type))
- def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp, local_bound):
+ def ConvAttribute(
+ self, pad, stride, dilation, input_zp, weight_zp, local_bound, acc_type
+ ):
from tosa import ConvAttribute as a, Attribute
self.utype = Attribute.Attribute().ConvAttribute
@@ -184,9 +186,10 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddInputZp, input_zp))
self.ints.append((a.AddWeightZp, weight_zp))
self.bools.append((a.AddLocalBound, local_bound))
+ self.ints.append((a.AddAccType, acc_type))
def TransposeConvAttribute(
- self, outpad, stride, output_shape, input_zp, weight_zp, local_bound
+ self, outpad, stride, output_shape, input_zp, weight_zp, local_bound, acc_type
):
from tosa import TransposeConvAttribute as a, Attribute
@@ -199,6 +202,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddInputZp, input_zp))
self.ints.append((a.AddWeightZp, weight_zp))
self.bools.append((a.AddLocalBound, local_bound))
+ self.ints.append((a.AddAccType, acc_type))
def PadAttribute(self, serializer_builder, pad_const_val_as_bytes):
from tosa import PadAttribute as a, Attribute
diff --git a/python/tosa/ConvAttribute.py b/python/tosa/ConvAttribute.py
index b35b67c..dfa75dc 100644
--- a/python/tosa/ConvAttribute.py
+++ b/python/tosa/ConvAttribute.py
@@ -130,8 +130,15 @@ class ConvAttribute(object):
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False
+ # ConvAttribute
+ def AccType(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
+ return 0
+
def ConvAttributeStart(builder):
- builder.StartObject(6)
+ builder.StartObject(7)
def Start(builder):
ConvAttributeStart(builder)
@@ -190,6 +197,12 @@ def ConvAttributeAddLocalBound(builder, localBound):
def AddLocalBound(builder, localBound):
ConvAttributeAddLocalBound(builder, localBound)
+def ConvAttributeAddAccType(builder, accType):
+ builder.PrependUint32Slot(6, accType, 0)
+
+def AddAccType(builder, accType):
+ ConvAttributeAddAccType(builder, accType)
+
def ConvAttributeEnd(builder):
return builder.EndObject()
diff --git a/python/tosa/TransposeConvAttribute.py b/python/tosa/TransposeConvAttribute.py
index a74a433..e5397a8 100644
--- a/python/tosa/TransposeConvAttribute.py
+++ b/python/tosa/TransposeConvAttribute.py
@@ -130,8 +130,15 @@ class TransposeConvAttribute(object):
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False
+ # TransposeConvAttribute
+ def AccType(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
+ return 0
+
def TransposeConvAttributeStart(builder):
- builder.StartObject(6)
+ builder.StartObject(7)
def Start(builder):
TransposeConvAttributeStart(builder)
@@ -190,6 +197,12 @@ def TransposeConvAttributeAddLocalBound(builder, localBound):
def AddLocalBound(builder, localBound):
TransposeConvAttributeAddLocalBound(builder, localBound)
+def TransposeConvAttributeAddAccType(builder, accType):
+ builder.PrependUint32Slot(6, accType, 0)
+
+def AddAccType(builder, accType):
+ TransposeConvAttributeAddAccType(builder, accType)
+
def TransposeConvAttributeEnd(builder):
return builder.EndObject()