From 4881c29247d4b411de446b13d9bd58ea93737aac Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Wed, 1 Nov 2023 16:12:07 -0700 Subject: Add support for local_bound attribute local_bound is used to determine when fast convolution algorithms can be used in implementing the operation. Signed-off-by: Eric Kunze Change-Id: I9970a2544e90a620f46ac4d3d01cec90a15710a9 --- python/tosa/Attribute.py | 1 + python/tosa/ConvAttribute.py | 15 +++++++++- python/tosa/FFTAttribute.py | 15 +++++++++- python/tosa/RFFTAttribute.py | 54 +++++++++++++++++++++++++++++++++++ python/tosa/TransposeConvAttribute.py | 15 +++++++++- 5 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 python/tosa/RFFTAttribute.py (limited to 'python/tosa') diff --git a/python/tosa/Attribute.py b/python/tosa/Attribute.py index 25ade44..2589f71 100644 --- a/python/tosa/Attribute.py +++ b/python/tosa/Attribute.py @@ -26,3 +26,4 @@ class Attribute(object): NegateAttribute = 20 CustomAttribute = 21 FFTAttribute = 22 + RFFTAttribute = 23 diff --git a/python/tosa/ConvAttribute.py b/python/tosa/ConvAttribute.py index 6db82d7..b35b67c 100644 --- a/python/tosa/ConvAttribute.py +++ b/python/tosa/ConvAttribute.py @@ -123,8 +123,15 @@ class ConvAttribute(object): return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 + # ConvAttribute + def LocalBound(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + def ConvAttributeStart(builder): - builder.StartObject(5) + builder.StartObject(6) def Start(builder): ConvAttributeStart(builder) @@ -177,6 +184,12 @@ def ConvAttributeAddWeightZp(builder, weightZp): def AddWeightZp(builder, weightZp): ConvAttributeAddWeightZp(builder, weightZp) +def ConvAttributeAddLocalBound(builder, localBound): + builder.PrependBoolSlot(5, localBound, 0) + +def AddLocalBound(builder, localBound): + ConvAttributeAddLocalBound(builder, localBound) + def ConvAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/FFTAttribute.py b/python/tosa/FFTAttribute.py index 0f22aa7..d1624c2 100644 --- a/python/tosa/FFTAttribute.py +++ b/python/tosa/FFTAttribute.py @@ -35,8 +35,15 @@ class FFTAttribute(object): return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) return False + # FFTAttribute + def LocalBound(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + def FFTAttributeStart(builder): - builder.StartObject(1) + builder.StartObject(2) def Start(builder): FFTAttributeStart(builder) @@ -47,6 +54,12 @@ def FFTAttributeAddInverse(builder, inverse): def AddInverse(builder, inverse): FFTAttributeAddInverse(builder, inverse) +def FFTAttributeAddLocalBound(builder, localBound): + builder.PrependBoolSlot(1, localBound, 0) + +def AddLocalBound(builder, localBound): + FFTAttributeAddLocalBound(builder, localBound) + def FFTAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/RFFTAttribute.py b/python/tosa/RFFTAttribute.py new file mode 100644 index 0000000..7f76024 --- /dev/null +++ b/python/tosa/RFFTAttribute.py @@ -0,0 +1,54 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class RFFTAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = RFFTAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsRFFTAttribute(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + @classmethod + def RFFTAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # RFFTAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # RFFTAttribute + def LocalBound(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + +def RFFTAttributeStart(builder): + builder.StartObject(1) + +def Start(builder): + RFFTAttributeStart(builder) + +def RFFTAttributeAddLocalBound(builder, localBound): + builder.PrependBoolSlot(0, localBound, 0) + +def AddLocalBound(builder, localBound): + RFFTAttributeAddLocalBound(builder, localBound) + +def RFFTAttributeEnd(builder): + return builder.EndObject() + +def End(builder): + return RFFTAttributeEnd(builder) diff --git a/python/tosa/TransposeConvAttribute.py b/python/tosa/TransposeConvAttribute.py index def507e..a74a433 100644 --- a/python/tosa/TransposeConvAttribute.py +++ b/python/tosa/TransposeConvAttribute.py @@ -123,8 +123,15 @@ class TransposeConvAttribute(object): return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 + # TransposeConvAttribute + def LocalBound(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + def TransposeConvAttributeStart(builder): - builder.StartObject(5) + builder.StartObject(6) def Start(builder): TransposeConvAttributeStart(builder) @@ -177,6 +184,12 @@ def TransposeConvAttributeAddWeightZp(builder, weightZp): def AddWeightZp(builder, weightZp): TransposeConvAttributeAddWeightZp(builder, weightZp) +def TransposeConvAttributeAddLocalBound(builder, localBound): + builder.PrependBoolSlot(5, localBound, 0) + +def AddLocalBound(builder, localBound): + TransposeConvAttributeAddLocalBound(builder, localBound) + def TransposeConvAttributeEnd(builder): return builder.EndObject() -- cgit v1.2.1