aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-10-15 15:49:19 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-10-27 10:43:44 -0700
commit38d214cfa4491f1956f28d7eff428a8ed07d824c (patch)
tree871d86c6bc442149159d021724c392a0b9256b51
parente6563f52231c603b409638b22530d016757542c8 (diff)
downloadserialization_lib-38d214cfa4491f1956f28d7eff428a8ed07d824c.tar.gz
Changes for 0.23.0 release
- Remove RELUN op - Add pad_const to PAD op - Make padding as an attribute of PAD op - Make perm as an attribute of TRANSPOSE op - Make table as attribute in Table op - Fix typo in operator.def Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ifcaa4ad686578cf814345ede8e7f37f0a04fd8ea
-rw-r--r--include/attribute.def14
-rw-r--r--include/operator.def21
-rw-r--r--include/tosa_generated.h264
-rw-r--r--python/tosa/Attribute.py4
-rw-r--r--python/tosa/Op.py2
-rw-r--r--python/tosa/PadAttribute.py88
-rw-r--r--python/tosa/TableAttribute.py (renamed from python/tosa/ReluNAttribute.py)47
-rw-r--r--python/tosa/TransposeAttribute.py72
-rw-r--r--python/tosa_serializer.py36
-rw-r--r--schema/tosa.fbs21
10 files changed, 470 insertions, 99 deletions
diff --git a/include/attribute.def b/include/attribute.def
index d77a687..2534cc4 100644
--- a/include/attribute.def
+++ b/include/attribute.def
@@ -42,9 +42,10 @@ DEF_ATTRIBUTE(TransposeConv, 4,
int32_t, V, dilation,
int32_t, V, output_shape)
-DEF_ATTRIBUTE(ReluN, 2,
- int32_t, S, max_int,
- float, S, max_fp)
+DEF_ATTRIBUTE(Pad, 3,
+ int32_t, V, padding,
+ int32_t, S, pad_const_int,
+ float, S, pad_const_fp)
DEF_ATTRIBUTE(Axis, 1,
int32_t, S, axis)
@@ -96,3 +97,10 @@ DEF_ATTRIBUTE(CondIf, 2,
DEF_ATTRIBUTE(WhileLoop, 2,
string, S, cond_branch,
string, S, body_branch)
+
+DEF_ATTRIBUTE(Transpose, 1,
+ int32_t, V, perm)
+
+DEF_ATTRIBUTE(Table, 1,
+ int32_t, V, table)
+
diff --git a/include/operator.def b/include/operator.def
index 386e72a..0b186c5 100644
--- a/include/operator.def
+++ b/include/operator.def
@@ -39,7 +39,6 @@ DEF_OPERATOR(transpose_conv2d, TRANSPOSE_CONV2D, TransposeConv2d,
/* activation */
DEF_OPERATOR(clamp, CLAMP, Clamp, Clamp, None)
-DEF_OPERATOR(reluN, RELUN, ReluN, ReluN, None)
DEF_OPERATOR(sigmoid, SIGMOID, Sigmoid, None, None)
DEF_OPERATOR(tanh, TANH, Tanh, None, None)
@@ -60,7 +59,7 @@ DEF_OPERATOR(minimum, MINIMUM, Minimum,
DEF_OPERATOR(mul, MUL, Mul, Mul, None)
DEF_OPERATOR(pow, POW, Pow, None, None)
DEF_OPERATOR(sub, SUB, Sub, None, None)
-DEF_OPERATOR(table, TABLE, Table, None, None)
+DEF_OPERATOR(table, TABLE, Table, Table, None)
/* elementwise - unary */
DEF_OPERATOR(abs, ABS, Abs, None, None)
@@ -84,21 +83,21 @@ DEF_OPERATOR(greater, GREATER, Greater,
DEF_OPERATOR(greater_equal, GREATER_EQUAL, GreaterEqual, None, None)
/* reduction */
-DEF_OPERATOR(reduce_any, REDUCE_ANY, ReduceAny, Reduce, None)
-DEF_OPERATOR(reduce_all, REDUCE_ALL, ReduceAll, Reduce, None)
-DEF_OPERATOR(reduce_max, REDUCE_MAX, ReduceMax, Reduce, None)
-DEF_OPERATOR(reduce_min, REDUCE_MIN, ReduceMin, Reduce, None)
-DEF_OPERATOR(reduce_prod, REDUCE_PRODUCT, ReduceProduct, Reduce, None)
-DEF_OPERATOR(reduce_sum, REDUCE_SUM, ReduceSum, Reduce, None)
+DEF_OPERATOR(reduce_any, REDUCE_ANY, ReduceAny, Axis, None)
+DEF_OPERATOR(reduce_all, REDUCE_ALL, ReduceAll, Axis, None)
+DEF_OPERATOR(reduce_max, REDUCE_MAX, ReduceMax, Axis, None)
+DEF_OPERATOR(reduce_min, REDUCE_MIN, ReduceMin, Axis, None)
+DEF_OPERATOR(reduce_prod, REDUCE_PRODUCT, ReduceProduct, Axis, None)
+DEF_OPERATOR(reduce_sum, REDUCE_SUM, ReduceSum, Axis, None)
/* memory operation */
DEF_OPERATOR(concat, CONCAT, Concat, Axis, None)
-DEF_OPERATOR(pad, PAD, Pad, None, Pad)
+DEF_OPERATOR(pad, PAD, Pad, Pad, Pad)
DEF_OPERATOR(reshape, RESHAPE, Reshape, Reshape, None)
-DEF_OPERATOR(reverse, REVERSE, Reverse, Reverse, None)
+DEF_OPERATOR(reverse, REVERSE, Reverse, Axis, None)
DEF_OPERATOR(slice, SLICE, Slice, Slice, None)
DEF_OPERATOR(tile, TILE, Tile, Tile, None)
-DEF_OPERATOR(transpose, TRANSPOSE, Transpose, None, None)
+DEF_OPERATOR(transpose, TRANSPOSE, Transpose, Transpose, None)
/* gather/scatter */
DEF_OPERATOR(gather, GATHER, Gather, None, None)
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index 51e33ce..223760e 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -32,8 +32,8 @@ struct ConvAttributeBuilder;
struct TransposeConvAttribute;
struct TransposeConvAttributeBuilder;
-struct ReluNAttribute;
-struct ReluNAttributeBuilder;
+struct PadAttribute;
+struct PadAttributeBuilder;
struct AxisAttribute;
struct AxisAttributeBuilder;
@@ -68,6 +68,12 @@ struct CondIfAttributeBuilder;
struct WhileLoopAttribute;
struct WhileLoopAttributeBuilder;
+struct TransposeAttribute;
+struct TransposeAttributeBuilder;
+
+struct TableAttribute;
+struct TableAttributeBuilder;
+
struct UnaryQuantInfo;
struct UnaryQuantInfoBuilder;
@@ -191,7 +197,7 @@ enum Op {
Op_MAX_POOL2D = 8,
Op_TRANSPOSE_CONV2D = 9,
Op_CLAMP = 10,
- Op_RELUN = 11,
+ Op_RESERVED = 11,
Op_SIGMOID = 12,
Op_TANH = 13,
Op_ADD = 14,
@@ -266,7 +272,7 @@ inline const Op (&EnumValuesOp())[69] {
Op_MAX_POOL2D,
Op_TRANSPOSE_CONV2D,
Op_CLAMP,
- Op_RELUN,
+ Op_RESERVED,
Op_SIGMOID,
Op_TANH,
Op_ADD,
@@ -341,7 +347,7 @@ inline const char * const *EnumNamesOp() {
"MAX_POOL2D",
"TRANSPOSE_CONV2D",
"CLAMP",
- "RELUN",
+ "RESERVED",
"SIGMOID",
"TANH",
"ADD",
@@ -415,7 +421,7 @@ enum Attribute {
Attribute_PoolAttribute = 1,
Attribute_ConvAttribute = 2,
Attribute_TransposeConvAttribute = 3,
- Attribute_ReluNAttribute = 4,
+ Attribute_PadAttribute = 4,
Attribute_AxisAttribute = 5,
Attribute_ReshapeAttribute = 6,
Attribute_SliceAttribute = 7,
@@ -427,17 +433,19 @@ enum Attribute {
Attribute_ArithmeticRightShiftAttribute = 13,
Attribute_CondIfAttribute = 14,
Attribute_WhileLoopAttribute = 15,
+ Attribute_TransposeAttribute = 16,
+ Attribute_TableAttribute = 17,
Attribute_MIN = Attribute_NONE,
- Attribute_MAX = Attribute_WhileLoopAttribute
+ Attribute_MAX = Attribute_TableAttribute
};
-inline const Attribute (&EnumValuesAttribute())[16] {
+inline const Attribute (&EnumValuesAttribute())[18] {
static const Attribute values[] = {
Attribute_NONE,
Attribute_PoolAttribute,
Attribute_ConvAttribute,
Attribute_TransposeConvAttribute,
- Attribute_ReluNAttribute,
+ Attribute_PadAttribute,
Attribute_AxisAttribute,
Attribute_ReshapeAttribute,
Attribute_SliceAttribute,
@@ -448,18 +456,20 @@ inline const Attribute (&EnumValuesAttribute())[16] {
Attribute_MulAttribute,
Attribute_ArithmeticRightShiftAttribute,
Attribute_CondIfAttribute,
- Attribute_WhileLoopAttribute
+ Attribute_WhileLoopAttribute,
+ Attribute_TransposeAttribute,
+ Attribute_TableAttribute
};
return values;
}
inline const char * const *EnumNamesAttribute() {
- static const char * const names[17] = {
+ static const char * const names[19] = {
"NONE",
"PoolAttribute",
"ConvAttribute",
"TransposeConvAttribute",
- "ReluNAttribute",
+ "PadAttribute",
"AxisAttribute",
"ReshapeAttribute",
"SliceAttribute",
@@ -471,13 +481,15 @@ inline const char * const *EnumNamesAttribute() {
"ArithmeticRightShiftAttribute",
"CondIfAttribute",
"WhileLoopAttribute",
+ "TransposeAttribute",
+ "TableAttribute",
nullptr
};
return names;
}
inline const char *EnumNameAttribute(Attribute e) {
- if (flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_WhileLoopAttribute)) return "";
+ if (flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_TableAttribute)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesAttribute()[index];
}
@@ -498,8 +510,8 @@ template<> struct AttributeTraits<tosa::TransposeConvAttribute> {
static const Attribute enum_value = Attribute_TransposeConvAttribute;
};
-template<> struct AttributeTraits<tosa::ReluNAttribute> {
- static const Attribute enum_value = Attribute_ReluNAttribute;
+template<> struct AttributeTraits<tosa::PadAttribute> {
+ static const Attribute enum_value = Attribute_PadAttribute;
};
template<> struct AttributeTraits<tosa::AxisAttribute> {
@@ -546,6 +558,14 @@ template<> struct AttributeTraits<tosa::WhileLoopAttribute> {
static const Attribute enum_value = Attribute_WhileLoopAttribute;
};
+template<> struct AttributeTraits<tosa::TransposeAttribute> {
+ static const Attribute enum_value = Attribute_TransposeAttribute;
+};
+
+template<> struct AttributeTraits<tosa::TableAttribute> {
+ static const Attribute enum_value = Attribute_TableAttribute;
+};
+
bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type);
bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
@@ -865,58 +885,82 @@ inline flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttributeD
output_shape__);
}
-struct ReluNAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
- typedef ReluNAttributeBuilder Builder;
+struct PadAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PadAttributeBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
- VT_MAX_INT = 4,
- VT_MAX_FP = 6
+ VT_PADDING = 4,
+ VT_PAD_CONST_INT = 6,
+ VT_PAD_CONST_FP = 8
};
- int32_t max_int() const {
- return GetField<int32_t>(VT_MAX_INT, 0);
+ const flatbuffers::Vector<int32_t> *padding() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PADDING);
}
- float max_fp() const {
- return GetField<float>(VT_MAX_FP, 0.0f);
+ int32_t pad_const_int() const {
+ return GetField<int32_t>(VT_PAD_CONST_INT, 0);
+ }
+ float pad_const_fp() const {
+ return GetField<float>(VT_PAD_CONST_FP, 0.0f);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
- VerifyField<int32_t>(verifier, VT_MAX_INT) &&
- VerifyField<float>(verifier, VT_MAX_FP) &&
+ VerifyOffset(verifier, VT_PADDING) &&
+ verifier.VerifyVector(padding()) &&
+ VerifyField<int32_t>(verifier, VT_PAD_CONST_INT) &&
+ VerifyField<float>(verifier, VT_PAD_CONST_FP) &&
verifier.EndTable();
}
};
-struct ReluNAttributeBuilder {
- typedef ReluNAttribute Table;
+struct PadAttributeBuilder {
+ typedef PadAttribute Table;
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
- void add_max_int(int32_t max_int) {
- fbb_.AddElement<int32_t>(ReluNAttribute::VT_MAX_INT, max_int, 0);
+ void add_padding(flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding) {
+ fbb_.AddOffset(PadAttribute::VT_PADDING, padding);
}
- void add_max_fp(float max_fp) {
- fbb_.AddElement<float>(ReluNAttribute::VT_MAX_FP, max_fp, 0.0f);
+ void add_pad_const_int(int32_t pad_const_int) {
+ fbb_.AddElement<int32_t>(PadAttribute::VT_PAD_CONST_INT, pad_const_int, 0);
+ }
+ void add_pad_const_fp(float pad_const_fp) {
+ fbb_.AddElement<float>(PadAttribute::VT_PAD_CONST_FP, pad_const_fp, 0.0f);
}
- explicit ReluNAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ explicit PadAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
- ReluNAttributeBuilder &operator=(const ReluNAttributeBuilder &);
- flatbuffers::Offset<ReluNAttribute> Finish() {
+ PadAttributeBuilder &operator=(const PadAttributeBuilder &);
+ flatbuffers::Offset<PadAttribute> Finish() {
const auto end = fbb_.EndTable(start_);
- auto o = flatbuffers::Offset<ReluNAttribute>(end);
+ auto o = flatbuffers::Offset<PadAttribute>(end);
return o;
}
};
-inline flatbuffers::Offset<ReluNAttribute> CreateReluNAttribute(
+inline flatbuffers::Offset<PadAttribute> CreatePadAttribute(
flatbuffers::FlatBufferBuilder &_fbb,
- int32_t max_int = 0,
- float max_fp = 0.0f) {
- ReluNAttributeBuilder builder_(_fbb);
- builder_.add_max_fp(max_fp);
- builder_.add_max_int(max_int);
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding = 0,
+ int32_t pad_const_int = 0,
+ float pad_const_fp = 0.0f) {
+ PadAttributeBuilder builder_(_fbb);
+ builder_.add_pad_const_fp(pad_const_fp);
+ builder_.add_pad_const_int(pad_const_int);
+ builder_.add_padding(padding);
return builder_.Finish();
}
+inline flatbuffers::Offset<PadAttribute> CreatePadAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *padding = nullptr,
+ int32_t pad_const_int = 0,
+ float pad_const_fp = 0.0f) {
+ auto padding__ = padding ? _fbb.CreateVector<int32_t>(*padding) : 0;
+ return tosa::CreatePadAttribute(
+ _fbb,
+ padding__,
+ pad_const_int,
+ pad_const_fp);
+}
+
struct AxisAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef AxisAttributeBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
@@ -1675,6 +1719,110 @@ inline flatbuffers::Offset<WhileLoopAttribute> CreateWhileLoopAttributeDirect(
body_branch__);
}
+struct TransposeAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TransposeAttributeBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_PERM = 4
+ };
+ const flatbuffers::Vector<int32_t> *perm() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PERM);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_PERM) &&
+ verifier.VerifyVector(perm()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TransposeAttributeBuilder {
+ typedef TransposeAttribute Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_perm(flatbuffers::Offset<flatbuffers::Vector<int32_t>> perm) {
+ fbb_.AddOffset(TransposeAttribute::VT_PERM, perm);
+ }
+ explicit TransposeAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TransposeAttributeBuilder &operator=(const TransposeAttributeBuilder &);
+ flatbuffers::Offset<TransposeAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TransposeAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TransposeAttribute> CreateTransposeAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> perm = 0) {
+ TransposeAttributeBuilder builder_(_fbb);
+ builder_.add_perm(perm);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TransposeAttribute> CreateTransposeAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *perm = nullptr) {
+ auto perm__ = perm ? _fbb.CreateVector<int32_t>(*perm) : 0;
+ return tosa::CreateTransposeAttribute(
+ _fbb,
+ perm__);
+}
+
+struct TableAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TableAttributeBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_TABLE = 4
+ };
+ const flatbuffers::Vector<int32_t> *table() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_TABLE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_TABLE) &&
+ verifier.VerifyVector(table()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TableAttributeBuilder {
+ typedef TableAttribute Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_table(flatbuffers::Offset<flatbuffers::Vector<int32_t>> table) {
+ fbb_.AddOffset(TableAttribute::VT_TABLE, table);
+ }
+ explicit TableAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TableAttributeBuilder &operator=(const TableAttributeBuilder &);
+ flatbuffers::Offset<TableAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TableAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TableAttribute> CreateTableAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> table = 0) {
+ TableAttributeBuilder builder_(_fbb);
+ builder_.add_table(table);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TableAttribute> CreateTableAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *table = nullptr) {
+ auto table__ = table ? _fbb.CreateVector<int32_t>(*table) : 0;
+ return tosa::CreateTableAttribute(
+ _fbb,
+ table__);
+}
+
struct UnaryQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef UnaryQuantInfoBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
@@ -2068,8 +2216,8 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const tosa::TransposeConvAttribute *attribute_as_TransposeConvAttribute() const {
return attribute_type() == tosa::Attribute_TransposeConvAttribute ? static_cast<const tosa::TransposeConvAttribute *>(attribute()) : nullptr;
}
- const tosa::ReluNAttribute *attribute_as_ReluNAttribute() const {
- return attribute_type() == tosa::Attribute_ReluNAttribute ? static_cast<const tosa::ReluNAttribute *>(attribute()) : nullptr;
+ const tosa::PadAttribute *attribute_as_PadAttribute() const {
+ return attribute_type() == tosa::Attribute_PadAttribute ? static_cast<const tosa::PadAttribute *>(attribute()) : nullptr;
}
const tosa::AxisAttribute *attribute_as_AxisAttribute() const {
return attribute_type() == tosa::Attribute_AxisAttribute ? static_cast<const tosa::AxisAttribute *>(attribute()) : nullptr;
@@ -2104,6 +2252,12 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const tosa::WhileLoopAttribute *attribute_as_WhileLoopAttribute() const {
return attribute_type() == tosa::Attribute_WhileLoopAttribute ? static_cast<const tosa::WhileLoopAttribute *>(attribute()) : nullptr;
}
+ const tosa::TransposeAttribute *attribute_as_TransposeAttribute() const {
+ return attribute_type() == tosa::Attribute_TransposeAttribute ? static_cast<const tosa::TransposeAttribute *>(attribute()) : nullptr;
+ }
+ const tosa::TableAttribute *attribute_as_TableAttribute() const {
+ return attribute_type() == tosa::Attribute_TableAttribute ? static_cast<const tosa::TableAttribute *>(attribute()) : nullptr;
+ }
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *inputs() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_INPUTS);
}
@@ -2160,8 +2314,8 @@ template<> inline const tosa::TransposeConvAttribute *TosaOperator::attribute_as
return attribute_as_TransposeConvAttribute();
}
-template<> inline const tosa::ReluNAttribute *TosaOperator::attribute_as<tosa::ReluNAttribute>() const {
- return attribute_as_ReluNAttribute();
+template<> inline const tosa::PadAttribute *TosaOperator::attribute_as<tosa::PadAttribute>() const {
+ return attribute_as_PadAttribute();
}
template<> inline const tosa::AxisAttribute *TosaOperator::attribute_as<tosa::AxisAttribute>() const {
@@ -2208,6 +2362,14 @@ template<> inline const tosa::WhileLoopAttribute *TosaOperator::attribute_as<tos
return attribute_as_WhileLoopAttribute();
}
+template<> inline const tosa::TransposeAttribute *TosaOperator::attribute_as<tosa::TransposeAttribute>() const {
+ return attribute_as_TransposeAttribute();
+}
+
+template<> inline const tosa::TableAttribute *TosaOperator::attribute_as<tosa::TableAttribute>() const {
+ return attribute_as_TableAttribute();
+}
+
template<> inline const tosa::UnaryQuantInfo *TosaOperator::quant_info_as<tosa::UnaryQuantInfo>() const {
return quant_info_as_UnaryQuantInfo();
}
@@ -2498,8 +2660,8 @@ inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, At
auto ptr = reinterpret_cast<const tosa::TransposeConvAttribute *>(obj);
return verifier.VerifyTable(ptr);
}
- case Attribute_ReluNAttribute: {
- auto ptr = reinterpret_cast<const tosa::ReluNAttribute *>(obj);
+ case Attribute_PadAttribute: {
+ auto ptr = reinterpret_cast<const tosa::PadAttribute *>(obj);
return verifier.VerifyTable(ptr);
}
case Attribute_AxisAttribute: {
@@ -2546,6 +2708,14 @@ inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, At
auto ptr = reinterpret_cast<const tosa::WhileLoopAttribute *>(obj);
return verifier.VerifyTable(ptr);
}
+ case Attribute_TransposeAttribute: {
+ auto ptr = reinterpret_cast<const tosa::TransposeAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_TableAttribute: {
+ auto ptr = reinterpret_cast<const tosa::TableAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return true;
}
}
diff --git a/python/tosa/Attribute.py b/python/tosa/Attribute.py
index 334b4ae..478796f 100644
--- a/python/tosa/Attribute.py
+++ b/python/tosa/Attribute.py
@@ -22,7 +22,7 @@ class Attribute(object):
PoolAttribute = 1
ConvAttribute = 2
TransposeConvAttribute = 3
- ReluNAttribute = 4
+ PadAttribute = 4
AxisAttribute = 5
ReshapeAttribute = 6
SliceAttribute = 7
@@ -34,4 +34,6 @@ class Attribute(object):
ArithmeticRightShiftAttribute = 13
CondIfAttribute = 14
WhileLoopAttribute = 15
+ TransposeAttribute = 16
+ TableAttribute = 17
diff --git a/python/tosa/Op.py b/python/tosa/Op.py
index 0c29224..a3de87d 100644
--- a/python/tosa/Op.py
+++ b/python/tosa/Op.py
@@ -29,7 +29,7 @@ class Op(object):
MAX_POOL2D = 8
TRANSPOSE_CONV2D = 9
CLAMP = 10
- RELUN = 11
+ RESERVED = 11
SIGMOID = 12
TANH = 13
ADD = 14
diff --git a/python/tosa/PadAttribute.py b/python/tosa/PadAttribute.py
new file mode 100644
index 0000000..64dc4c7
--- /dev/null
+++ b/python/tosa/PadAttribute.py
@@ -0,0 +1,88 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020-2021, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+from flatbuffers.compat import import_numpy
+np = import_numpy()
+
+class PadAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsPadAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = PadAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ @classmethod
+ def PadAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed)
+
+ # PadAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # PadAttribute
+ def Padding(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # PadAttribute
+ def PaddingAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # PadAttribute
+ def PaddingLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # PadAttribute
+ def PaddingIsNone(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ return o == 0
+
+ # PadAttribute
+ def PadConstInt(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+ # PadAttribute
+ def PadConstFp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
+ return 0.0
+
+def PadAttributeStart(builder): builder.StartObject(3)
+def PadAttributeAddPadding(builder, padding): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(padding), 0)
+def PadAttributeStartPaddingVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def PadAttributeAddPadConstInt(builder, padConstInt): builder.PrependInt32Slot(1, padConstInt, 0)
+def PadAttributeAddPadConstFp(builder, padConstFp): builder.PrependFloat32Slot(2, padConstFp, 0.0)
+def PadAttributeEnd(builder): return builder.EndObject()
diff --git a/python/tosa/ReluNAttribute.py b/python/tosa/TableAttribute.py
index b96d701..8cdbb65 100644
--- a/python/tosa/ReluNAttribute.py
+++ b/python/tosa/TableAttribute.py
@@ -21,39 +21,52 @@ import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()
-class ReluNAttribute(object):
+class TableAttribute(object):
__slots__ = ['_tab']
@classmethod
- def GetRootAsReluNAttribute(cls, buf, offset):
+ def GetRootAsTableAttribute(cls, buf, offset):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
- x = ReluNAttribute()
+ x = TableAttribute()
x.Init(buf, n + offset)
return x
@classmethod
- def ReluNAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
+ def TableAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed)
- # ReluNAttribute
+ # TableAttribute
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
- # ReluNAttribute
- def MaxInt(self):
+ # TableAttribute
+ def Table(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
- return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0
- # ReluNAttribute
- def MaxFp(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ # TableAttribute
+ def TableAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # TableAttribute
+ def TableLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
- return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
- return 0.0
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TableAttribute
+ def TableIsNone(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ return o == 0
-def ReluNAttributeStart(builder): builder.StartObject(2)
-def ReluNAttributeAddMaxInt(builder, maxInt): builder.PrependInt32Slot(0, maxInt, 0)
-def ReluNAttributeAddMaxFp(builder, maxFp): builder.PrependFloat32Slot(1, maxFp, 0.0)
-def ReluNAttributeEnd(builder): return builder.EndObject()
+def TableAttributeStart(builder): builder.StartObject(1)
+def TableAttributeAddTable(builder, table): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(table), 0)
+def TableAttributeStartTableVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TableAttributeEnd(builder): return builder.EndObject()
diff --git a/python/tosa/TransposeAttribute.py b/python/tosa/TransposeAttribute.py
new file mode 100644
index 0000000..399a308
--- /dev/null
+++ b/python/tosa/TransposeAttribute.py
@@ -0,0 +1,72 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020-2021, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+from flatbuffers.compat import import_numpy
+np = import_numpy()
+
+class TransposeAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsTransposeAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = TransposeAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ @classmethod
+ def TransposeAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed)
+
+ # TransposeAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # TransposeAttribute
+ def Perm(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # TransposeAttribute
+ def PermAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # TransposeAttribute
+ def PermLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TransposeAttribute
+ def PermIsNone(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ return o == 0
+
+def TransposeAttributeStart(builder): builder.StartObject(1)
+def TransposeAttributeAddPerm(builder, perm): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(perm), 0)
+def TransposeAttributeStartPermVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TransposeAttributeEnd(builder): return builder.EndObject()
diff --git a/python/tosa_serializer.py b/python/tosa_serializer.py
index d85494d..3d0019e 100644
--- a/python/tosa_serializer.py
+++ b/python/tosa_serializer.py
@@ -170,14 +170,15 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.TransposeConvAttributeAddDilation, dilation))
self.intvecs.append((a.TransposeConvAttributeAddOutputShape, output_shape))
- def ReluNAttribute(self, maxint, maxfp):
- from tosa import ReluNAttribute as a, Attribute
+ def PadAttribute(self, padding, pad_const_int, pad_const_fp):
+ from tosa import PadAttribute as a, Attribute
- self.utype = Attribute.Attribute().ReluNAttribute
- self.optFcns = (a.ReluNAttributeStart, a.ReluNAttributeEnd)
+ self.utype = Attribute.Attribute().PadAttribute
+ self.optFcns = (a.PadAttributeStart, a.PadAttributeEnd)
- self.ints.append((a.ReluNAttributeAddMaxInt, maxint))
- self.ints.append((a.ReluNAttributeAddMaxFp, maxfp))
+ self.intvecs.append((a.PadAttributeAddPadding, padding))
+ self.ints.append((a.PadAttributeAddPadConstInt, pad_const_int))
+ self.floats.append((a.PadAttributeAddPadConstFp, pad_const_fp))
def AxisAttribute(self, axis):
from tosa import AxisAttribute as a, Attribute
@@ -275,14 +276,6 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.bools.append((a.ArithmeticRightShiftAttributeAddRound, round))
- def CustomAttribute(self, identifier):
- from tosa import CustomAttribute as a, Attribute
-
- self.utype = Attribute.Attribute().CustomAttribute
- self.optFcns = (a.CustomAttributeStart, a.CustomAttributeEnd)
-
- self.strings.append((a.CustomAttributeAddIdentifier, identifier))
-
def CondIfAttribute(self, then_branch, else_branch):
from tosa import CondIfAttribute as a, Attribute
@@ -301,6 +294,21 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.strings.append((a.WhileLoopAttributeAddCondBranch, cond_branch))
self.strings.append((a.WhileLoopAttributeAddBodyBranch, body_branch))
+ def TransposeAttribute(self, perm):
+ from tosa import TransposeAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().TransposeAttribute
+ self.optFcns = (a.TransposeAttributeStart, a.TransposeAttributeEnd)
+
+ self.intvecs.append((a.TransposeAttributeAddPerm, perm))
+
+ def TableAttribute(self, table):
+ from tosa import TableAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().TableAttribute
+ self.optFcns = (a.TableAttributeStart, a.TableAttributeEnd)
+
+ self.intvecs.append((a.TableAttributeAddTable, table))
class TosaSerializerQuantInfo(TosaSerializerUnion):
"""This class handles encapsulating all of the enumerated types for quantinfo types"""
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index 2e77fe5..8977f59 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -54,7 +54,7 @@ enum Op:uint32 {
// Activation
CLAMP,
- RELUN,
+ RESERVED,
SIGMOID,
TANH,
@@ -142,7 +142,7 @@ union Attribute {
PoolAttribute,
ConvAttribute,
TransposeConvAttribute,
- ReluNAttribute,
+ PadAttribute,
AxisAttribute,
ReshapeAttribute,
SliceAttribute,
@@ -154,6 +154,8 @@ union Attribute {
ArithmeticRightShiftAttribute,
CondIfAttribute,
WhileLoopAttribute,
+ TransposeAttribute,
+ TableAttribute,
}
table PoolAttribute {
@@ -175,9 +177,10 @@ table TransposeConvAttribute {
output_shape: [int32];
}
-table ReluNAttribute {
- max_int: int32;
- max_fp: float;
+table PadAttribute {
+ padding: [int32];
+ pad_const_int: int32;
+ pad_const_fp: float;
}
table AxisAttribute {
@@ -242,6 +245,14 @@ table WhileLoopAttribute {
body_branch: string;
}
+table TransposeAttribute {
+ perm: [int32];
+}
+
+table TableAttribute {
+ table: [int32];
+}
+
union QuantInfo {
UnaryQuantInfo,
ConvQuantInfo,