diff options
Diffstat (limited to 'serialization')
-rw-r--r-- | serialization/attribute.def | 6 | ||||
-rw-r--r-- | serialization/operator.def | 4 | ||||
-rw-r--r-- | serialization/tosa.fbs | 11 | ||||
-rw-r--r-- | serialization/tosa_generated.h | 126 |
4 files changed, 104 insertions, 43 deletions
diff --git a/serialization/attribute.def b/serialization/attribute.def index 88e8c81..d937395 100644 --- a/serialization/attribute.def +++ b/serialization/attribute.def @@ -81,6 +81,12 @@ DEF_ATTRIBUTE(Rescale, 7, bool, S, double_round, bool, S, per_channel) +DEF_ATTRIBUTE(Mul, 1, + int32_t, S, shift) + +DEF_ATTRIBUTE(ArithmeticRightShift, 1, + bool, S, round) + DEF_ATTRIBUTE(CondIf, 2, string, S, then_branch, string, S, else_branch) diff --git a/serialization/operator.def b/serialization/operator.def index 66d3784..267976c 100644 --- a/serialization/operator.def +++ b/serialization/operator.def @@ -45,7 +45,7 @@ DEF_OPERATOR(tanh, TANH, Tanh, /* elementwise - binary */ DEF_OPERATOR(add, ADD, Add, None, None) -DEF_OPERATOR(arithmetic_right_shift, ARITHMETIC_RIGHT_SHIFT, ArithmeticRightShift, None, None) +DEF_OPERATOR(arithmetic_right_shift, ARITHMETIC_RIGHT_SHIFT, ArithmeticRightShift, ArithmeticRightShift, None) DEF_OPERATOR(bitwise_and, BITWISE_AND, BitwiseAnd, None, None) DEF_OPERATOR(bitwise_or, BITWISE_OR, BitwiseOr, None, None) DEF_OPERATOR(bitwise_xor, BITWISE_XOR, BitwiseXor, None, None) @@ -56,7 +56,7 @@ DEF_OPERATOR(logical_or, LOGICAL_OR, LogicalOr, DEF_OPERATOR(logical_xor, LOGICAL_XOR, LogicalXor, None, None) DEF_OPERATOR(maximum, MAXIMUM, Maximum, None, None) DEF_OPERATOR(minimum, MINIMUM, Minimum, None, None) -DEF_OPERATOR(mul, MUL, Mul, None, None) +DEF_OPERATOR(mul, MUL, Mul, Mul, None) DEF_OPERATOR(pow, POW, Pow, None, None) DEF_OPERATOR(sub, SUB, Sub, None, None) DEF_OPERATOR(table, TABLE, Table, None, None) diff --git a/serialization/tosa.fbs b/serialization/tosa.fbs index 841cf3d..f57d9dc 100644 --- a/serialization/tosa.fbs +++ b/serialization/tosa.fbs @@ -167,7 +167,8 @@ union Attribute { ResizeAttribute, ClampAttribute, RescaleAttribute, - CustomAttribute, + MulAttribute, + ArithmeticRightShiftAttribute, CondIfAttribute, WhileLoopAttribute, } @@ -238,8 +239,12 @@ table RescaleAttribute { per_channel: bool; } -table CustomAttribute { - identifier: string; +table MulAttribute { + shift: int32; +} + +table ArithmeticRightShiftAttribute { + round: bool; } table CondIfAttribute { diff --git a/serialization/tosa_generated.h b/serialization/tosa_generated.h index 5bb21f3..5140f7b 100644 --- a/serialization/tosa_generated.h +++ b/serialization/tosa_generated.h @@ -45,7 +45,9 @@ struct ClampAttribute; struct RescaleAttribute; -struct CustomAttribute; +struct MulAttribute; + +struct ArithmeticRightShiftAttribute; struct CondIfAttribute; @@ -478,14 +480,15 @@ enum Attribute { Attribute_ResizeAttribute = 9, Attribute_ClampAttribute = 10, Attribute_RescaleAttribute = 11, - Attribute_CustomAttribute = 12, - Attribute_CondIfAttribute = 13, - Attribute_WhileLoopAttribute = 14, + Attribute_MulAttribute = 12, + Attribute_ArithmeticRightShiftAttribute = 13, + Attribute_CondIfAttribute = 14, + Attribute_WhileLoopAttribute = 15, Attribute_MIN = Attribute_NONE, Attribute_MAX = Attribute_WhileLoopAttribute }; -inline const Attribute (&EnumValuesAttribute())[15] { +inline const Attribute (&EnumValuesAttribute())[16] { static const Attribute values[] = { Attribute_NONE, Attribute_Pool2dAttribute, @@ -499,7 +502,8 @@ inline const Attribute (&EnumValuesAttribute())[15] { Attribute_ResizeAttribute, Attribute_ClampAttribute, Attribute_RescaleAttribute, - Attribute_CustomAttribute, + Attribute_MulAttribute, + Attribute_ArithmeticRightShiftAttribute, Attribute_CondIfAttribute, Attribute_WhileLoopAttribute }; @@ -520,7 +524,8 @@ inline const char * const *EnumNamesAttribute() { "ResizeAttribute", "ClampAttribute", "RescaleAttribute", - "CustomAttribute", + "MulAttribute", + "ArithmeticRightShiftAttribute", "CondIfAttribute", "WhileLoopAttribute", nullptr @@ -582,8 +587,12 @@ template<> struct AttributeTraits<RescaleAttribute> { static const Attribute enum_value = Attribute_RescaleAttribute; }; -template<> struct AttributeTraits<CustomAttribute> { - static const Attribute enum_value = Attribute_CustomAttribute; +template<> struct AttributeTraits<MulAttribute> { + static const Attribute enum_value = Attribute_MulAttribute; +}; + +template<> struct AttributeTraits<ArithmeticRightShiftAttribute> { + static const Attribute enum_value = Attribute_ArithmeticRightShiftAttribute; }; template<> struct AttributeTraits<CondIfAttribute> { @@ -1457,54 +1466,84 @@ inline flatbuffers::Offset<RescaleAttribute> CreateRescaleAttributeDirect( per_channel); } -struct CustomAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { +struct MulAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_IDENTIFIER = 4 + VT_SHIFT = 4 }; - const flatbuffers::String *identifier() const { - return GetPointer<const flatbuffers::String *>(VT_IDENTIFIER); + int32_t shift() const { + return GetField<int32_t>(VT_SHIFT, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_IDENTIFIER) && - verifier.VerifyString(identifier()) && + VerifyField<int32_t>(verifier, VT_SHIFT) && verifier.EndTable(); } }; -struct CustomAttributeBuilder { +struct MulAttributeBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_identifier(flatbuffers::Offset<flatbuffers::String> identifier) { - fbb_.AddOffset(CustomAttribute::VT_IDENTIFIER, identifier); + void add_shift(int32_t shift) { + fbb_.AddElement<int32_t>(MulAttribute::VT_SHIFT, shift, 0); } - explicit CustomAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit MulAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - CustomAttributeBuilder &operator=(const CustomAttributeBuilder &); - flatbuffers::Offset<CustomAttribute> Finish() { + MulAttributeBuilder &operator=(const MulAttributeBuilder &); + flatbuffers::Offset<MulAttribute> Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<CustomAttribute>(end); + auto o = flatbuffers::Offset<MulAttribute>(end); return o; } }; -inline flatbuffers::Offset<CustomAttribute> CreateCustomAttribute( +inline flatbuffers::Offset<MulAttribute> CreateMulAttribute( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset<flatbuffers::String> identifier = 0) { - CustomAttributeBuilder builder_(_fbb); - builder_.add_identifier(identifier); + int32_t shift = 0) { + MulAttributeBuilder builder_(_fbb); + builder_.add_shift(shift); return builder_.Finish(); } -inline flatbuffers::Offset<CustomAttribute> CreateCustomAttributeDirect( +struct ArithmeticRightShiftAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ROUND = 4 + }; + bool round() const { + return GetField<uint8_t>(VT_ROUND, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<uint8_t>(verifier, VT_ROUND) && + verifier.EndTable(); + } +}; + +struct ArithmeticRightShiftAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_round(bool round) { + fbb_.AddElement<uint8_t>(ArithmeticRightShiftAttribute::VT_ROUND, static_cast<uint8_t>(round), 0); + } + explicit ArithmeticRightShiftAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ArithmeticRightShiftAttributeBuilder &operator=(const ArithmeticRightShiftAttributeBuilder &); + flatbuffers::Offset<ArithmeticRightShiftAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<ArithmeticRightShiftAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<ArithmeticRightShiftAttribute> CreateArithmeticRightShiftAttribute( flatbuffers::FlatBufferBuilder &_fbb, - const char *identifier = nullptr) { - auto identifier__ = identifier ? _fbb.CreateString(identifier) : 0; - return tosa::CreateCustomAttribute( - _fbb, - identifier__); + bool round = false) { + ArithmeticRightShiftAttributeBuilder builder_(_fbb); + builder_.add_round(round); + return builder_.Finish(); } struct CondIfAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -2066,8 +2105,11 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const RescaleAttribute *attribute_as_RescaleAttribute() const { return attribute_type() == Attribute_RescaleAttribute ? static_cast<const RescaleAttribute *>(attribute()) : nullptr; } - const CustomAttribute *attribute_as_CustomAttribute() const { - return attribute_type() == Attribute_CustomAttribute ? static_cast<const CustomAttribute *>(attribute()) : nullptr; + const MulAttribute *attribute_as_MulAttribute() const { + return attribute_type() == Attribute_MulAttribute ? static_cast<const MulAttribute *>(attribute()) : nullptr; + } + const ArithmeticRightShiftAttribute *attribute_as_ArithmeticRightShiftAttribute() const { + return attribute_type() == Attribute_ArithmeticRightShiftAttribute ? static_cast<const ArithmeticRightShiftAttribute *>(attribute()) : nullptr; } const CondIfAttribute *attribute_as_CondIfAttribute() const { return attribute_type() == Attribute_CondIfAttribute ? static_cast<const CondIfAttribute *>(attribute()) : nullptr; @@ -2163,8 +2205,12 @@ template<> inline const RescaleAttribute *TosaOperator::attribute_as<RescaleAttr return attribute_as_RescaleAttribute(); } -template<> inline const CustomAttribute *TosaOperator::attribute_as<CustomAttribute>() const { - return attribute_as_CustomAttribute(); +template<> inline const MulAttribute *TosaOperator::attribute_as<MulAttribute>() const { + return attribute_as_MulAttribute(); +} + +template<> inline const ArithmeticRightShiftAttribute *TosaOperator::attribute_as<ArithmeticRightShiftAttribute>() const { + return attribute_as_ArithmeticRightShiftAttribute(); } template<> inline const CondIfAttribute *TosaOperator::attribute_as<CondIfAttribute>() const { @@ -2492,8 +2538,12 @@ inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, At auto ptr = reinterpret_cast<const RescaleAttribute *>(obj); return verifier.VerifyTable(ptr); } - case Attribute_CustomAttribute: { - auto ptr = reinterpret_cast<const CustomAttribute *>(obj); + case Attribute_MulAttribute: { + auto ptr = reinterpret_cast<const MulAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_ArithmeticRightShiftAttribute: { + auto ptr = reinterpret_cast<const ArithmeticRightShiftAttribute *>(obj); return verifier.VerifyTable(ptr); } case Attribute_CondIfAttribute: { |