From 4881c29247d4b411de446b13d9bd58ea93737aac Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Wed, 1 Nov 2023 16:12:07 -0700 Subject: Add support for local_bound attribute local_bound is used to determine when fast convolution algorithms can be used in implementing the operation. Signed-off-by: Eric Kunze Change-Id: I9970a2544e90a620f46ac4d3d01cec90a15710a9 --- include/attribute.def | 18 +++-- include/tosa_generated.h | 126 ++++++++++++++++++++++++++++++---- python/tosa/Attribute.py | 1 + python/tosa/ConvAttribute.py | 15 +++- python/tosa/FFTAttribute.py | 15 +++- python/tosa/RFFTAttribute.py | 54 +++++++++++++++ python/tosa/TransposeConvAttribute.py | 15 +++- schema/tosa.fbs | 8 +++ 8 files changed, 228 insertions(+), 24 deletions(-) create mode 100644 python/tosa/RFFTAttribute.py diff --git a/include/attribute.def b/include/attribute.def index b72f65a..dc7a569 100644 --- a/include/attribute.def +++ b/include/attribute.def @@ -34,19 +34,21 @@ DEF_ATTRIBUTE(Pool, 6, int32_t, S, output_zp, DType, S, accum_dtype) -DEF_ATTRIBUTE(Conv, 5, +DEF_ATTRIBUTE(Conv, 6, int32_t, V, pad, int32_t, V, stride, int32_t, V, dilation, int32_t, S, input_zp, - int32_t, S, weight_zp) + int32_t, S, weight_zp, + bool, S, local_bound) -DEF_ATTRIBUTE(TransposeConv, 5, +DEF_ATTRIBUTE(TransposeConv, 6, int32_t, V, out_pad, int32_t, V, stride, int32_t, V, output_shape, int32_t, S, input_zp, - int32_t, S, weight_zp) + int32_t, S, weight_zp, + bool, S, local_bound) DEF_ATTRIBUTE(Pad, 3, int32_t, V, padding, @@ -126,6 +128,10 @@ DEF_ATTRIBUTE(Custom, 3, string, S, config, uint8_t, V, implementation_attrs) -DEF_ATTRIBUTE(FFT, 1, - bool, S, inverse) +DEF_ATTRIBUTE(FFT, 2, + bool, S, inverse, + bool, S, local_bound) + +DEF_ATTRIBUTE(RFFT, 1, + bool, S, local_bound) diff --git a/include/tosa_generated.h b/include/tosa_generated.h index a81ff9c..6cc0d42 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -81,6 +81,9 @@ struct CustomAttributeBuilder; struct FFTAttribute; struct FFTAttributeBuilder; +struct RFFTAttribute; +struct RFFTAttributeBuilder; + struct Version; struct VersionBuilder; @@ -462,11 +465,12 @@ enum Attribute : uint8_t { Attribute_NegateAttribute = 20, Attribute_CustomAttribute = 21, Attribute_FFTAttribute = 22, + Attribute_RFFTAttribute = 23, Attribute_MIN = Attribute_NONE, - Attribute_MAX = Attribute_FFTAttribute + Attribute_MAX = Attribute_RFFTAttribute }; -inline const Attribute (&EnumValuesAttribute())[23] { +inline const Attribute (&EnumValuesAttribute())[24] { static const Attribute values[] = { Attribute_NONE, Attribute_PoolAttribute, @@ -490,13 +494,14 @@ inline const Attribute (&EnumValuesAttribute())[23] { Attribute_FullyConnectedAttribute, Attribute_NegateAttribute, Attribute_CustomAttribute, - Attribute_FFTAttribute + Attribute_FFTAttribute, + Attribute_RFFTAttribute }; return values; } inline const char * const *EnumNamesAttribute() { - static const char * const names[24] = { + static const char * const names[25] = { "NONE", "PoolAttribute", "ConvAttribute", @@ -520,13 +525,14 @@ inline const char * const *EnumNamesAttribute() { "NegateAttribute", "CustomAttribute", "FFTAttribute", + "RFFTAttribute", nullptr }; return names; } inline const char *EnumNameAttribute(Attribute e) { - if (::flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_FFTAttribute)) return ""; + if (::flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_RFFTAttribute)) return ""; const size_t index = static_cast(e); return EnumNamesAttribute()[index]; } @@ -623,6 +629,10 @@ template<> struct AttributeTraits { static const Attribute enum_value = Attribute_FFTAttribute; }; +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_RFFTAttribute; +}; + bool VerifyAttribute(::flatbuffers::Verifier &verifier, const void *obj, Attribute type); bool VerifyAttributeVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types); @@ -748,7 +758,8 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_STRIDE = 6, VT_DILATION = 8, VT_INPUT_ZP = 10, - VT_WEIGHT_ZP = 12 + VT_WEIGHT_ZP = 12, + VT_LOCAL_BOUND = 14 }; const ::flatbuffers::Vector *pad() const { return GetPointer *>(VT_PAD); @@ -765,6 +776,9 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { int32_t weight_zp() const { return GetField(VT_WEIGHT_ZP, 0); } + bool local_bound() const { + return GetField(VT_LOCAL_BOUND, 0) != 0; + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PAD) && @@ -775,6 +789,7 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { verifier.VerifyVector(dilation()) && VerifyField(verifier, VT_INPUT_ZP, 4) && VerifyField(verifier, VT_WEIGHT_ZP, 4) && + VerifyField(verifier, VT_LOCAL_BOUND, 1) && verifier.EndTable(); } }; @@ -798,6 +813,9 @@ struct ConvAttributeBuilder { void add_weight_zp(int32_t weight_zp) { fbb_.AddElement(ConvAttribute::VT_WEIGHT_ZP, weight_zp, 0); } + void add_local_bound(bool local_bound) { + fbb_.AddElement(ConvAttribute::VT_LOCAL_BOUND, static_cast(local_bound), 0); + } explicit ConvAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -815,13 +833,15 @@ inline ::flatbuffers::Offset CreateConvAttribute( ::flatbuffers::Offset<::flatbuffers::Vector> stride = 0, ::flatbuffers::Offset<::flatbuffers::Vector> dilation = 0, int32_t input_zp = 0, - int32_t weight_zp = 0) { + int32_t weight_zp = 0, + bool local_bound = false) { ConvAttributeBuilder builder_(_fbb); builder_.add_weight_zp(weight_zp); builder_.add_input_zp(input_zp); builder_.add_dilation(dilation); builder_.add_stride(stride); builder_.add_pad(pad); + builder_.add_local_bound(local_bound); return builder_.Finish(); } @@ -831,7 +851,8 @@ inline ::flatbuffers::Offset CreateConvAttributeDirect( const std::vector *stride = nullptr, const std::vector *dilation = nullptr, int32_t input_zp = 0, - int32_t weight_zp = 0) { + int32_t weight_zp = 0, + bool local_bound = false) { auto pad__ = pad ? _fbb.CreateVector(*pad) : 0; auto stride__ = stride ? _fbb.CreateVector(*stride) : 0; auto dilation__ = dilation ? _fbb.CreateVector(*dilation) : 0; @@ -841,7 +862,8 @@ inline ::flatbuffers::Offset CreateConvAttributeDirect( stride__, dilation__, input_zp, - weight_zp); + weight_zp, + local_bound); } struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -851,7 +873,8 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T VT_STRIDE = 6, VT_OUTPUT_SHAPE = 8, VT_INPUT_ZP = 10, - VT_WEIGHT_ZP = 12 + VT_WEIGHT_ZP = 12, + VT_LOCAL_BOUND = 14 }; const ::flatbuffers::Vector *out_pad() const { return GetPointer *>(VT_OUT_PAD); @@ -868,6 +891,9 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T int32_t weight_zp() const { return GetField(VT_WEIGHT_ZP, 0); } + bool local_bound() const { + return GetField(VT_LOCAL_BOUND, 0) != 0; + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_OUT_PAD) && @@ -878,6 +904,7 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T verifier.VerifyVector(output_shape()) && VerifyField(verifier, VT_INPUT_ZP, 4) && VerifyField(verifier, VT_WEIGHT_ZP, 4) && + VerifyField(verifier, VT_LOCAL_BOUND, 1) && verifier.EndTable(); } }; @@ -901,6 +928,9 @@ struct TransposeConvAttributeBuilder { void add_weight_zp(int32_t weight_zp) { fbb_.AddElement(TransposeConvAttribute::VT_WEIGHT_ZP, weight_zp, 0); } + void add_local_bound(bool local_bound) { + fbb_.AddElement(TransposeConvAttribute::VT_LOCAL_BOUND, static_cast(local_bound), 0); + } explicit TransposeConvAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -918,13 +948,15 @@ inline ::flatbuffers::Offset CreateTransposeConvAttribut ::flatbuffers::Offset<::flatbuffers::Vector> stride = 0, ::flatbuffers::Offset<::flatbuffers::Vector> output_shape = 0, int32_t input_zp = 0, - int32_t weight_zp = 0) { + int32_t weight_zp = 0, + bool local_bound = false) { TransposeConvAttributeBuilder builder_(_fbb); builder_.add_weight_zp(weight_zp); builder_.add_input_zp(input_zp); builder_.add_output_shape(output_shape); builder_.add_stride(stride); builder_.add_out_pad(out_pad); + builder_.add_local_bound(local_bound); return builder_.Finish(); } @@ -934,7 +966,8 @@ inline ::flatbuffers::Offset CreateTransposeConvAttribut const std::vector *stride = nullptr, const std::vector *output_shape = nullptr, int32_t input_zp = 0, - int32_t weight_zp = 0) { + int32_t weight_zp = 0, + bool local_bound = false) { auto out_pad__ = out_pad ? _fbb.CreateVector(*out_pad) : 0; auto stride__ = stride ? _fbb.CreateVector(*stride) : 0; auto output_shape__ = output_shape ? _fbb.CreateVector(*output_shape) : 0; @@ -944,7 +977,8 @@ inline ::flatbuffers::Offset CreateTransposeConvAttribut stride__, output_shape__, input_zp, - weight_zp); + weight_zp, + local_bound); } struct PadAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -2113,14 +2147,19 @@ inline ::flatbuffers::Offset CreateCustomAttributeDirect( struct FFTAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { typedef FFTAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_INVERSE = 4 + VT_INVERSE = 4, + VT_LOCAL_BOUND = 6 }; bool inverse() const { return GetField(VT_INVERSE, 0) != 0; } + bool local_bound() const { + return GetField(VT_LOCAL_BOUND, 0) != 0; + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_INVERSE, 1) && + VerifyField(verifier, VT_LOCAL_BOUND, 1) && verifier.EndTable(); } }; @@ -2132,6 +2171,9 @@ struct FFTAttributeBuilder { void add_inverse(bool inverse) { fbb_.AddElement(FFTAttribute::VT_INVERSE, static_cast(inverse), 0); } + void add_local_bound(bool local_bound) { + fbb_.AddElement(FFTAttribute::VT_LOCAL_BOUND, static_cast(local_bound), 0); + } explicit FFTAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2145,12 +2187,55 @@ struct FFTAttributeBuilder { inline ::flatbuffers::Offset CreateFFTAttribute( ::flatbuffers::FlatBufferBuilder &_fbb, - bool inverse = false) { + bool inverse = false, + bool local_bound = false) { FFTAttributeBuilder builder_(_fbb); + builder_.add_local_bound(local_bound); builder_.add_inverse(inverse); return builder_.Finish(); } +struct RFFTAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef RFFTAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_LOCAL_BOUND = 4 + }; + bool local_bound() const { + return GetField(VT_LOCAL_BOUND, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_LOCAL_BOUND, 1) && + verifier.EndTable(); + } +}; + +struct RFFTAttributeBuilder { + typedef RFFTAttribute Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_local_bound(bool local_bound) { + fbb_.AddElement(RFFTAttribute::VT_LOCAL_BOUND, static_cast(local_bound), 0); + } + explicit RFFTAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateRFFTAttribute( + ::flatbuffers::FlatBufferBuilder &_fbb, + bool local_bound = false) { + RFFTAttributeBuilder builder_(_fbb); + builder_.add_local_bound(local_bound); + return builder_.Finish(); +} + struct Version FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { typedef VersionBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { @@ -2437,6 +2522,9 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const tosa::FFTAttribute *attribute_as_FFTAttribute() const { return attribute_type() == tosa::Attribute_FFTAttribute ? static_cast(attribute()) : nullptr; } + const tosa::RFFTAttribute *attribute_as_RFFTAttribute() const { + return attribute_type() == tosa::Attribute_RFFTAttribute ? static_cast(attribute()) : nullptr; + } const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *inputs() const { return GetPointer> *>(VT_INPUTS); } @@ -2547,6 +2635,10 @@ template<> inline const tosa::FFTAttribute *TosaOperator::attribute_as inline const tosa::RFFTAttribute *TosaOperator::attribute_as() const { + return attribute_as_RFFTAttribute(); +} + struct TosaOperatorBuilder { typedef TosaOperator Table; ::flatbuffers::FlatBufferBuilder &fbb_; @@ -2947,6 +3039,10 @@ inline bool VerifyAttribute(::flatbuffers::Verifier &verifier, const void *obj, auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case Attribute_RFFTAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } diff --git a/python/tosa/Attribute.py b/python/tosa/Attribute.py index 25ade44..2589f71 100644 --- a/python/tosa/Attribute.py +++ b/python/tosa/Attribute.py @@ -26,3 +26,4 @@ class Attribute(object): NegateAttribute = 20 CustomAttribute = 21 FFTAttribute = 22 + RFFTAttribute = 23 diff --git a/python/tosa/ConvAttribute.py b/python/tosa/ConvAttribute.py index 6db82d7..b35b67c 100644 --- a/python/tosa/ConvAttribute.py +++ b/python/tosa/ConvAttribute.py @@ -123,8 +123,15 @@ class ConvAttribute(object): return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 + # ConvAttribute + def LocalBound(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + def ConvAttributeStart(builder): - builder.StartObject(5) + builder.StartObject(6) def Start(builder): ConvAttributeStart(builder) @@ -177,6 +184,12 @@ def ConvAttributeAddWeightZp(builder, weightZp): def AddWeightZp(builder, weightZp): ConvAttributeAddWeightZp(builder, weightZp) +def ConvAttributeAddLocalBound(builder, localBound): + builder.PrependBoolSlot(5, localBound, 0) + +def AddLocalBound(builder, localBound): + ConvAttributeAddLocalBound(builder, localBound) + def ConvAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/FFTAttribute.py b/python/tosa/FFTAttribute.py index 0f22aa7..d1624c2 100644 --- a/python/tosa/FFTAttribute.py +++ b/python/tosa/FFTAttribute.py @@ -35,8 +35,15 @@ class FFTAttribute(object): return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) return False + # FFTAttribute + def LocalBound(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + def FFTAttributeStart(builder): - builder.StartObject(1) + builder.StartObject(2) def Start(builder): FFTAttributeStart(builder) @@ -47,6 +54,12 @@ def FFTAttributeAddInverse(builder, inverse): def AddInverse(builder, inverse): FFTAttributeAddInverse(builder, inverse) +def FFTAttributeAddLocalBound(builder, localBound): + builder.PrependBoolSlot(1, localBound, 0) + +def AddLocalBound(builder, localBound): + FFTAttributeAddLocalBound(builder, localBound) + def FFTAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/RFFTAttribute.py b/python/tosa/RFFTAttribute.py new file mode 100644 index 0000000..7f76024 --- /dev/null +++ b/python/tosa/RFFTAttribute.py @@ -0,0 +1,54 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class RFFTAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = RFFTAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsRFFTAttribute(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + @classmethod + def RFFTAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # RFFTAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # RFFTAttribute + def LocalBound(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + +def RFFTAttributeStart(builder): + builder.StartObject(1) + +def Start(builder): + RFFTAttributeStart(builder) + +def RFFTAttributeAddLocalBound(builder, localBound): + builder.PrependBoolSlot(0, localBound, 0) + +def AddLocalBound(builder, localBound): + RFFTAttributeAddLocalBound(builder, localBound) + +def RFFTAttributeEnd(builder): + return builder.EndObject() + +def End(builder): + return RFFTAttributeEnd(builder) diff --git a/python/tosa/TransposeConvAttribute.py b/python/tosa/TransposeConvAttribute.py index def507e..a74a433 100644 --- a/python/tosa/TransposeConvAttribute.py +++ b/python/tosa/TransposeConvAttribute.py @@ -123,8 +123,15 @@ class TransposeConvAttribute(object): return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 + # TransposeConvAttribute + def LocalBound(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + def TransposeConvAttributeStart(builder): - builder.StartObject(5) + builder.StartObject(6) def Start(builder): TransposeConvAttributeStart(builder) @@ -177,6 +184,12 @@ def TransposeConvAttributeAddWeightZp(builder, weightZp): def AddWeightZp(builder, weightZp): TransposeConvAttributeAddWeightZp(builder, weightZp) +def TransposeConvAttributeAddLocalBound(builder, localBound): + builder.PrependBoolSlot(5, localBound, 0) + +def AddLocalBound(builder, localBound): + TransposeConvAttributeAddLocalBound(builder, localBound) + def TransposeConvAttributeEnd(builder): return builder.EndObject() diff --git a/schema/tosa.fbs b/schema/tosa.fbs index 431efb4..21d3dad 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -144,6 +144,7 @@ union Attribute { NegateAttribute, CustomAttribute, FFTAttribute, + RFFTAttribute, } table PoolAttribute { @@ -161,6 +162,7 @@ table ConvAttribute { dilation: [int32]; input_zp: int32; weight_zp: int32; + local_bound: bool; } table TransposeConvAttribute { @@ -169,6 +171,7 @@ table TransposeConvAttribute { output_shape: [int32]; input_zp: int32; weight_zp: int32; + local_bound: bool; } table PadAttribute { @@ -269,6 +272,11 @@ table CustomAttribute { table FFTAttribute { inverse: bool; + local_bound: bool; +} + +table RFFTAttribute { + local_bound: bool; } table Version { -- cgit v1.2.1