aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2023-11-01 16:12:07 -0700
committerEric Kunze <eric.kunze@arm.com>2023-11-01 16:14:25 -0700
commit4881c29247d4b411de446b13d9bd58ea93737aac (patch)
treec1a55620543fd0fff382a9f8c7ceeab14b6aa5db
parent5917fc7a9392da8fd1e8c68b2d00b89709a31584 (diff)
downloadserialization_lib-4881c29247d4b411de446b13d9bd58ea93737aac.tar.gz
Add support for local_bound attribute
local_bound is used to determine when fast convolution algorithms can be used in implementing the operation. Signed-off-by: Eric Kunze <eric.kunze@arm.com> Change-Id: I9970a2544e90a620f46ac4d3d01cec90a15710a9
-rw-r--r--include/attribute.def18
-rw-r--r--include/tosa_generated.h126
-rw-r--r--python/tosa/Attribute.py1
-rw-r--r--python/tosa/ConvAttribute.py15
-rw-r--r--python/tosa/FFTAttribute.py15
-rw-r--r--python/tosa/RFFTAttribute.py54
-rw-r--r--python/tosa/TransposeConvAttribute.py15
-rw-r--r--schema/tosa.fbs8
8 files changed, 228 insertions, 24 deletions
diff --git a/include/attribute.def b/include/attribute.def
index b72f65a..dc7a569 100644
--- a/include/attribute.def
+++ b/include/attribute.def
@@ -34,19 +34,21 @@ DEF_ATTRIBUTE(Pool, 6,
int32_t, S, output_zp,
DType, S, accum_dtype)
-DEF_ATTRIBUTE(Conv, 5,
+DEF_ATTRIBUTE(Conv, 6,
int32_t, V, pad,
int32_t, V, stride,
int32_t, V, dilation,
int32_t, S, input_zp,
- int32_t, S, weight_zp)
+ int32_t, S, weight_zp,
+ bool, S, local_bound)
-DEF_ATTRIBUTE(TransposeConv, 5,
+DEF_ATTRIBUTE(TransposeConv, 6,
int32_t, V, out_pad,
int32_t, V, stride,
int32_t, V, output_shape,
int32_t, S, input_zp,
- int32_t, S, weight_zp)
+ int32_t, S, weight_zp,
+ bool, S, local_bound)
DEF_ATTRIBUTE(Pad, 3,
int32_t, V, padding,
@@ -126,6 +128,10 @@ DEF_ATTRIBUTE(Custom, 3,
string, S, config,
uint8_t, V, implementation_attrs)
-DEF_ATTRIBUTE(FFT, 1,
- bool, S, inverse)
+DEF_ATTRIBUTE(FFT, 2,
+ bool, S, inverse,
+ bool, S, local_bound)
+
+DEF_ATTRIBUTE(RFFT, 1,
+ bool, S, local_bound)
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index a81ff9c..6cc0d42 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -81,6 +81,9 @@ struct CustomAttributeBuilder;
struct FFTAttribute;
struct FFTAttributeBuilder;
+struct RFFTAttribute;
+struct RFFTAttributeBuilder;
+
struct Version;
struct VersionBuilder;
@@ -462,11 +465,12 @@ enum Attribute : uint8_t {
Attribute_NegateAttribute = 20,
Attribute_CustomAttribute = 21,
Attribute_FFTAttribute = 22,
+ Attribute_RFFTAttribute = 23,
Attribute_MIN = Attribute_NONE,
- Attribute_MAX = Attribute_FFTAttribute
+ Attribute_MAX = Attribute_RFFTAttribute
};
-inline const Attribute (&EnumValuesAttribute())[23] {
+inline const Attribute (&EnumValuesAttribute())[24] {
static const Attribute values[] = {
Attribute_NONE,
Attribute_PoolAttribute,
@@ -490,13 +494,14 @@ inline const Attribute (&EnumValuesAttribute())[23] {
Attribute_FullyConnectedAttribute,
Attribute_NegateAttribute,
Attribute_CustomAttribute,
- Attribute_FFTAttribute
+ Attribute_FFTAttribute,
+ Attribute_RFFTAttribute
};
return values;
}
inline const char * const *EnumNamesAttribute() {
- static const char * const names[24] = {
+ static const char * const names[25] = {
"NONE",
"PoolAttribute",
"ConvAttribute",
@@ -520,13 +525,14 @@ inline const char * const *EnumNamesAttribute() {
"NegateAttribute",
"CustomAttribute",
"FFTAttribute",
+ "RFFTAttribute",
nullptr
};
return names;
}
inline const char *EnumNameAttribute(Attribute e) {
- if (::flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_FFTAttribute)) return "";
+ if (::flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_RFFTAttribute)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesAttribute()[index];
}
@@ -623,6 +629,10 @@ template<> struct AttributeTraits<tosa::FFTAttribute> {
static const Attribute enum_value = Attribute_FFTAttribute;
};
+template<> struct AttributeTraits<tosa::RFFTAttribute> {
+ static const Attribute enum_value = Attribute_RFFTAttribute;
+};
+
bool VerifyAttribute(::flatbuffers::Verifier &verifier, const void *obj, Attribute type);
bool VerifyAttributeVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<uint8_t> *types);
@@ -748,7 +758,8 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
VT_STRIDE = 6,
VT_DILATION = 8,
VT_INPUT_ZP = 10,
- VT_WEIGHT_ZP = 12
+ VT_WEIGHT_ZP = 12,
+ VT_LOCAL_BOUND = 14
};
const ::flatbuffers::Vector<int32_t> *pad() const {
return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_PAD);
@@ -765,6 +776,9 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int32_t weight_zp() const {
return GetField<int32_t>(VT_WEIGHT_ZP, 0);
}
+ bool local_bound() const {
+ return GetField<uint8_t>(VT_LOCAL_BOUND, 0) != 0;
+ }
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_PAD) &&
@@ -775,6 +789,7 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
verifier.VerifyVector(dilation()) &&
VerifyField<int32_t>(verifier, VT_INPUT_ZP, 4) &&
VerifyField<int32_t>(verifier, VT_WEIGHT_ZP, 4) &&
+ VerifyField<uint8_t>(verifier, VT_LOCAL_BOUND, 1) &&
verifier.EndTable();
}
};
@@ -798,6 +813,9 @@ struct ConvAttributeBuilder {
void add_weight_zp(int32_t weight_zp) {
fbb_.AddElement<int32_t>(ConvAttribute::VT_WEIGHT_ZP, weight_zp, 0);
}
+ void add_local_bound(bool local_bound) {
+ fbb_.AddElement<uint8_t>(ConvAttribute::VT_LOCAL_BOUND, static_cast<uint8_t>(local_bound), 0);
+ }
explicit ConvAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -815,13 +833,15 @@ inline ::flatbuffers::Offset<ConvAttribute> CreateConvAttribute(
::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> stride = 0,
::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> dilation = 0,
int32_t input_zp = 0,
- int32_t weight_zp = 0) {
+ int32_t weight_zp = 0,
+ bool local_bound = false) {
ConvAttributeBuilder builder_(_fbb);
builder_.add_weight_zp(weight_zp);
builder_.add_input_zp(input_zp);
builder_.add_dilation(dilation);
builder_.add_stride(stride);
builder_.add_pad(pad);
+ builder_.add_local_bound(local_bound);
return builder_.Finish();
}
@@ -831,7 +851,8 @@ inline ::flatbuffers::Offset<ConvAttribute> CreateConvAttributeDirect(
const std::vector<int32_t> *stride = nullptr,
const std::vector<int32_t> *dilation = nullptr,
int32_t input_zp = 0,
- int32_t weight_zp = 0) {
+ int32_t weight_zp = 0,
+ bool local_bound = false) {
auto pad__ = pad ? _fbb.CreateVector<int32_t>(*pad) : 0;
auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
auto dilation__ = dilation ? _fbb.CreateVector<int32_t>(*dilation) : 0;
@@ -841,7 +862,8 @@ inline ::flatbuffers::Offset<ConvAttribute> CreateConvAttributeDirect(
stride__,
dilation__,
input_zp,
- weight_zp);
+ weight_zp,
+ local_bound);
}
struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
@@ -851,7 +873,8 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T
VT_STRIDE = 6,
VT_OUTPUT_SHAPE = 8,
VT_INPUT_ZP = 10,
- VT_WEIGHT_ZP = 12
+ VT_WEIGHT_ZP = 12,
+ VT_LOCAL_BOUND = 14
};
const ::flatbuffers::Vector<int32_t> *out_pad() const {
return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_OUT_PAD);
@@ -868,6 +891,9 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T
int32_t weight_zp() const {
return GetField<int32_t>(VT_WEIGHT_ZP, 0);
}
+ bool local_bound() const {
+ return GetField<uint8_t>(VT_LOCAL_BOUND, 0) != 0;
+ }
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_OUT_PAD) &&
@@ -878,6 +904,7 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T
verifier.VerifyVector(output_shape()) &&
VerifyField<int32_t>(verifier, VT_INPUT_ZP, 4) &&
VerifyField<int32_t>(verifier, VT_WEIGHT_ZP, 4) &&
+ VerifyField<uint8_t>(verifier, VT_LOCAL_BOUND, 1) &&
verifier.EndTable();
}
};
@@ -901,6 +928,9 @@ struct TransposeConvAttributeBuilder {
void add_weight_zp(int32_t weight_zp) {
fbb_.AddElement<int32_t>(TransposeConvAttribute::VT_WEIGHT_ZP, weight_zp, 0);
}
+ void add_local_bound(bool local_bound) {
+ fbb_.AddElement<uint8_t>(TransposeConvAttribute::VT_LOCAL_BOUND, static_cast<uint8_t>(local_bound), 0);
+ }
explicit TransposeConvAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -918,13 +948,15 @@ inline ::flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribut
::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> stride = 0,
::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> output_shape = 0,
int32_t input_zp = 0,
- int32_t weight_zp = 0) {
+ int32_t weight_zp = 0,
+ bool local_bound = false) {
TransposeConvAttributeBuilder builder_(_fbb);
builder_.add_weight_zp(weight_zp);
builder_.add_input_zp(input_zp);
builder_.add_output_shape(output_shape);
builder_.add_stride(stride);
builder_.add_out_pad(out_pad);
+ builder_.add_local_bound(local_bound);
return builder_.Finish();
}
@@ -934,7 +966,8 @@ inline ::flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribut
const std::vector<int32_t> *stride = nullptr,
const std::vector<int32_t> *output_shape = nullptr,
int32_t input_zp = 0,
- int32_t weight_zp = 0) {
+ int32_t weight_zp = 0,
+ bool local_bound = false) {
auto out_pad__ = out_pad ? _fbb.CreateVector<int32_t>(*out_pad) : 0;
auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
auto output_shape__ = output_shape ? _fbb.CreateVector<int32_t>(*output_shape) : 0;
@@ -944,7 +977,8 @@ inline ::flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribut
stride__,
output_shape__,
input_zp,
- weight_zp);
+ weight_zp,
+ local_bound);
}
struct PadAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
@@ -2113,14 +2147,19 @@ inline ::flatbuffers::Offset<CustomAttribute> CreateCustomAttributeDirect(
struct FFTAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
typedef FFTAttributeBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
- VT_INVERSE = 4
+ VT_INVERSE = 4,
+ VT_LOCAL_BOUND = 6
};
bool inverse() const {
return GetField<uint8_t>(VT_INVERSE, 0) != 0;
}
+ bool local_bound() const {
+ return GetField<uint8_t>(VT_LOCAL_BOUND, 0) != 0;
+ }
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_INVERSE, 1) &&
+ VerifyField<uint8_t>(verifier, VT_LOCAL_BOUND, 1) &&
verifier.EndTable();
}
};
@@ -2132,6 +2171,9 @@ struct FFTAttributeBuilder {
void add_inverse(bool inverse) {
fbb_.AddElement<uint8_t>(FFTAttribute::VT_INVERSE, static_cast<uint8_t>(inverse), 0);
}
+ void add_local_bound(bool local_bound) {
+ fbb_.AddElement<uint8_t>(FFTAttribute::VT_LOCAL_BOUND, static_cast<uint8_t>(local_bound), 0);
+ }
explicit FFTAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2145,12 +2187,55 @@ struct FFTAttributeBuilder {
inline ::flatbuffers::Offset<FFTAttribute> CreateFFTAttribute(
::flatbuffers::FlatBufferBuilder &_fbb,
- bool inverse = false) {
+ bool inverse = false,
+ bool local_bound = false) {
FFTAttributeBuilder builder_(_fbb);
+ builder_.add_local_bound(local_bound);
builder_.add_inverse(inverse);
return builder_.Finish();
}
+struct RFFTAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+ typedef RFFTAttributeBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_LOCAL_BOUND = 4
+ };
+ bool local_bound() const {
+ return GetField<uint8_t>(VT_LOCAL_BOUND, 0) != 0;
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_LOCAL_BOUND, 1) &&
+ verifier.EndTable();
+ }
+};
+
+struct RFFTAttributeBuilder {
+ typedef RFFTAttribute Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_local_bound(bool local_bound) {
+ fbb_.AddElement<uint8_t>(RFFTAttribute::VT_LOCAL_BOUND, static_cast<uint8_t>(local_bound), 0);
+ }
+ explicit RFFTAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<RFFTAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<RFFTAttribute>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<RFFTAttribute> CreateRFFTAttribute(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ bool local_bound = false) {
+ RFFTAttributeBuilder builder_(_fbb);
+ builder_.add_local_bound(local_bound);
+ return builder_.Finish();
+}
+
struct Version FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
typedef VersionBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
@@ -2437,6 +2522,9 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
const tosa::FFTAttribute *attribute_as_FFTAttribute() const {
return attribute_type() == tosa::Attribute_FFTAttribute ? static_cast<const tosa::FFTAttribute *>(attribute()) : nullptr;
}
+ const tosa::RFFTAttribute *attribute_as_RFFTAttribute() const {
+ return attribute_type() == tosa::Attribute_RFFTAttribute ? static_cast<const tosa::RFFTAttribute *>(attribute()) : nullptr;
+ }
const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *inputs() const {
return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_INPUTS);
}
@@ -2547,6 +2635,10 @@ template<> inline const tosa::FFTAttribute *TosaOperator::attribute_as<tosa::FFT
return attribute_as_FFTAttribute();
}
+template<> inline const tosa::RFFTAttribute *TosaOperator::attribute_as<tosa::RFFTAttribute>() const {
+ return attribute_as_RFFTAttribute();
+}
+
struct TosaOperatorBuilder {
typedef TosaOperator Table;
::flatbuffers::FlatBufferBuilder &fbb_;
@@ -2947,6 +3039,10 @@ inline bool VerifyAttribute(::flatbuffers::Verifier &verifier, const void *obj,
auto ptr = reinterpret_cast<const tosa::FFTAttribute *>(obj);
return verifier.VerifyTable(ptr);
}
+ case Attribute_RFFTAttribute: {
+ auto ptr = reinterpret_cast<const tosa::RFFTAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return true;
}
}
diff --git a/python/tosa/Attribute.py b/python/tosa/Attribute.py
index 25ade44..2589f71 100644
--- a/python/tosa/Attribute.py
+++ b/python/tosa/Attribute.py
@@ -26,3 +26,4 @@ class Attribute(object):
NegateAttribute = 20
CustomAttribute = 21
FFTAttribute = 22
+ RFFTAttribute = 23
diff --git a/python/tosa/ConvAttribute.py b/python/tosa/ConvAttribute.py
index 6db82d7..b35b67c 100644
--- a/python/tosa/ConvAttribute.py
+++ b/python/tosa/ConvAttribute.py
@@ -123,8 +123,15 @@ class ConvAttribute(object):
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
return 0
+ # ConvAttribute
+ def LocalBound(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+ if o != 0:
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+ return False
+
def ConvAttributeStart(builder):
- builder.StartObject(5)
+ builder.StartObject(6)
def Start(builder):
ConvAttributeStart(builder)
@@ -177,6 +184,12 @@ def ConvAttributeAddWeightZp(builder, weightZp):
def AddWeightZp(builder, weightZp):
ConvAttributeAddWeightZp(builder, weightZp)
+def ConvAttributeAddLocalBound(builder, localBound):
+ builder.PrependBoolSlot(5, localBound, 0)
+
+def AddLocalBound(builder, localBound):
+ ConvAttributeAddLocalBound(builder, localBound)
+
def ConvAttributeEnd(builder):
return builder.EndObject()
diff --git a/python/tosa/FFTAttribute.py b/python/tosa/FFTAttribute.py
index 0f22aa7..d1624c2 100644
--- a/python/tosa/FFTAttribute.py
+++ b/python/tosa/FFTAttribute.py
@@ -35,8 +35,15 @@ class FFTAttribute(object):
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False
+ # FFTAttribute
+ def LocalBound(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+ return False
+
def FFTAttributeStart(builder):
- builder.StartObject(1)
+ builder.StartObject(2)
def Start(builder):
FFTAttributeStart(builder)
@@ -47,6 +54,12 @@ def FFTAttributeAddInverse(builder, inverse):
def AddInverse(builder, inverse):
FFTAttributeAddInverse(builder, inverse)
+def FFTAttributeAddLocalBound(builder, localBound):
+ builder.PrependBoolSlot(1, localBound, 0)
+
+def AddLocalBound(builder, localBound):
+ FFTAttributeAddLocalBound(builder, localBound)
+
def FFTAttributeEnd(builder):
return builder.EndObject()
diff --git a/python/tosa/RFFTAttribute.py b/python/tosa/RFFTAttribute.py
new file mode 100644
index 0000000..7f76024
--- /dev/null
+++ b/python/tosa/RFFTAttribute.py
@@ -0,0 +1,54 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+from flatbuffers.compat import import_numpy
+np = import_numpy()
+
+class RFFTAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAs(cls, buf, offset=0):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = RFFTAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ @classmethod
+ def GetRootAsRFFTAttribute(cls, buf, offset=0):
+ """This method is deprecated. Please switch to GetRootAs."""
+ return cls.GetRootAs(buf, offset)
+ @classmethod
+ def RFFTAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed)
+
+ # RFFTAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # RFFTAttribute
+ def LocalBound(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+ return False
+
+def RFFTAttributeStart(builder):
+ builder.StartObject(1)
+
+def Start(builder):
+ RFFTAttributeStart(builder)
+
+def RFFTAttributeAddLocalBound(builder, localBound):
+ builder.PrependBoolSlot(0, localBound, 0)
+
+def AddLocalBound(builder, localBound):
+ RFFTAttributeAddLocalBound(builder, localBound)
+
+def RFFTAttributeEnd(builder):
+ return builder.EndObject()
+
+def End(builder):
+ return RFFTAttributeEnd(builder)
diff --git a/python/tosa/TransposeConvAttribute.py b/python/tosa/TransposeConvAttribute.py
index def507e..a74a433 100644
--- a/python/tosa/TransposeConvAttribute.py
+++ b/python/tosa/TransposeConvAttribute.py
@@ -123,8 +123,15 @@ class TransposeConvAttribute(object):
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
return 0
+ # TransposeConvAttribute
+ def LocalBound(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+ if o != 0:
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+ return False
+
def TransposeConvAttributeStart(builder):
- builder.StartObject(5)
+ builder.StartObject(6)
def Start(builder):
TransposeConvAttributeStart(builder)
@@ -177,6 +184,12 @@ def TransposeConvAttributeAddWeightZp(builder, weightZp):
def AddWeightZp(builder, weightZp):
TransposeConvAttributeAddWeightZp(builder, weightZp)
+def TransposeConvAttributeAddLocalBound(builder, localBound):
+ builder.PrependBoolSlot(5, localBound, 0)
+
+def AddLocalBound(builder, localBound):
+ TransposeConvAttributeAddLocalBound(builder, localBound)
+
def TransposeConvAttributeEnd(builder):
return builder.EndObject()
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index 431efb4..21d3dad 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -144,6 +144,7 @@ union Attribute {
NegateAttribute,
CustomAttribute,
FFTAttribute,
+ RFFTAttribute,
}
table PoolAttribute {
@@ -161,6 +162,7 @@ table ConvAttribute {
dilation: [int32];
input_zp: int32;
weight_zp: int32;
+ local_bound: bool;
}
table TransposeConvAttribute {
@@ -169,6 +171,7 @@ table TransposeConvAttribute {
output_shape: [int32];
input_zp: int32;
weight_zp: int32;
+ local_bound: bool;
}
table PadAttribute {
@@ -269,6 +272,11 @@ table CustomAttribute {
table FFTAttribute {
inverse: bool;
+ local_bound: bool;
+}
+
+table RFFTAttribute {
+ local_bound: bool;
}
table Version {