diff options
-rw-r--r-- | include/attribute.def | 29 | ||||
-rw-r--r-- | include/quant_info.def | 43 | ||||
-rw-r--r-- | include/quant_info.h | 163 | ||||
-rw-r--r-- | include/tosa_generated.h | 486 | ||||
-rw-r--r-- | include/tosa_serialization_handler.h | 22 | ||||
-rw-r--r-- | python/serializer/tosa_serializer.py | 192 | ||||
-rw-r--r-- | python/tosa/Attribute.py | 3 | ||||
-rw-r--r-- | python/tosa/ConvAttribute.py | 18 | ||||
-rw-r--r-- | python/tosa/FullyConnectedAttribute.py (renamed from python/tosa/ConvQuantInfo.py) | 22 | ||||
-rw-r--r-- | python/tosa/MatMulAttribute.py (renamed from python/tosa/MatMulQuantInfo.py) | 22 | ||||
-rw-r--r-- | python/tosa/NegateAttribute.py (renamed from python/tosa/UnaryQuantInfo.py) | 24 | ||||
-rw-r--r-- | python/tosa/PadQuantInfo.py | 36 | ||||
-rw-r--r-- | python/tosa/PoolAttribute.py | 18 | ||||
-rw-r--r-- | python/tosa/QuantInfo.py | 11 | ||||
-rw-r--r-- | python/tosa/TosaOperator.py | 21 | ||||
-rw-r--r-- | python/tosa/TransposeConvAttribute.py | 18 | ||||
-rw-r--r-- | python/tosa/Version.py | 4 | ||||
-rw-r--r-- | schema/tosa.fbs | 37 | ||||
-rw-r--r-- | src/tosa_serialization_handler.cpp | 139 |
19 files changed, 466 insertions, 842 deletions
diff --git a/include/attribute.def b/include/attribute.def index 8ac5d8d..ea91869 100644 --- a/include/attribute.def +++ b/include/attribute.def @@ -26,20 +26,26 @@ ...: variadic variables for more arguments, depending on NUM_ARGS_IN_ATTRIBUTES */ -DEF_ATTRIBUTE(Pool, 3, +DEF_ATTRIBUTE(Pool, 5, int32_t, V, pad, int32_t, V, kernel, - int32_t, V, stride) + int32_t, V, stride, + int32_t, S, input_zp, + int32_t, S, output_zp) -DEF_ATTRIBUTE(Conv, 3, +DEF_ATTRIBUTE(Conv, 5, int32_t, V, pad, int32_t, V, stride, - int32_t, V, dilation) + int32_t, V, dilation, + int32_t, S, input_zp, + int32_t, S, weight_zp) -DEF_ATTRIBUTE(TransposeConv, 3, +DEF_ATTRIBUTE(TransposeConv, 5, int32_t, V, out_pad, int32_t, V, stride, - int32_t, V, output_shape) + int32_t, V, output_shape, + int32_t, S, input_zp, + int32_t, S, weight_zp) DEF_ATTRIBUTE(Pad, 3, int32_t, V, padding, @@ -103,3 +109,14 @@ DEF_ATTRIBUTE(Transpose, 1, DEF_ATTRIBUTE(Table, 1, int16_t, V, table) +DEF_ATTRIBUTE(MatMul, 2, + int32_t, S, a_zp, + int32_t, S, b_zp) + +DEF_ATTRIBUTE(FullyConnected, 2, + int32_t, S, input_zp, + int32_t, S, weight_zp) + +DEF_ATTRIBUTE(Negate, 2, + int32_t, S, input1_zp, + int32_t, S, output_zp) diff --git a/include/quant_info.def b/include/quant_info.def deleted file mode 100644 index 888c183..0000000 --- a/include/quant_info.def +++ /dev/null @@ -1,43 +0,0 @@ - -// Copyright (c) 2020-2021, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* - Syntax: - DEF_QUANTIZATION_INFO(NAME, NUM_ARGS_IN_OPTIONS, ARG0_TYPE, ARG0_SCALAR_OR_VECTOR, ARGS0_NAME, ...) - - Description: - NAME: corresponding quantization info name, must match corresponding "table XXXQuantInfo" in tosa.fbs - NUM_ARGS_IN_QINFO: number of arguments in this quantization info - ARG0_TYPE: data type of arg0 - ARG0_SCALAR_OR_VECTOR: is arg0 a scalar (S) or a vector (V) - ARG0_NAME: name of arg0 - ...: variadic variables for more arguments, depending on NUM_ARGS_IN_QINFO -*/ - - -DEF_QUANTIZATION_INFO(Unary, 2, - int32_t, S, input_zp, - int32_t, S, output_zp) - -DEF_QUANTIZATION_INFO(Conv, 2, - int32_t, S, input_zp, - int32_t, S, weight_zp) - -DEF_QUANTIZATION_INFO(MatMul, 2, - int32_t, S, a_zp, - int32_t, S, b_zp) - -DEF_QUANTIZATION_INFO(Pad, 1, - int32_t, S, input_zp) diff --git a/include/quant_info.h b/include/quant_info.h deleted file mode 100644 index c7daeb2..0000000 --- a/include/quant_info.h +++ /dev/null @@ -1,163 +0,0 @@ - -// Copyright (c) 2020-2021, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef _TOSA_SERIALIZATION_QUANT_INFO_H -#define _TOSA_SERIALIZATION_QUANT_INFO_H -#include "flatbuffers/idl.h" -#include "flatbuffers/util.h" -#include "tosa_generated.h" - -namespace tosa -{ - -class TosaQuantInfoBase -{ -public: - virtual ~TosaQuantInfoBase() - {} -}; - -class TosaNoneQuantInfo : public TosaQuantInfoBase -{ -public: - TosaNoneQuantInfo() - {} - TosaNoneQuantInfo(TosaNoneQuantInfo* p) - {} -}; - -#define DEF_ARGS_VER0_S(T, V) _##V = p->V(); -#define DEF_ARGS_VER0_V(T, V) _##V = std::vector<T>(p->V()->begin(), p->V()->end()); -#define DEF_ARGS_VER1_S(T, V) const T& V -#define DEF_ARGS_VER1_V(T, V) const std::vector<T>& V -#define DEF_ARGS_VER2_S(T, V) _##V = V; -#define DEF_ARGS_VER2_V(T, V) _##V = V; -#define DEF_ARGS_VER3_S(T, V) \ - T V() const \ - { \ - return _##V; \ - } -#define DEF_ARGS_VER3_V(T, V) \ - std::vector<T> V() const \ - { \ - return _##V; \ - } -#define DEF_ARGS_VER4_S(T, V) T _##V; -#define DEF_ARGS_VER4_V(T, V) std::vector<T> _##V; - -// another level of preprocessor indirection to handle ", " as function's input argument -#define DEF_ARGS_VER1_TRUE(T, F, V) DEF_ARGS_VER1_##F(T, V) -#define DEF_ARGS_VER1_FALSE(T, F, V) , DEF_ARGS_VER1_##F(T, V) - -#define DEF_ARGS_VER0(FIRST, T, F, V) DEF_ARGS_VER0_##F(T, V) -#define DEF_ARGS_VER1(FIRST, T, F, V) DEF_ARGS_VER1_##FIRST(T, F, V) -#define DEF_ARGS_VER2(FIRST, T, F, V) DEF_ARGS_VER2_##F(T, V) -#define DEF_ARGS_VER3(FIRST, T, F, V) DEF_ARGS_VER3_##F(T, V) -#define DEF_ARGS_VER4(FIRST, T, F, V) DEF_ARGS_VER4_##F(T, V) - -#define DEF_ARGS_1(VER, T0, F0, V0) DEF_ARGS_##VER(TRUE, T0, F0, V0) -#define DEF_ARGS_2(VER, T0, F0, V0, T1, F1, V1) DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) -#define DEF_ARGS_3(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) -#define DEF_ARGS_4(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) -#define DEF_ARGS_5(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) -#define DEF_ARGS_6(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) -#define DEF_ARGS_7(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ - DEF_ARGS_##VER(FALSE, T6, F6, V6) -#define DEF_ARGS_8(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ - V7) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ - DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) -#define DEF_ARGS_9(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ - V7, T8, F8, V8) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ - DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) DEF_ARGS_##VER(FALSE, T8, F8, V8) -#define DEF_ARGS_10(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ - V7, T8, F8, V8, T9, F9, V9) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ - DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) DEF_ARGS_##VER(FALSE, T8, F8, V8) \ - DEF_ARGS_##VER(FALSE, T9, F9, V9) - -#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \ - class Tosa##NAME##QuantInfo : public TosaQuantInfoBase \ - { \ - public: \ - Tosa##NAME##QuantInfo(const TosaQuantInfoBase* qinfo) \ - { \ - const Tosa##NAME##QuantInfo* p = static_cast<const Tosa##NAME##QuantInfo*>(qinfo); \ - *this = *p; \ - } \ - Tosa##NAME##QuantInfo(const Tosa##NAME##QuantInfo* p) \ - { \ - *this = *p; \ - } \ - Tosa##NAME##QuantInfo(const void* qinfo) \ - { \ - const NAME##QuantInfo* p = static_cast<const NAME##QuantInfo*>(qinfo); \ - DEF_ARGS_##NUM_ARGS(VER0, __VA_ARGS__) \ - } \ - Tosa##NAME##QuantInfo(DEF_ARGS_##NUM_ARGS(VER1, __VA_ARGS__)) \ - { \ - DEF_ARGS_##NUM_ARGS(VER2, __VA_ARGS__) \ - } \ - virtual ~Tosa##NAME##QuantInfo() \ - {} \ - DEF_ARGS_##NUM_ARGS(VER3, __VA_ARGS__) private : DEF_ARGS_##NUM_ARGS(VER4, __VA_ARGS__) \ - }; - -#include "quant_info.def" -#undef DEF_QUANTIZATION_INFO -#undef DEF_ARGS_1 -#undef DEF_ARGS_2 -#undef DEF_ARGS_3 -#undef DEF_ARGS_4 -#undef DEF_ARGS_5 -#undef DEF_ARGS_6 -#undef DEF_ARGS_7 -#undef DEF_ARGS_8 -#undef DEF_ARGS_9 -#undef DEF_ARGS_10 -#undef DEF_ARGS_VER0 -#undef DEF_ARGS_VER1 -#undef DEF_ARGS_VER2 -#undef DEF_ARGS_VER3 -#undef DEF_ARGS_VER4 -#undef DEF_ARGS_VER1_TRUE -#undef DEF_ARGS_VER1_FALSE -#undef DEF_ARGS_VER0_S -#undef DEF_ARGS_VER0_V -#undef DEF_ARGS_VER1_S -#undef DEF_ARGS_VER1_V -#undef DEF_ARGS_VER2_S -#undef DEF_ARGS_VER2_V -#undef DEF_ARGS_VER3_S -#undef DEF_ARGS_VER3_V -#undef DEF_ARGS_VER4_S -#undef DEF_ARGS_VER4_V - -} // namespace tosa - -#endif diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 58fe4a7..bd56e4e 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -59,17 +59,14 @@ struct TransposeAttributeBuilder; struct TableAttribute; struct TableAttributeBuilder; -struct UnaryQuantInfo; -struct UnaryQuantInfoBuilder; +struct MatMulAttribute; +struct MatMulAttributeBuilder; -struct ConvQuantInfo; -struct ConvQuantInfoBuilder; +struct FullyConnectedAttribute; +struct FullyConnectedAttributeBuilder; -struct MatMulQuantInfo; -struct MatMulQuantInfoBuilder; - -struct PadQuantInfo; -struct PadQuantInfoBuilder; +struct NegateAttribute; +struct NegateAttributeBuilder; struct Version; struct VersionBuilder; @@ -423,11 +420,14 @@ enum Attribute { Attribute_WhileLoopAttribute = 15, Attribute_TransposeAttribute = 16, Attribute_TableAttribute = 17, + Attribute_MatMulAttribute = 18, + Attribute_FullyConnectedAttribute = 19, + Attribute_NegateAttribute = 20, Attribute_MIN = Attribute_NONE, - Attribute_MAX = Attribute_TableAttribute + Attribute_MAX = Attribute_NegateAttribute }; -inline const Attribute (&EnumValuesAttribute())[18] { +inline const Attribute (&EnumValuesAttribute())[21] { static const Attribute values[] = { Attribute_NONE, Attribute_PoolAttribute, @@ -446,13 +446,16 @@ inline const Attribute (&EnumValuesAttribute())[18] { Attribute_CondIfAttribute, Attribute_WhileLoopAttribute, Attribute_TransposeAttribute, - Attribute_TableAttribute + Attribute_TableAttribute, + Attribute_MatMulAttribute, + Attribute_FullyConnectedAttribute, + Attribute_NegateAttribute }; return values; } inline const char * const *EnumNamesAttribute() { - static const char * const names[19] = { + static const char * const names[22] = { "NONE", "PoolAttribute", "ConvAttribute", @@ -471,13 +474,16 @@ inline const char * const *EnumNamesAttribute() { "WhileLoopAttribute", "TransposeAttribute", "TableAttribute", + "MatMulAttribute", + "FullyConnectedAttribute", + "NegateAttribute", nullptr }; return names; } inline const char *EnumNameAttribute(Attribute e) { - if (flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_TableAttribute)) return ""; + if (flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_NegateAttribute)) return ""; const size_t index = static_cast<size_t>(e); return EnumNamesAttribute()[index]; } @@ -554,77 +560,29 @@ template<> struct AttributeTraits<tosa::TableAttribute> { static const Attribute enum_value = Attribute_TableAttribute; }; -bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type); -bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types); - -enum QuantInfo { - QuantInfo_NONE = 0, - QuantInfo_UnaryQuantInfo = 1, - QuantInfo_ConvQuantInfo = 2, - QuantInfo_MatMulQuantInfo = 3, - QuantInfo_PadQuantInfo = 4, - QuantInfo_MIN = QuantInfo_NONE, - QuantInfo_MAX = QuantInfo_PadQuantInfo -}; - -inline const QuantInfo (&EnumValuesQuantInfo())[5] { - static const QuantInfo values[] = { - QuantInfo_NONE, - QuantInfo_UnaryQuantInfo, - QuantInfo_ConvQuantInfo, - QuantInfo_MatMulQuantInfo, - QuantInfo_PadQuantInfo - }; - return values; -} - -inline const char * const *EnumNamesQuantInfo() { - static const char * const names[6] = { - "NONE", - "UnaryQuantInfo", - "ConvQuantInfo", - "MatMulQuantInfo", - "PadQuantInfo", - nullptr - }; - return names; -} - -inline const char *EnumNameQuantInfo(QuantInfo e) { - if (flatbuffers::IsOutRange(e, QuantInfo_NONE, QuantInfo_PadQuantInfo)) return ""; - const size_t index = static_cast<size_t>(e); - return EnumNamesQuantInfo()[index]; -} - -template<typename T> struct QuantInfoTraits { - static const QuantInfo enum_value = QuantInfo_NONE; -}; - -template<> struct QuantInfoTraits<tosa::UnaryQuantInfo> { - static const QuantInfo enum_value = QuantInfo_UnaryQuantInfo; +template<> struct AttributeTraits<tosa::MatMulAttribute> { + static const Attribute enum_value = Attribute_MatMulAttribute; }; -template<> struct QuantInfoTraits<tosa::ConvQuantInfo> { - static const QuantInfo enum_value = QuantInfo_ConvQuantInfo; +template<> struct AttributeTraits<tosa::FullyConnectedAttribute> { + static const Attribute enum_value = Attribute_FullyConnectedAttribute; }; -template<> struct QuantInfoTraits<tosa::MatMulQuantInfo> { - static const QuantInfo enum_value = QuantInfo_MatMulQuantInfo; +template<> struct AttributeTraits<tosa::NegateAttribute> { + static const Attribute enum_value = Attribute_NegateAttribute; }; -template<> struct QuantInfoTraits<tosa::PadQuantInfo> { - static const QuantInfo enum_value = QuantInfo_PadQuantInfo; -}; - -bool VerifyQuantInfo(flatbuffers::Verifier &verifier, const void *obj, QuantInfo type); -bool VerifyQuantInfoVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types); +bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type); +bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types); struct PoolAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef PoolAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_PAD = 4, VT_KERNEL = 6, - VT_STRIDE = 8 + VT_STRIDE = 8, + VT_INPUT_ZP = 10, + VT_OUTPUT_ZP = 12 }; const flatbuffers::Vector<int32_t> *pad() const { return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PAD); @@ -635,6 +593,12 @@ struct PoolAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector<int32_t> *stride() const { return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE); } + int32_t input_zp() const { + return GetField<int32_t>(VT_INPUT_ZP, 0); + } + int32_t output_zp() const { + return GetField<int32_t>(VT_OUTPUT_ZP, 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PAD) && @@ -643,6 +607,8 @@ struct PoolAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyVector(kernel()) && VerifyOffset(verifier, VT_STRIDE) && verifier.VerifyVector(stride()) && + VerifyField<int32_t>(verifier, VT_INPUT_ZP) && + VerifyField<int32_t>(verifier, VT_OUTPUT_ZP) && verifier.EndTable(); } }; @@ -660,6 +626,12 @@ struct PoolAttributeBuilder { void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) { fbb_.AddOffset(PoolAttribute::VT_STRIDE, stride); } + void add_input_zp(int32_t input_zp) { + fbb_.AddElement<int32_t>(PoolAttribute::VT_INPUT_ZP, input_zp, 0); + } + void add_output_zp(int32_t output_zp) { + fbb_.AddElement<int32_t>(PoolAttribute::VT_OUTPUT_ZP, output_zp, 0); + } explicit PoolAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -676,8 +648,12 @@ inline flatbuffers::Offset<PoolAttribute> CreatePoolAttribute( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset<flatbuffers::Vector<int32_t>> pad = 0, flatbuffers::Offset<flatbuffers::Vector<int32_t>> kernel = 0, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0) { + flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0, + int32_t input_zp = 0, + int32_t output_zp = 0) { PoolAttributeBuilder builder_(_fbb); + builder_.add_output_zp(output_zp); + builder_.add_input_zp(input_zp); builder_.add_stride(stride); builder_.add_kernel(kernel); builder_.add_pad(pad); @@ -688,7 +664,9 @@ inline flatbuffers::Offset<PoolAttribute> CreatePoolAttributeDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector<int32_t> *pad = nullptr, const std::vector<int32_t> *kernel = nullptr, - const std::vector<int32_t> *stride = nullptr) { + const std::vector<int32_t> *stride = nullptr, + int32_t input_zp = 0, + int32_t output_zp = 0) { auto pad__ = pad ? _fbb.CreateVector<int32_t>(*pad) : 0; auto kernel__ = kernel ? _fbb.CreateVector<int32_t>(*kernel) : 0; auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0; @@ -696,7 +674,9 @@ inline flatbuffers::Offset<PoolAttribute> CreatePoolAttributeDirect( _fbb, pad__, kernel__, - stride__); + stride__, + input_zp, + output_zp); } struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -704,7 +684,9 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_PAD = 4, VT_STRIDE = 6, - VT_DILATION = 8 + VT_DILATION = 8, + VT_INPUT_ZP = 10, + VT_WEIGHT_ZP = 12 }; const flatbuffers::Vector<int32_t> *pad() const { return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PAD); @@ -715,6 +697,12 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector<int32_t> *dilation() const { return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_DILATION); } + int32_t input_zp() const { + return GetField<int32_t>(VT_INPUT_ZP, 0); + } + int32_t weight_zp() const { + return GetField<int32_t>(VT_WEIGHT_ZP, 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PAD) && @@ -723,6 +711,8 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyVector(stride()) && VerifyOffset(verifier, VT_DILATION) && verifier.VerifyVector(dilation()) && + VerifyField<int32_t>(verifier, VT_INPUT_ZP) && + VerifyField<int32_t>(verifier, VT_WEIGHT_ZP) && verifier.EndTable(); } }; @@ -740,6 +730,12 @@ struct ConvAttributeBuilder { void add_dilation(flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation) { fbb_.AddOffset(ConvAttribute::VT_DILATION, dilation); } + void add_input_zp(int32_t input_zp) { + fbb_.AddElement<int32_t>(ConvAttribute::VT_INPUT_ZP, input_zp, 0); + } + void add_weight_zp(int32_t weight_zp) { + fbb_.AddElement<int32_t>(ConvAttribute::VT_WEIGHT_ZP, weight_zp, 0); + } explicit ConvAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -756,8 +752,12 @@ inline flatbuffers::Offset<ConvAttribute> CreateConvAttribute( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset<flatbuffers::Vector<int32_t>> pad = 0, flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation = 0) { + flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation = 0, + int32_t input_zp = 0, + int32_t weight_zp = 0) { ConvAttributeBuilder builder_(_fbb); + builder_.add_weight_zp(weight_zp); + builder_.add_input_zp(input_zp); builder_.add_dilation(dilation); builder_.add_stride(stride); builder_.add_pad(pad); @@ -768,7 +768,9 @@ inline flatbuffers::Offset<ConvAttribute> CreateConvAttributeDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector<int32_t> *pad = nullptr, const std::vector<int32_t> *stride = nullptr, - const std::vector<int32_t> *dilation = nullptr) { + const std::vector<int32_t> *dilation = nullptr, + int32_t input_zp = 0, + int32_t weight_zp = 0) { auto pad__ = pad ? _fbb.CreateVector<int32_t>(*pad) : 0; auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0; auto dilation__ = dilation ? _fbb.CreateVector<int32_t>(*dilation) : 0; @@ -776,7 +778,9 @@ inline flatbuffers::Offset<ConvAttribute> CreateConvAttributeDirect( _fbb, pad__, stride__, - dilation__); + dilation__, + input_zp, + weight_zp); } struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -784,7 +788,9 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_OUT_PAD = 4, VT_STRIDE = 6, - VT_OUTPUT_SHAPE = 8 + VT_OUTPUT_SHAPE = 8, + VT_INPUT_ZP = 10, + VT_WEIGHT_ZP = 12 }; const flatbuffers::Vector<int32_t> *out_pad() const { return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OUT_PAD); @@ -795,6 +801,12 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab const flatbuffers::Vector<int32_t> *output_shape() const { return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OUTPUT_SHAPE); } + int32_t input_zp() const { + return GetField<int32_t>(VT_INPUT_ZP, 0); + } + int32_t weight_zp() const { + return GetField<int32_t>(VT_WEIGHT_ZP, 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_OUT_PAD) && @@ -803,6 +815,8 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab verifier.VerifyVector(stride()) && VerifyOffset(verifier, VT_OUTPUT_SHAPE) && verifier.VerifyVector(output_shape()) && + VerifyField<int32_t>(verifier, VT_INPUT_ZP) && + VerifyField<int32_t>(verifier, VT_WEIGHT_ZP) && verifier.EndTable(); } }; @@ -820,6 +834,12 @@ struct TransposeConvAttributeBuilder { void add_output_shape(flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_shape) { fbb_.AddOffset(TransposeConvAttribute::VT_OUTPUT_SHAPE, output_shape); } + void add_input_zp(int32_t input_zp) { + fbb_.AddElement<int32_t>(TransposeConvAttribute::VT_INPUT_ZP, input_zp, 0); + } + void add_weight_zp(int32_t weight_zp) { + fbb_.AddElement<int32_t>(TransposeConvAttribute::VT_WEIGHT_ZP, weight_zp, 0); + } explicit TransposeConvAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -836,8 +856,12 @@ inline flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribute( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset<flatbuffers::Vector<int32_t>> out_pad = 0, flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_shape = 0) { + flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_shape = 0, + int32_t input_zp = 0, + int32_t weight_zp = 0) { TransposeConvAttributeBuilder builder_(_fbb); + builder_.add_weight_zp(weight_zp); + builder_.add_input_zp(input_zp); builder_.add_output_shape(output_shape); builder_.add_stride(stride); builder_.add_out_pad(out_pad); @@ -848,7 +872,9 @@ inline flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttributeD flatbuffers::FlatBufferBuilder &_fbb, const std::vector<int32_t> *out_pad = nullptr, const std::vector<int32_t> *stride = nullptr, - const std::vector<int32_t> *output_shape = nullptr) { + const std::vector<int32_t> *output_shape = nullptr, + int32_t input_zp = 0, + int32_t weight_zp = 0) { auto out_pad__ = out_pad ? _fbb.CreateVector<int32_t>(*out_pad) : 0; auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0; auto output_shape__ = output_shape ? _fbb.CreateVector<int32_t>(*output_shape) : 0; @@ -856,7 +882,9 @@ inline flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttributeD _fbb, out_pad__, stride__, - output_shape__); + output_shape__, + input_zp, + weight_zp); } struct PadAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -1797,60 +1825,60 @@ inline flatbuffers::Offset<TableAttribute> CreateTableAttributeDirect( table__); } -struct UnaryQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef UnaryQuantInfoBuilder Builder; +struct MatMulAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef MatMulAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_INPUT_ZP = 4, - VT_OUTPUT_ZP = 6 + VT_A_ZP = 4, + VT_B_ZP = 6 }; - int32_t input_zp() const { - return GetField<int32_t>(VT_INPUT_ZP, 0); + int32_t a_zp() const { + return GetField<int32_t>(VT_A_ZP, 0); } - int32_t output_zp() const { - return GetField<int32_t>(VT_OUTPUT_ZP, 0); + int32_t b_zp() const { + return GetField<int32_t>(VT_B_ZP, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField<int32_t>(verifier, VT_INPUT_ZP) && - VerifyField<int32_t>(verifier, VT_OUTPUT_ZP) && + VerifyField<int32_t>(verifier, VT_A_ZP) && + VerifyField<int32_t>(verifier, VT_B_ZP) && verifier.EndTable(); } }; -struct UnaryQuantInfoBuilder { - typedef UnaryQuantInfo Table; +struct MatMulAttributeBuilder { + typedef MatMulAttribute Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_input_zp(int32_t input_zp) { - fbb_.AddElement<int32_t>(UnaryQuantInfo::VT_INPUT_ZP, input_zp, 0); + void add_a_zp(int32_t a_zp) { + fbb_.AddElement<int32_t>(MatMulAttribute::VT_A_ZP, a_zp, 0); } - void add_output_zp(int32_t output_zp) { - fbb_.AddElement<int32_t>(UnaryQuantInfo::VT_OUTPUT_ZP, output_zp, 0); + void add_b_zp(int32_t b_zp) { + fbb_.AddElement<int32_t>(MatMulAttribute::VT_B_ZP, b_zp, 0); } - explicit UnaryQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit MatMulAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - UnaryQuantInfoBuilder &operator=(const UnaryQuantInfoBuilder &); - flatbuffers::Offset<UnaryQuantInfo> Finish() { + MatMulAttributeBuilder &operator=(const MatMulAttributeBuilder &); + flatbuffers::Offset<MatMulAttribute> Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<UnaryQuantInfo>(end); + auto o = flatbuffers::Offset<MatMulAttribute>(end); return o; } }; -inline flatbuffers::Offset<UnaryQuantInfo> CreateUnaryQuantInfo( +inline flatbuffers::Offset<MatMulAttribute> CreateMatMulAttribute( flatbuffers::FlatBufferBuilder &_fbb, - int32_t input_zp = 0, - int32_t output_zp = 0) { - UnaryQuantInfoBuilder builder_(_fbb); - builder_.add_output_zp(output_zp); - builder_.add_input_zp(input_zp); + int32_t a_zp = 0, + int32_t b_zp = 0) { + MatMulAttributeBuilder builder_(_fbb); + builder_.add_b_zp(b_zp); + builder_.add_a_zp(a_zp); return builder_.Finish(); } -struct ConvQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef ConvQuantInfoBuilder Builder; +struct FullyConnectedAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef FullyConnectedAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_INPUT_ZP = 4, VT_WEIGHT_ZP = 6 @@ -1869,129 +1897,87 @@ struct ConvQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } }; -struct ConvQuantInfoBuilder { - typedef ConvQuantInfo Table; +struct FullyConnectedAttributeBuilder { + typedef FullyConnectedAttribute Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_input_zp(int32_t input_zp) { - fbb_.AddElement<int32_t>(ConvQuantInfo::VT_INPUT_ZP, input_zp, 0); + fbb_.AddElement<int32_t>(FullyConnectedAttribute::VT_INPUT_ZP, input_zp, 0); } void add_weight_zp(int32_t weight_zp) { - fbb_.AddElement<int32_t>(ConvQuantInfo::VT_WEIGHT_ZP, weight_zp, 0); + fbb_.AddElement<int32_t>(FullyConnectedAttribute::VT_WEIGHT_ZP, weight_zp, 0); } - explicit ConvQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit FullyConnectedAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - ConvQuantInfoBuilder &operator=(const ConvQuantInfoBuilder &); - flatbuffers::Offset<ConvQuantInfo> Finish() { + FullyConnectedAttributeBuilder &operator=(const FullyConnectedAttributeBuilder &); + flatbuffers::Offset<FullyConnectedAttribute> Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<ConvQuantInfo>(end); + auto o = flatbuffers::Offset<FullyConnectedAttribute>(end); return o; } }; -inline flatbuffers::Offset<ConvQuantInfo> CreateConvQuantInfo( +inline flatbuffers::Offset<FullyConnectedAttribute> CreateFullyConnectedAttribute( flatbuffers::FlatBufferBuilder &_fbb, int32_t input_zp = 0, int32_t weight_zp = 0) { - ConvQuantInfoBuilder builder_(_fbb); + FullyConnectedAttributeBuilder builder_(_fbb); builder_.add_weight_zp(weight_zp); builder_.add_input_zp(input_zp); return builder_.Finish(); } -struct MatMulQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef MatMulQuantInfoBuilder Builder; +struct NegateAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef NegateAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_A_ZP = 4, - VT_B_ZP = 6 + VT_INPUT1_ZP = 4, + VT_OUTPUT_ZP = 6 }; - int32_t a_zp() const { - return GetField<int32_t>(VT_A_ZP, 0); + int32_t input1_zp() const { + return GetField<int32_t>(VT_INPUT1_ZP, 0); } - int32_t b_zp() const { - return GetField<int32_t>(VT_B_ZP, 0); + int32_t output_zp() const { + return GetField<int32_t>(VT_OUTPUT_ZP, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField<int32_t>(verifier, VT_A_ZP) && - VerifyField<int32_t>(verifier, VT_B_ZP) && + VerifyField<int32_t>(verifier, VT_INPUT1_ZP) && + VerifyField<int32_t>(verifier, VT_OUTPUT_ZP) && verifier.EndTable(); } }; -struct MatMulQuantInfoBuilder { - typedef MatMulQuantInfo Table; +struct NegateAttributeBuilder { + typedef NegateAttribute Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_a_zp(int32_t a_zp) { - fbb_.AddElement<int32_t>(MatMulQuantInfo::VT_A_ZP, a_zp, 0); - } - void add_b_zp(int32_t b_zp) { - fbb_.AddElement<int32_t>(MatMulQuantInfo::VT_B_ZP, b_zp, 0); - } - explicit MatMulQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - MatMulQuantInfoBuilder &operator=(const MatMulQuantInfoBuilder &); - flatbuffers::Offset<MatMulQuantInfo> Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<MatMulQuantInfo>(end); - return o; - } -}; - -inline flatbuffers::Offset<MatMulQuantInfo> CreateMatMulQuantInfo( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t a_zp = 0, - int32_t b_zp = 0) { - MatMulQuantInfoBuilder builder_(_fbb); - builder_.add_b_zp(b_zp); - builder_.add_a_zp(a_zp); - return builder_.Finish(); -} - -struct PadQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef PadQuantInfoBuilder Builder; - enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_INPUT_ZP = 4 - }; - int32_t input_zp() const { - return GetField<int32_t>(VT_INPUT_ZP, 0); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyField<int32_t>(verifier, VT_INPUT_ZP) && - verifier.EndTable(); + void add_input1_zp(int32_t input1_zp) { + fbb_.AddElement<int32_t>(NegateAttribute::VT_INPUT1_ZP, input1_zp, 0); } -}; - -struct PadQuantInfoBuilder { - typedef PadQuantInfo Table; - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_input_zp(int32_t input_zp) { - fbb_.AddElement<int32_t>(PadQuantInfo::VT_INPUT_ZP, input_zp, 0); + void add_output_zp(int32_t output_zp) { + fbb_.AddElement<int32_t>(NegateAttribute::VT_OUTPUT_ZP, output_zp, 0); } - explicit PadQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit NegateAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - PadQuantInfoBuilder &operator=(const PadQuantInfoBuilder &); - flatbuffers::Offset<PadQuantInfo> Finish() { + NegateAttributeBuilder &operator=(const NegateAttributeBuilder &); + flatbuffers::Offset<NegateAttribute> Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset<PadQuantInfo>(end); + auto o = flatbuffers::Offset<NegateAttribute>(end); return o; } }; -inline flatbuffers::Offset<PadQuantInfo> CreatePadQuantInfo( +inline flatbuffers::Offset<NegateAttribute> CreateNegateAttribute( flatbuffers::FlatBufferBuilder &_fbb, - int32_t input_zp = 0) { - PadQuantInfoBuilder builder_(_fbb); - builder_.add_input_zp(input_zp); + int32_t input1_zp = 0, + int32_t output_zp = 0) { + NegateAttributeBuilder builder_(_fbb); + builder_.add_output_zp(output_zp); + builder_.add_input1_zp(input1_zp); return builder_.Finish(); } @@ -2007,7 +1993,7 @@ struct Version FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetField<int32_t>(VT__MAJOR, 0); } int32_t _minor() const { - return GetField<int32_t>(VT__MINOR, 25); + return GetField<int32_t>(VT__MINOR, 30); } int32_t _patch() const { return GetField<int32_t>(VT__PATCH, 0); @@ -2033,7 +2019,7 @@ struct VersionBuilder { fbb_.AddElement<int32_t>(Version::VT__MAJOR, _major, 0); } void add__minor(int32_t _minor) { - fbb_.AddElement<int32_t>(Version::VT__MINOR, _minor, 25); + fbb_.AddElement<int32_t>(Version::VT__MINOR, _minor, 30); } void add__patch(int32_t _patch) { fbb_.AddElement<int32_t>(Version::VT__PATCH, _patch, 0); @@ -2056,7 +2042,7 @@ struct VersionBuilder { inline flatbuffers::Offset<Version> CreateVersion( flatbuffers::FlatBufferBuilder &_fbb, int32_t _major = 0, - int32_t _minor = 25, + int32_t _minor = 30, int32_t _patch = 0, bool _draft = true) { VersionBuilder builder_(_fbb); @@ -2167,9 +2153,7 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_ATTRIBUTE_TYPE = 6, VT_ATTRIBUTE = 8, VT_INPUTS = 10, - VT_OUTPUTS = 12, - VT_QUANT_INFO_TYPE = 14, - VT_QUANT_INFO = 16 + VT_OUTPUTS = 12 }; tosa::Op op() const { return static_cast<tosa::Op>(GetField<uint32_t>(VT_OP, 0)); @@ -2232,31 +2216,21 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const tosa::TableAttribute *attribute_as_TableAttribute() const { return attribute_type() == tosa::Attribute_TableAttribute ? static_cast<const tosa::TableAttribute *>(attribute()) : nullptr; } + const tosa::MatMulAttribute *attribute_as_MatMulAttribute() const { + return attribute_type() == tosa::Attribute_MatMulAttribute ? static_cast<const tosa::MatMulAttribute *>(attribute()) : nullptr; + } + const tosa::FullyConnectedAttribute *attribute_as_FullyConnectedAttribute() const { + return attribute_type() == tosa::Attribute_FullyConnectedAttribute ? static_cast<const tosa::FullyConnectedAttribute *>(attribute()) : nullptr; + } + const tosa::NegateAttribute *attribute_as_NegateAttribute() const { + return attribute_type() == tosa::Attribute_NegateAttribute ? static_cast<const tosa::NegateAttribute *>(attribute()) : nullptr; + } const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *inputs() const { return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_INPUTS); } const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *outputs() const { return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OUTPUTS); } - tosa::QuantInfo quant_info_type() const { - return static_cast<tosa::QuantInfo>(GetField<uint8_t>(VT_QUANT_INFO_TYPE, 0)); - } - const void *quant_info() const { - return GetPointer<const void *>(VT_QUANT_INFO); - } - template<typename T> const T *quant_info_as() const; - const tosa::UnaryQuantInfo *quant_info_as_UnaryQuantInfo() const { - return quant_info_type() == tosa::QuantInfo_UnaryQuantInfo ? static_cast<const tosa::UnaryQuantInfo *>(quant_info()) : nullptr; - } - const tosa::ConvQuantInfo *quant_info_as_ConvQuantInfo() const { - return quant_info_type() == tosa::QuantInfo_ConvQuantInfo ? static_cast<const tosa::ConvQuantInfo *>(quant_info()) : nullptr; - } - const tosa::MatMulQuantInfo *quant_info_as_MatMulQuantInfo() const { - return quant_info_type() == tosa::QuantInfo_MatMulQuantInfo ? static_cast<const tosa::MatMulQuantInfo *>(quant_info()) : nullptr; - } - const tosa::PadQuantInfo *quant_info_as_PadQuantInfo() const { - return quant_info_type() == tosa::QuantInfo_PadQuantInfo ? static_cast<const tosa::PadQuantInfo *>(quant_info()) : nullptr; - } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField<uint32_t>(verifier, VT_OP) && @@ -2269,9 +2243,6 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_OUTPUTS) && verifier.VerifyVector(outputs()) && verifier.VerifyVectorOfStrings(outputs()) && - VerifyField<uint8_t>(verifier, VT_QUANT_INFO_TYPE) && - VerifyOffset(verifier, VT_QUANT_INFO) && - VerifyQuantInfo(verifier, quant_info(), quant_info_type()) && verifier.EndTable(); } }; @@ -2344,20 +2315,16 @@ template<> inline const tosa::TableAttribute *TosaOperator::attribute_as<tosa::T return attribute_as_TableAttribute(); } -template<> inline const tosa::UnaryQuantInfo *TosaOperator::quant_info_as<tosa::UnaryQuantInfo>() const { - return quant_info_as_UnaryQuantInfo(); +template<> inline const tosa::MatMulAttribute *TosaOperator::attribute_as<tosa::MatMulAttribute>() const { + return attribute_as_MatMulAttribute(); } -template<> inline const tosa::ConvQuantInfo *TosaOperator::quant_info_as<tosa::ConvQuantInfo>() const { - return quant_info_as_ConvQuantInfo(); +template<> inline const tosa::FullyConnectedAttribute *TosaOperator::attribute_as<tosa::FullyConnectedAttribute>() const { + return attribute_as_FullyConnectedAttribute(); } -template<> inline const tosa::MatMulQuantInfo *TosaOperator::quant_info_as<tosa::MatMulQuantInfo>() const { - return quant_info_as_MatMulQuantInfo(); -} - -template<> inline const tosa::PadQuantInfo *TosaOperator::quant_info_as<tosa::PadQuantInfo>() const { - return quant_info_as_PadQuantInfo(); +template<> inline const tosa::NegateAttribute *TosaOperator::attribute_as<tosa::NegateAttribute>() const { + return attribute_as_NegateAttribute(); } struct TosaOperatorBuilder { @@ -2379,12 +2346,6 @@ struct TosaOperatorBuilder { void add_outputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs) { fbb_.AddOffset(TosaOperator::VT_OUTPUTS, outputs); } - void add_quant_info_type(tosa::QuantInfo quant_info_type) { - fbb_.AddElement<uint8_t>(TosaOperator::VT_QUANT_INFO_TYPE, static_cast<uint8_t>(quant_info_type), 0); - } - void add_quant_info(flatbuffers::Offset<void> quant_info) { - fbb_.AddOffset(TosaOperator::VT_QUANT_INFO, quant_info); - } explicit TosaOperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2403,16 +2364,12 @@ inline flatbuffers::Offset<TosaOperator> CreateTosaOperator( tosa::Attribute attribute_type = tosa::Attribute_NONE, flatbuffers::Offset<void> attribute = 0, flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs = 0, - flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs = 0, - tosa::QuantInfo quant_info_type = tosa::QuantInfo_NONE, - flatbuffers::Offset<void> quant_info = 0) { + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs = 0) { TosaOperatorBuilder builder_(_fbb); - builder_.add_quant_info(quant_info); builder_.add_outputs(outputs); builder_.add_inputs(inputs); builder_.add_attribute(attribute); builder_.add_op(op); - builder_.add_quant_info_type(quant_info_type); builder_.add_attribute_type(attribute_type); return builder_.Finish(); } @@ -2423,9 +2380,7 @@ inline flatbuffers::Offset<TosaOperator> CreateTosaOperatorDirect( tosa::Attribute attribute_type = tosa::Attribute_NONE, flatbuffers::Offset<void> attribute = 0, const std::vector<flatbuffers::Offset<flatbuffers::String>> *inputs = nullptr, - const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputs = nullptr, - tosa::QuantInfo quant_info_type = tosa::QuantInfo_NONE, - flatbuffers::Offset<void> quant_info = 0) { + const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputs = nullptr) { auto inputs__ = inputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*inputs) : 0; auto outputs__ = outputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*outputs) : 0; return tosa::CreateTosaOperator( @@ -2434,9 +2389,7 @@ inline flatbuffers::Offset<TosaOperator> CreateTosaOperatorDirect( attribute_type, attribute, inputs__, - outputs__, - quant_info_type, - quant_info); + outputs__); } struct TosaBasicBlock FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -2690,53 +2643,28 @@ inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, At auto ptr = reinterpret_cast<const tosa::TableAttribute *>(obj); return verifier.VerifyTable(ptr); } - default: return true; - } -} - -inline bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) { - if (!values || !types) return !values && !types; - if (values->size() != types->size()) return false; - for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { - if (!VerifyAttribute( - verifier, values->Get(i), types->GetEnum<Attribute>(i))) { - return false; - } - } - return true; -} - -inline bool VerifyQuantInfo(flatbuffers::Verifier &verifier, const void *obj, QuantInfo type) { - switch (type) { - case QuantInfo_NONE: { - return true; - } - case QuantInfo_UnaryQuantInfo: { - auto ptr = reinterpret_cast<const tosa::UnaryQuantInfo *>(obj); - return verifier.VerifyTable(ptr); - } - case QuantInfo_ConvQuantInfo: { - auto ptr = reinterpret_cast<const tosa::ConvQuantInfo *>(obj); + case Attribute_MatMulAttribute: { + auto ptr = reinterpret_cast<const tosa::MatMulAttribute *>(obj); return verifier.VerifyTable(ptr); } - case QuantInfo_MatMulQuantInfo: { - auto ptr = reinterpret_cast<const tosa::MatMulQuantInfo *>(obj); + case Attribute_FullyConnectedAttribute: { + auto ptr = reinterpret_cast<const tosa::FullyConnectedAttribute *>(obj); return verifier.VerifyTable(ptr); } - case QuantInfo_PadQuantInfo: { - auto ptr = reinterpret_cast<const tosa::PadQuantInfo *>(obj); + case Attribute_NegateAttribute: { + auto ptr = reinterpret_cast<const tosa::NegateAttribute *>(obj); return verifier.VerifyTable(ptr); } default: return true; } } -inline bool VerifyQuantInfoVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) { +inline bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) { if (!values || !types) return !values && !types; if (values->size() != types->size()) return false; for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { - if (!VerifyQuantInfo( - verifier, values->Get(i), types->GetEnum<QuantInfo>(i))) { + if (!VerifyAttribute( + verifier, values->Get(i), types->GetEnum<Attribute>(i))) { return false; } } diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 33c8047..695c530 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -19,7 +19,6 @@ #include "flatbuffers/idl.h" #include "flatbuffers/util.h" #include "numpy_utils.h" -#include "quant_info.h" #include "tosa_generated.h" #include <cstdint> #include <memory> @@ -28,7 +27,7 @@ // Keep version number in sync with the version default value with schema/tosa.fbs #define TOSA_VERSION_MAJOR 0 -#define TOSA_VERSION_MINOR 25 +#define TOSA_VERSION_MINOR 30 #define TOSA_VERSION_PATCH 0 #define TOSA_VERSION_DRAFT true #define TENSOR_BUFFER_FORCE_ALIGNMENT 8 @@ -172,15 +171,11 @@ public: TosaSerializationOperator(Op op, Attribute attribute_type, const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo, const std::vector<std::string>& input_tensor_names, const std::vector<std::string>& output_tensor_names); TosaSerializationOperator(Op op, Attribute attribute_type, const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo, std::vector<std::string>&& input_tensor_names, std::vector<std::string>&& output_tensor_names); ~TosaSerializationOperator(); @@ -198,14 +193,6 @@ public: { return _attribute; } - QuantInfo GetQInfoType() const - { - return _qinfo_type; - } - TosaQuantInfoBase* GetQInfo() const - { - return _qinfo; - } std::vector<std::string>& GetInputTensorNames() { return _input_tensor_names; @@ -216,15 +203,10 @@ public: } private: - void InitializeAttributeQinfo(Attribute attribute_type, - const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo); + void InitializeAttribute(Attribute attribute_type, const TosaAttributeBase* attribute); Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */ Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */ TosaAttributeBase* _attribute; /* real attribute class goes here */ - QuantInfo _qinfo_type; /* QuantInfo enum */ - TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */ std::vector<std::string> _input_tensor_names; /* array of input tensor names */ std::vector<std::string> _output_tensor_names; /* array of output tensor names */ }; diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index 4d7d7bf..10372e4 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -30,7 +30,7 @@ import tosa.Op as TosaOp # Keep version number in sync with the version default value with schema/tosa.fbs TOSA_VERSION_MAJOR = 0 -TOSA_VERSION_MINOR = 25 +TOSA_VERSION_MINOR = 30 TOSA_VERSION_PATCH = 0 TOSA_VERSION_DRAFT = True TOSA_VERSION = [ @@ -141,7 +141,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): def __init__(self): super().__init__() - def PoolAttribute(self, kernel, stride, pad): + def PoolAttribute(self, kernel, stride, pad, input_zp, output_zp): from tosa import PoolAttribute as a, Attribute self.utype = Attribute.Attribute().PoolAttribute @@ -150,8 +150,10 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddPad, pad)) self.intvecs.append((a.AddKernel, kernel)) self.intvecs.append((a.AddStride, stride)) + self.ints.append((a.AddInputZp, input_zp)) + self.ints.append((a.AddOutputZp, output_zp)) - def ConvAttribute(self, pad, stride, dilation): + def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp): from tosa import ConvAttribute as a, Attribute self.utype = Attribute.Attribute().ConvAttribute @@ -160,8 +162,10 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddPad, pad)) self.intvecs.append((a.AddStride, stride)) self.intvecs.append((a.AddDilation, dilation)) + self.ints.append((a.AddInputZp, input_zp)) + self.ints.append((a.AddWeightZp, weight_zp)) - def TransposeConvAttribute(self, outpad, stride, output_shape): + def TransposeConvAttribute(self, outpad, stride, output_shape, input_zp, weight_zp): from tosa import TransposeConvAttribute as a, Attribute self.utype = Attribute.Attribute().TransposeConvAttribute @@ -170,6 +174,8 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddOutPad, outpad)) self.intvecs.append((a.AddStride, stride)) self.intvecs.append((a.AddOutputShape, output_shape)) + self.ints.append((a.AddInputZp, input_zp)) + self.ints.append((a.AddWeightZp, weight_zp)) def PadAttribute(self, padding, pad_const_int, pad_const_fp): from tosa import PadAttribute as a, Attribute @@ -311,43 +317,32 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddTable, table)) + def MatMulAttribute(self, A_zp, B_zp): + from tosa import MatMulAttribute as a, Attribute -class TosaSerializerQuantInfo(TosaSerializerUnion): - """This class handles encapsulating all of the enumerated types for quantinfo""" - - def __init__(self): - super().__init__() - - def ConvQuantInfo(self, input_zp, weight_zp): - from tosa import ConvQuantInfo as q, QuantInfo + self.utype = Attribute.Attribute().MatMulAttribute + self.optFcns = (a.Start, a.End) - self.utype = QuantInfo.QuantInfo().ConvQuantInfo - self.optFcns = (q.Start, q.End) - self.ints.append((q.AddInputZp, input_zp)) - self.ints.append((q.AddWeightZp, weight_zp)) + self.ints.append((a.AddAZp, A_zp)) + self.ints.append((a.AddBZp, B_zp)) - def UnaryQuantInfo(self, input_zp, output_zp): - from tosa import UnaryQuantInfo as q, QuantInfo + def FullyConnectedAttribute(self, input_zp, weight_zp): + from tosa import FullyConnectedAttribute as a, Attribute - self.utype = QuantInfo.QuantInfo().UnaryQuantInfo - self.optFcns = (q.Start, q.End) - self.ints.append((q.AddInputZp, input_zp)) - self.ints.append((q.AddOutputZp, output_zp)) + self.utype = Attribute.Attribute().FullyConnectedAttribute + self.optFcns = (a.Start, a.End) - def MatMulQuantInfo(self, a_zp, b_zp): - from tosa import MatMulQuantInfo as q, QuantInfo + self.ints.append((a.AddInputZp, input_zp)) + self.ints.append((a.AddWeightZp, weight_zp)) - self.utype = QuantInfo.QuantInfo().MatMulQuantInfo - self.optFcns = (q.Start, q.End) - self.ints.append((q.AddAZp, a_zp)) - self.ints.append((q.AddBZp, b_zp)) + def NegateAttribute(self, input1_zp, output_zp): + from tosa import NegateAttribute as a, Attribute - def PadQuantInfo(self, input_zp): - from tosa import PadQuantInfo as q, QuantInfo + self.utype = Attribute.Attribute().NegateAttribute + self.optFcns = (a.Start, a.End) - self.utype = QuantInfo.QuantInfo().PadQuantInfo - self.optFcns = (q.Start, q.End) - self.ints.append((q.AddInputZp, input_zp)) + self.ints.append((a.AddInput1Zp, input1_zp)) + self.ints.append((a.AddOutputZp, output_zp)) class TosaSerializerTensor: @@ -467,12 +462,11 @@ class TosaSerializerTensor: class TosaSerializerOperator: - def __init__(self, op, inputs, outputs, attributes=None, quantInfo=None): + def __init__(self, op, inputs, outputs, attributes=None): self.op = op self.attributes = attributes self.inputs = TosaSerializer.toList(inputs) self.outputs = TosaSerializer.toList(outputs) - self.quantInfo = quantInfo def __str__(self): str = "Op {}\n----\n".format(self.op) @@ -491,13 +485,10 @@ class TosaSerializerOperator: fb_outputs = TosaSerializer.serializeStrVec( builder, self.outputs, TosaOperator.StartOutputsVector ) - # Need to serialize quant_info and attributes enums still + # Need to serialize attributes enums still if self.attributes is not None: fb_attributes = self.attributes.serialize(builder) - if self.quantInfo is not None: - fb_qinfo = self.quantInfo.serialize(builder) - TosaOperator.Start(builder) TosaOperator.AddOp(builder, self.op) TosaOperator.AddInputs(builder, fb_inputs) @@ -505,9 +496,6 @@ class TosaSerializerOperator: if self.attributes is not None: TosaOperator.AddAttributeType(builder, self.attributes.utype) TosaOperator.AddAttribute(builder, fb_attributes) - if self.quantInfo is not None: - TosaOperator.AddQuantInfoType(builder, self.quantInfo.utype) - TosaOperator.AddQuantInfo(builder, fb_qinfo) return TosaOperator.End(builder) @@ -544,10 +532,8 @@ class TosaSerializerBasicBlock: def addOutput(self, name): self.outputs.append(name) - def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None): - self.operators.append( - TosaSerializerOperator(op, inputs, outputs, attributes, quant_info) - ) + def addOperator(self, op, inputs, outputs, attributes=None): + self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes)) def serialize(self, builder): fb_name = builder.CreateString(self.name) @@ -671,13 +657,16 @@ class TosaSerializer: self.currBasicBlock.addOutput(name) return tens - def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None): + def addOperator(self, op, inputs, outputs, attributes=None): if op == TosaOp.Op().CONST: raise Exception("Use addConstTensor() to add CONST ops") return self.currBasicBlock.addOperator( - op, inputs, outputs, attributes, quant_info + op, + inputs, + outputs, + attributes, ) def setExpectedReturnCode(self, val, fail, desc=""): @@ -861,21 +850,48 @@ class TosaSerializer: ConvAttribute.StartDilationVector = ( ConvAttribute.ConvAttributeStartDilationVector ) + ConvAttribute.AddInputZp = ConvAttribute.ConvAttributeAddInputZp + ConvAttribute.AddWeightZp = ConvAttribute.ConvAttributeAddWeightZp ConvAttribute.End = ConvAttribute.ConvAttributeEnd - from tosa import ConvQuantInfo - - if not hasattr(ConvQuantInfo, "Start"): - ConvQuantInfo.Start = ConvQuantInfo.ConvQuantInfoStart - ConvQuantInfo.AddInputZp = ConvQuantInfo.ConvQuantInfoAddInputZp - ConvQuantInfo.AddWeightZp = ConvQuantInfo.ConvQuantInfoAddWeightZp - ConvQuantInfo.End = ConvQuantInfo.ConvQuantInfoEnd - from tosa import MatMulQuantInfo - - if not hasattr(MatMulQuantInfo, "Start"): - MatMulQuantInfo.Start = MatMulQuantInfo.MatMulQuantInfoStart - MatMulQuantInfo.AddAZp = MatMulQuantInfo.MatMulQuantInfoAddAZp - MatMulQuantInfo.AddBZp = MatMulQuantInfo.MatMulQuantInfoAddBZp - MatMulQuantInfo.End = MatMulQuantInfo.MatMulQuantInfoEnd + from tosa import FullyConnectedAttribute + + if not hasattr(FullyConnectedAttribute, "Start"): + FullyConnectedAttribute.Start = ( + FullyConnectedAttribute.FullyConnectedAttributeStart + ) + FullyConnectedAttribute.AddInputZp = ( + FullyConnectedAttribute.FullyConnectedAttributeAddInputZp + ) + FullyConnectedAttribute.AddWeightZp = ( + FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp + ) + FullyConnectedAttribute.End = ( + FullyConnectedAttribute.FullyConnectedAttributeEnd + ) + from tosa import MatMulAttribute + + if not hasattr(MatMulAttribute, "Start"): + MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart + MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp + MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp + MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd + from tosa import PoolAttribute + + if not hasattr(PoolAttribute, "Start"): + PoolAttribute.Start = PoolAttribute.PoolAttributeStart + PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad + PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector + PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel + PoolAttribute.StartKernelVector = ( + PoolAttribute.PoolAttributeStartKernelVector + ) + PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride + PoolAttribute.StartStrideVector = ( + PoolAttribute.PoolAttributeStartStrideVector + ) + PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp + PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp + PoolAttribute.End = PoolAttribute.PoolAttributeEnd from tosa import MulAttribute if not hasattr(MulAttribute, "Start"): @@ -893,12 +909,6 @@ class TosaSerializer: PadAttribute.AddPadConstInt = PadAttribute.PadAttributeAddPadConstInt PadAttribute.AddPadConstFp = PadAttribute.PadAttributeAddPadConstFp PadAttribute.End = PadAttribute.PadAttributeEnd - from tosa import PadQuantInfo - - if not hasattr(PadQuantInfo, "Start"): - PadQuantInfo.Start = PadQuantInfo.PadQuantInfoStart - PadQuantInfo.AddInputZp = PadQuantInfo.PadQuantInfoAddInputZp - PadQuantInfo.End = PadQuantInfo.PadQuantInfoEnd from tosa import PoolAttribute if not hasattr(PoolAttribute, "Start"): @@ -913,6 +923,8 @@ class TosaSerializer: PoolAttribute.StartStrideVector = ( PoolAttribute.PoolAttributeStartStrideVector ) + PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp + PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp PoolAttribute.End = PoolAttribute.PoolAttributeEnd from tosa import RescaleAttribute @@ -1048,8 +1060,6 @@ class TosaSerializer: TosaOperator.StartOutputsVector = ( TosaOperator.TosaOperatorStartOutputsVector ) - TosaOperator.AddQuantInfoType = TosaOperator.TosaOperatorAddQuantInfoType - TosaOperator.AddQuantInfo = TosaOperator.TosaOperatorAddQuantInfo TosaOperator.End = TosaOperator.TosaOperatorEnd from tosa import TosaTensor @@ -1095,16 +1105,15 @@ class TosaSerializer: TransposeConvAttribute.StartOutputShapeVector = ( TransposeConvAttribute.TransposeConvAttributeStartOutputShapeVector ) + TransposeConvAttribute.AddInputZp = ( + TransposeConvAttribute.TransposeConvAttributeAddInputZp + ) + TransposeConvAttribute.AddWeightZp = ( + TransposeConvAttribute.TransposeConvAttributeAddWeightZp + ) TransposeConvAttribute.End = ( TransposeConvAttribute.TransposeConvAttributeEnd ) - from tosa import UnaryQuantInfo - - if not hasattr(UnaryQuantInfo, "Start"): - UnaryQuantInfo.Start = UnaryQuantInfo.UnaryQuantInfoStart - UnaryQuantInfo.AddInputZp = UnaryQuantInfo.UnaryQuantInfoAddInputZp - UnaryQuantInfo.AddOutputZp = UnaryQuantInfo.UnaryQuantInfoAddOutputZp - UnaryQuantInfo.End = UnaryQuantInfo.UnaryQuantInfoEnd from tosa import Version if not hasattr(Version, "Start"): @@ -1114,6 +1123,35 @@ class TosaSerializer: Version.Add_patch = Version.VersionAdd_patch Version.Add_draft = Version.VersionAdd_draft Version.End = Version.VersionEnd + from tosa import MatMulAttribute + + if not hasattr(MatMulAttribute, "Start"): + MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart + MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp + MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp + MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd + from tosa import FullyConnectedAttribute + + if not hasattr(FullyConnectedAttribute, "Start"): + FullyConnectedAttribute.Start = ( + FullyConnectedAttribute.FullyConnectedAttributeStart + ) + FullyConnectedAttribute.AddInputZp = ( + FullyConnectedAttribute.FullyConnectedAttributeAddInputZp + ) + FullyConnectedAttribute.AddWeightZp = ( + FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp + ) + FullyConnectedAttribute.End = ( + FullyConnectedAttribute.FullyConnectedAttributeEnd + ) + from tosa import NegateAttribute + + if not hasattr(NegateAttribute, "Start"): + NegateAttribute.Start = NegateAttribute.NegateAttributeStart + NegateAttribute.AddInput1Zp = NegateAttribute.NegateAttributeAddInput1Zp + NegateAttribute.AddOutputZp = NegateAttribute.NegateAttributeAddOutputZp + NegateAttribute.End = NegateAttribute.NegateAttributeEnd from tosa import WhileLoopAttribute if not hasattr(WhileLoopAttribute, "Start"): diff --git a/python/tosa/Attribute.py b/python/tosa/Attribute.py index f0307af..166de8e 100644 --- a/python/tosa/Attribute.py +++ b/python/tosa/Attribute.py @@ -21,4 +21,7 @@ class Attribute(object): WhileLoopAttribute = 15 TransposeAttribute = 16 TableAttribute = 17 + MatMulAttribute = 18 + FullyConnectedAttribute = 19 + NegateAttribute = 20 diff --git a/python/tosa/ConvAttribute.py b/python/tosa/ConvAttribute.py index 72a24ce..8244ea5 100644 --- a/python/tosa/ConvAttribute.py +++ b/python/tosa/ConvAttribute.py @@ -105,11 +105,27 @@ class ConvAttribute(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) return o == 0 -def ConvAttributeStart(builder): builder.StartObject(3) + # ConvAttribute + def InputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # ConvAttribute + def WeightZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def ConvAttributeStart(builder): builder.StartObject(5) def ConvAttributeAddPad(builder, pad): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(pad), 0) def ConvAttributeStartPadVector(builder, numElems): return builder.StartVector(4, numElems, 4) def ConvAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0) def ConvAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) def ConvAttributeAddDilation(builder, dilation): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(dilation), 0) def ConvAttributeStartDilationVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ConvAttributeAddInputZp(builder, inputZp): builder.PrependInt32Slot(3, inputZp, 0) +def ConvAttributeAddWeightZp(builder, weightZp): builder.PrependInt32Slot(4, weightZp, 0) def ConvAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/ConvQuantInfo.py b/python/tosa/FullyConnectedAttribute.py index 7902c67..62b480d 100644 --- a/python/tosa/ConvQuantInfo.py +++ b/python/tosa/FullyConnectedAttribute.py @@ -6,39 +6,39 @@ import flatbuffers from flatbuffers.compat import import_numpy np = import_numpy() -class ConvQuantInfo(object): +class FullyConnectedAttribute(object): __slots__ = ['_tab'] @classmethod - def GetRootAsConvQuantInfo(cls, buf, offset): + def GetRootAsFullyConnectedAttribute(cls, buf, offset): n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) - x = ConvQuantInfo() + x = FullyConnectedAttribute() x.Init(buf, n + offset) return x @classmethod - def ConvQuantInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + def FullyConnectedAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) - # ConvQuantInfo + # FullyConnectedAttribute def Init(self, buf, pos): self._tab = flatbuffers.table.Table(buf, pos) - # ConvQuantInfo + # FullyConnectedAttribute def InputZp(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) if o != 0: return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 - # ConvQuantInfo + # FullyConnectedAttribute def WeightZp(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 -def ConvQuantInfoStart(builder): builder.StartObject(2) -def ConvQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0) -def ConvQuantInfoAddWeightZp(builder, weightZp): builder.PrependInt32Slot(1, weightZp, 0) -def ConvQuantInfoEnd(builder): return builder.EndObject() +def FullyConnectedAttributeStart(builder): builder.StartObject(2) +def FullyConnectedAttributeAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0) +def FullyConnectedAttributeAddWeightZp(builder, weightZp): builder.PrependInt32Slot(1, weightZp, 0) +def FullyConnectedAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/MatMulQuantInfo.py b/python/tosa/MatMulAttribute.py index 457da66..601f13f 100644 --- a/python/tosa/MatMulQuantInfo.py +++ b/python/tosa/MatMulAttribute.py @@ -6,39 +6,39 @@ import flatbuffers from flatbuffers.compat import import_numpy np = import_numpy() -class MatMulQuantInfo(object): +class MatMulAttribute(object): __slots__ = ['_tab'] @classmethod - def GetRootAsMatMulQuantInfo(cls, buf, offset): + def GetRootAsMatMulAttribute(cls, buf, offset): n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) - x = MatMulQuantInfo() + x = MatMulAttribute() x.Init(buf, n + offset) return x @classmethod - def MatMulQuantInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + def MatMulAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) - # MatMulQuantInfo + # MatMulAttribute def Init(self, buf, pos): self._tab = flatbuffers.table.Table(buf, pos) - # MatMulQuantInfo + # MatMulAttribute def AZp(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) if o != 0: return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 - # MatMulQuantInfo + # MatMulAttribute def BZp(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 -def MatMulQuantInfoStart(builder): builder.StartObject(2) -def MatMulQuantInfoAddAZp(builder, aZp): builder.PrependInt32Slot(0, aZp, 0) -def MatMulQuantInfoAddBZp(builder, bZp): builder.PrependInt32Slot(1, bZp, 0) -def MatMulQuantInfoEnd(builder): return builder.EndObject() +def MatMulAttributeStart(builder): builder.StartObject(2) +def MatMulAttributeAddAZp(builder, aZp): builder.PrependInt32Slot(0, aZp, 0) +def MatMulAttributeAddBZp(builder, bZp): builder.PrependInt32Slot(1, bZp, 0) +def MatMulAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/UnaryQuantInfo.py b/python/tosa/NegateAttribute.py index 648fc34..24a57dc 100644 --- a/python/tosa/UnaryQuantInfo.py +++ b/python/tosa/NegateAttribute.py @@ -6,39 +6,39 @@ import flatbuffers from flatbuffers.compat import import_numpy np = import_numpy() -class UnaryQuantInfo(object): +class NegateAttribute(object): __slots__ = ['_tab'] @classmethod - def GetRootAsUnaryQuantInfo(cls, buf, offset): + def GetRootAsNegateAttribute(cls, buf, offset): n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) - x = UnaryQuantInfo() + x = NegateAttribute() x.Init(buf, n + offset) return x @classmethod - def UnaryQuantInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + def NegateAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) - # UnaryQuantInfo + # NegateAttribute def Init(self, buf, pos): self._tab = flatbuffers.table.Table(buf, pos) - # UnaryQuantInfo - def InputZp(self): + # NegateAttribute + def Input1Zp(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) if o != 0: return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 - # UnaryQuantInfo + # NegateAttribute def OutputZp(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 -def UnaryQuantInfoStart(builder): builder.StartObject(2) -def UnaryQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0) -def UnaryQuantInfoAddOutputZp(builder, outputZp): builder.PrependInt32Slot(1, outputZp, 0) -def UnaryQuantInfoEnd(builder): return builder.EndObject() +def NegateAttributeStart(builder): builder.StartObject(2) +def NegateAttributeAddInput1Zp(builder, input1Zp): builder.PrependInt32Slot(0, input1Zp, 0) +def NegateAttributeAddOutputZp(builder, outputZp): builder.PrependInt32Slot(1, outputZp, 0) +def NegateAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/PadQuantInfo.py b/python/tosa/PadQuantInfo.py deleted file mode 100644 index c07db07..0000000 --- a/python/tosa/PadQuantInfo.py +++ /dev/null @@ -1,36 +0,0 @@ -# automatically generated by the FlatBuffers compiler, do not modify - -# namespace: tosa - -import flatbuffers -from flatbuffers.compat import import_numpy -np = import_numpy() - -class PadQuantInfo(object): - __slots__ = ['_tab'] - - @classmethod - def GetRootAsPadQuantInfo(cls, buf, offset): - n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) - x = PadQuantInfo() - x.Init(buf, n + offset) - return x - - @classmethod - def PadQuantInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False): - return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) - - # PadQuantInfo - def Init(self, buf, pos): - self._tab = flatbuffers.table.Table(buf, pos) - - # PadQuantInfo - def InputZp(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) - return 0 - -def PadQuantInfoStart(builder): builder.StartObject(1) -def PadQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0) -def PadQuantInfoEnd(builder): return builder.EndObject() diff --git a/python/tosa/PoolAttribute.py b/python/tosa/PoolAttribute.py index f2c12cf..8b6903e 100644 --- a/python/tosa/PoolAttribute.py +++ b/python/tosa/PoolAttribute.py @@ -105,11 +105,27 @@ class PoolAttribute(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) return o == 0 -def PoolAttributeStart(builder): builder.StartObject(3) + # PoolAttribute + def InputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # PoolAttribute + def OutputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def PoolAttributeStart(builder): builder.StartObject(5) def PoolAttributeAddPad(builder, pad): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(pad), 0) def PoolAttributeStartPadVector(builder, numElems): return builder.StartVector(4, numElems, 4) def PoolAttributeAddKernel(builder, kernel): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernel), 0) def PoolAttributeStartKernelVector(builder, numElems): return builder.StartVector(4, numElems, 4) def PoolAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0) def PoolAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def PoolAttributeAddInputZp(builder, inputZp): builder.PrependInt32Slot(3, inputZp, 0) +def PoolAttributeAddOutputZp(builder, outputZp): builder.PrependInt32Slot(4, outputZp, 0) def PoolAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/QuantInfo.py b/python/tosa/QuantInfo.py deleted file mode 100644 index ffdfd32..0000000 --- a/python/tosa/QuantInfo.py +++ /dev/null @@ -1,11 +0,0 @@ -# automatically generated by the FlatBuffers compiler, do not modify - -# namespace: tosa - -class QuantInfo(object): - NONE = 0 - UnaryQuantInfo = 1 - ConvQuantInfo = 2 - MatMulQuantInfo = 3 - PadQuantInfo = 4 - diff --git a/python/tosa/TosaOperator.py b/python/tosa/TosaOperator.py index 040c2dc..fd11f76 100644 --- a/python/tosa/TosaOperator.py +++ b/python/tosa/TosaOperator.py @@ -88,24 +88,7 @@ class TosaOperator(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) return o == 0 - # TosaOperator - def QuantInfoType(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) - return 0 - - # TosaOperator - def QuantInfo(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) - if o != 0: - from flatbuffers.table import Table - obj = Table(bytearray(), 0) - self._tab.Union(obj, o) - return obj - return None - -def TosaOperatorStart(builder): builder.StartObject(7) +def TosaOperatorStart(builder): builder.StartObject(5) def TosaOperatorAddOp(builder, op): builder.PrependUint32Slot(0, op, 0) def TosaOperatorAddAttributeType(builder, attributeType): builder.PrependUint8Slot(1, attributeType, 0) def TosaOperatorAddAttribute(builder, attribute): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(attribute), 0) @@ -113,6 +96,4 @@ def TosaOperatorAddInputs(builder, inputs): builder.PrependUOffsetTRelativeSlot( def TosaOperatorStartInputsVector(builder, numElems): return builder.StartVector(4, numElems, 4) def TosaOperatorAddOutputs(builder, outputs): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0) def TosaOperatorStartOutputsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def TosaOperatorAddQuantInfoType(builder, quantInfoType): builder.PrependUint8Slot(5, quantInfoType, 0) -def TosaOperatorAddQuantInfo(builder, quantInfo): builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(quantInfo), 0) def TosaOperatorEnd(builder): return builder.EndObject() diff --git a/python/tosa/TransposeConvAttribute.py b/python/tosa/TransposeConvAttribute.py index 2a675a0..8ca5ba7 100644 --- a/python/tosa/TransposeConvAttribute.py +++ b/python/tosa/TransposeConvAttribute.py @@ -105,11 +105,27 @@ class TransposeConvAttribute(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) return o == 0 -def TransposeConvAttributeStart(builder): builder.StartObject(3) + # TransposeConvAttribute + def InputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # TransposeConvAttribute + def WeightZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def TransposeConvAttributeStart(builder): builder.StartObject(5) def TransposeConvAttributeAddOutPad(builder, outPad): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(outPad), 0) def TransposeConvAttributeStartOutPadVector(builder, numElems): return builder.StartVector(4, numElems, 4) def TransposeConvAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0) def TransposeConvAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) def TransposeConvAttributeAddOutputShape(builder, outputShape): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(outputShape), 0) def TransposeConvAttributeStartOutputShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TransposeConvAttributeAddInputZp(builder, inputZp): builder.PrependInt32Slot(3, inputZp, 0) +def TransposeConvAttributeAddWeightZp(builder, weightZp): builder.PrependInt32Slot(4, weightZp, 0) def TransposeConvAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/Version.py b/python/tosa/Version.py index 27dea53..bdac948 100644 --- a/python/tosa/Version.py +++ b/python/tosa/Version.py @@ -36,7 +36,7 @@ class Version(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) - return 25 + return 30 # Version def _patch(self): @@ -54,7 +54,7 @@ class Version(object): def VersionStart(builder): builder.StartObject(4) def VersionAdd_major(builder, Major): builder.PrependInt32Slot(0, Major, 0) -def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 25) +def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 30) def VersionAdd_patch(builder, Patch): builder.PrependInt32Slot(2, Patch, 0) def VersionAdd_draft(builder, Draft): builder.PrependBoolSlot(3, Draft, 1) def VersionEnd(builder): return builder.EndObject() diff --git a/schema/tosa.fbs b/schema/tosa.fbs index 7830bd2..b342103 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -157,24 +157,33 @@ union Attribute { WhileLoopAttribute, TransposeAttribute, TableAttribute, + MatMulAttribute, + FullyConnectedAttribute, + NegateAttribute } table PoolAttribute { pad: [int32]; kernel: [int32]; stride: [int32]; + input_zp: int32; + output_zp: int32; } table ConvAttribute { pad: [int32]; stride: [int32]; dilation: [int32]; + input_zp: int32; + weight_zp: int32; } table TransposeConvAttribute { out_pad: [int32]; stride: [int32]; output_shape: [int32]; + input_zp: int32; + weight_zp: int32; } table PadAttribute { @@ -253,35 +262,24 @@ table TableAttribute { table: [int16]; } -union QuantInfo { - UnaryQuantInfo, - ConvQuantInfo, - MatMulQuantInfo, - PadQuantInfo, -} - -table UnaryQuantInfo { - input_zp: int32; - output_zp: int32; +table MatMulAttribute { + a_zp: int32; + b_zp: int32; } -table ConvQuantInfo { +table FullyConnectedAttribute { input_zp: int32; weight_zp: int32; } -table MatMulQuantInfo { - a_zp: int32; - b_zp: int32; -} - -table PadQuantInfo { - input_zp: int32; +table NegateAttribute { + input1_zp: int32; + output_zp: int32; } table Version { _major: int32 = 0; - _minor: int32 = 25; + _minor: int32 = 30; _patch: int32 = 0; _draft: bool = true; } @@ -298,7 +296,6 @@ table TosaOperator { attribute: Attribute; // union structure. operator attribute inputs:[string]; // list of input tensor names outputs:[string]; // list of output tensor names - quant_info: QuantInfo; // op-based quantization information } table TosaBasicBlock { diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 32725a0..99da2e7 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -56,10 +56,7 @@ TosaSerializationTensor::TosaSerializationTensor() TosaSerializationTensor::~TosaSerializationTensor() {} -void TosaSerializationOperator::InitializeAttributeQinfo(Attribute attribute_type, - const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo) +void TosaSerializationOperator::InitializeAttribute(Attribute attribute_type, const TosaAttributeBase* attribute) { _attribute_type = attribute_type; @@ -80,32 +77,12 @@ void TosaSerializationOperator::InitializeAttributeQinfo(Attribute attribute_typ assert(0); } - _qinfo_type = qinfo_type; - switch (qinfo_type) - { - case QuantInfo_NONE: - _qinfo = new TosaNoneQuantInfo(); - break; -#define DEF_QUANTIZATION_INFO(NAME, ...) \ - case QuantInfo_##NAME##QuantInfo: \ - _qinfo = new Tosa##NAME##QuantInfo(qinfo); \ - break; -#include "quant_info.def" -#undef DEF_QUANTIZATION_INFO - default: - printf("TosaSerializationOperator::TosaSerializationOperator(): QuantInfo %s not implemented yet\n", - EnumNamesQuantInfo()[qinfo_type]); - assert(0); - } - - assert(_attribute && _qinfo); + assert(_attribute); } TosaSerializationOperator::TosaSerializationOperator(Op op, Attribute attribute_type, const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo, const std::vector<std::string>& input_tensor_names, const std::vector<std::string>& output_tensor_names) { @@ -113,14 +90,12 @@ TosaSerializationOperator::TosaSerializationOperator(Op op, _input_tensor_names = input_tensor_names; _output_tensor_names = output_tensor_names; - InitializeAttributeQinfo(attribute_type, attribute, qinfo_type, qinfo); + InitializeAttribute(attribute_type, attribute); } TosaSerializationOperator::TosaSerializationOperator(Op op, Attribute attribute_type, const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo, std::vector<std::string>&& input_tensor_names, std::vector<std::string>&& output_tensor_names) { @@ -128,13 +103,12 @@ TosaSerializationOperator::TosaSerializationOperator(Op op, _input_tensor_names = std::move(input_tensor_names); _output_tensor_names = std::move(output_tensor_names); - InitializeAttributeQinfo(attribute_type, attribute, qinfo_type, qinfo); + InitializeAttribute(attribute_type, attribute); } TosaSerializationOperator::~TosaSerializationOperator() { delete _attribute; - delete _qinfo; } TosaSerializationBasicBlock::TosaSerializationBasicBlock(const std::string& name, @@ -453,7 +427,6 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) std::vector<std::string> block_outputs_container; TosaAttributeBase* typed_attribute = NULL; - TosaQuantInfoBase* typed_qinfo = NULL; TosaSerializationOperator* new_operator = NULL; TosaSerializationBasicBlock* new_block = NULL; TosaSerializationTensor* new_tensor = NULL; @@ -491,11 +464,9 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) { auto curr_operator = fb_tosa_operators->Get(j); - auto operator_op = curr_operator->op(); - auto attribute_type = curr_operator->attribute_type(); - auto attribute = curr_operator->attribute(); - auto operator_qinfo_type = curr_operator->quant_info_type(); - auto operator_qinfo = curr_operator->quant_info(); + auto operator_op = curr_operator->op(); + auto attribute_type = curr_operator->attribute_type(); + auto attribute = curr_operator->attribute(); // input tensors auto operator_inputs = curr_operator->inputs(); @@ -538,27 +509,8 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) return TOSA_INTERNAL_ERROR; } - switch (operator_qinfo_type) - { - case QuantInfo_NONE: - typed_qinfo = new TosaNoneQuantInfo(); - break; -#define DEF_QUANTIZATION_INFO(NAME, ...) \ - case QuantInfo_##NAME##QuantInfo: \ - typed_qinfo = new Tosa##NAME##QuantInfo(operator_qinfo); \ - break; - -#include "quant_info.def" -#undef DEF_QUANTIZATION_INFO - default: - printf("TosaSerializationHandler::Deserialize(): QuantInfo %s not implemented yet\n", - EnumNamesQuantInfo()[operator_qinfo_type]); - return TOSA_INTERNAL_ERROR; - } - - new_operator = - new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, operator_qinfo_type, - typed_qinfo, operator_inputs_container, operator_outputs_container); + new_operator = new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, + operator_inputs_container, operator_outputs_container); if (new_operator) { block_operators_container.push_back(new_operator); @@ -570,8 +522,6 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) if (typed_attribute) delete typed_attribute; - if (typed_qinfo) - delete typed_qinfo; } auto fb_tosa_tensors = curr_block->tensors(); @@ -751,75 +701,8 @@ tosa_err_t TosaSerializationHandler::Serialize() return TOSA_INTERNAL_ERROR; } - auto qinfo_type = op->GetQInfoType(); - flatbuffers::Offset<void> fb_operator_qinfo; - switch (qinfo_type) - { - case QuantInfo_NONE: - fb_operator_qinfo = 0; - break; -#define DEF_ARGS_S(NAME, T, V) , reinterpret_cast<Tosa##NAME*>(op->GetQInfo())->V() -#define DEF_ARGS_V(NAME, T, V) , _builder.CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetQInfo())->V()) - -#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0) -#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) -#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) -#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) -#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ - DEF_ARGS_##F4(NAME, T4, V4) -#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ - DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) -#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ - DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) -#define DEF_ARGS_8(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ - V7) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ - DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ - DEF_ARGS_##F7(NAME, T7, V7) -#define DEF_ARGS_9(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ - V7, T8, F8, V8) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ - DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ - DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) -#define DEF_ARGS_10(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ - V7, T8, F8, V8, T9, F9, V9) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ - DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ - DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) DEF_ARGS_##F9(NAME, T9, V9) -#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \ - case QuantInfo_##NAME##QuantInfo: \ - fb_operator_qinfo = \ - Create##NAME##QuantInfo(_builder DEF_ARGS_##NUM_ARGS(NAME##QuantInfo, __VA_ARGS__)).Union(); \ - break; - -#include "quant_info.def" -#undef DEF_QUANTIZATION_INFO -#undef DEF_ARGS_1 -#undef DEF_ARGS_2 -#undef DEF_ARGS_3 -#undef DEF_ARGS_4 -#undef DEF_ARGS_5 -#undef DEF_ARGS_6 -#undef DEF_ARGS_7 -#undef DEF_ARGS_8 -#undef DEF_ARGS_9 -#undef DEF_ARGS_10 -#undef DEF_ARGS_S -#undef DEF_ARGS_V - default: - printf("TosaSerializationHandler::Serialize(): Attribute %s not implemented yet\n", - EnumNamesAttribute()[attribute_type]); - return TOSA_INTERNAL_ERROR; - } - - auto fboffset_operator = - CreateTosaOperator(_builder, operator_op, attribute_type, fb_attribute, fb_operator_inputs, - fb_operator_outputs, qinfo_type, fb_operator_qinfo); + auto fboffset_operator = CreateTosaOperator(_builder, operator_op, attribute_type, fb_attribute, + fb_operator_inputs, fb_operator_outputs); fboffset_block_operators.push_back(fboffset_operator); } |