From bdcc3fee1b8bf55aac50e060115b92a1ccf9741c Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 7 Jun 2022 05:17:37 +0000 Subject: Remove quantinfo types Any needed information has been moved into the attributes for each operator. This aligns with the structure of the attributes in the TOSA specification, and generally simplifies the code. Change-Id: I8243e91b09de1a9115f8af09c5e7def7e8f2866b Signed-off-by: Eric Kunze --- include/attribute.def | 29 +- include/quant_info.def | 43 --- include/quant_info.h | 163 ----------- include/tosa_generated.h | 486 ++++++++++++++------------------- include/tosa_serialization_handler.h | 22 +- python/serializer/tosa_serializer.py | 192 +++++++------ python/tosa/Attribute.py | 3 + python/tosa/ConvAttribute.py | 18 +- python/tosa/ConvQuantInfo.py | 44 --- python/tosa/FullyConnectedAttribute.py | 44 +++ python/tosa/MatMulAttribute.py | 44 +++ python/tosa/MatMulQuantInfo.py | 44 --- python/tosa/NegateAttribute.py | 44 +++ python/tosa/PadQuantInfo.py | 36 --- python/tosa/PoolAttribute.py | 18 +- python/tosa/QuantInfo.py | 11 - python/tosa/TosaOperator.py | 21 +- python/tosa/TransposeConvAttribute.py | 18 +- python/tosa/UnaryQuantInfo.py | 44 --- python/tosa/Version.py | 4 +- schema/tosa.fbs | 37 ++- src/tosa_serialization_handler.cpp | 139 +--------- 22 files changed, 564 insertions(+), 940 deletions(-) delete mode 100644 include/quant_info.def delete mode 100644 include/quant_info.h delete mode 100644 python/tosa/ConvQuantInfo.py create mode 100644 python/tosa/FullyConnectedAttribute.py create mode 100644 python/tosa/MatMulAttribute.py delete mode 100644 python/tosa/MatMulQuantInfo.py create mode 100644 python/tosa/NegateAttribute.py delete mode 100644 python/tosa/PadQuantInfo.py delete mode 100644 python/tosa/QuantInfo.py delete mode 100644 python/tosa/UnaryQuantInfo.py diff --git a/include/attribute.def b/include/attribute.def index 8ac5d8d..ea91869 100644 --- a/include/attribute.def +++ b/include/attribute.def @@ -26,20 +26,26 @@ ...: variadic variables for more arguments, depending on NUM_ARGS_IN_ATTRIBUTES */ -DEF_ATTRIBUTE(Pool, 3, +DEF_ATTRIBUTE(Pool, 5, int32_t, V, pad, int32_t, V, kernel, - int32_t, V, stride) + int32_t, V, stride, + int32_t, S, input_zp, + int32_t, S, output_zp) -DEF_ATTRIBUTE(Conv, 3, +DEF_ATTRIBUTE(Conv, 5, int32_t, V, pad, int32_t, V, stride, - int32_t, V, dilation) + int32_t, V, dilation, + int32_t, S, input_zp, + int32_t, S, weight_zp) -DEF_ATTRIBUTE(TransposeConv, 3, +DEF_ATTRIBUTE(TransposeConv, 5, int32_t, V, out_pad, int32_t, V, stride, - int32_t, V, output_shape) + int32_t, V, output_shape, + int32_t, S, input_zp, + int32_t, S, weight_zp) DEF_ATTRIBUTE(Pad, 3, int32_t, V, padding, @@ -103,3 +109,14 @@ DEF_ATTRIBUTE(Transpose, 1, DEF_ATTRIBUTE(Table, 1, int16_t, V, table) +DEF_ATTRIBUTE(MatMul, 2, + int32_t, S, a_zp, + int32_t, S, b_zp) + +DEF_ATTRIBUTE(FullyConnected, 2, + int32_t, S, input_zp, + int32_t, S, weight_zp) + +DEF_ATTRIBUTE(Negate, 2, + int32_t, S, input1_zp, + int32_t, S, output_zp) diff --git a/include/quant_info.def b/include/quant_info.def deleted file mode 100644 index 888c183..0000000 --- a/include/quant_info.def +++ /dev/null @@ -1,43 +0,0 @@ - -// Copyright (c) 2020-2021, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* - Syntax: - DEF_QUANTIZATION_INFO(NAME, NUM_ARGS_IN_OPTIONS, ARG0_TYPE, ARG0_SCALAR_OR_VECTOR, ARGS0_NAME, ...) - - Description: - NAME: corresponding quantization info name, must match corresponding "table XXXQuantInfo" in tosa.fbs - NUM_ARGS_IN_QINFO: number of arguments in this quantization info - ARG0_TYPE: data type of arg0 - ARG0_SCALAR_OR_VECTOR: is arg0 a scalar (S) or a vector (V) - ARG0_NAME: name of arg0 - ...: variadic variables for more arguments, depending on NUM_ARGS_IN_QINFO -*/ - - -DEF_QUANTIZATION_INFO(Unary, 2, - int32_t, S, input_zp, - int32_t, S, output_zp) - -DEF_QUANTIZATION_INFO(Conv, 2, - int32_t, S, input_zp, - int32_t, S, weight_zp) - -DEF_QUANTIZATION_INFO(MatMul, 2, - int32_t, S, a_zp, - int32_t, S, b_zp) - -DEF_QUANTIZATION_INFO(Pad, 1, - int32_t, S, input_zp) diff --git a/include/quant_info.h b/include/quant_info.h deleted file mode 100644 index c7daeb2..0000000 --- a/include/quant_info.h +++ /dev/null @@ -1,163 +0,0 @@ - -// Copyright (c) 2020-2021, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef _TOSA_SERIALIZATION_QUANT_INFO_H -#define _TOSA_SERIALIZATION_QUANT_INFO_H -#include "flatbuffers/idl.h" -#include "flatbuffers/util.h" -#include "tosa_generated.h" - -namespace tosa -{ - -class TosaQuantInfoBase -{ -public: - virtual ~TosaQuantInfoBase() - {} -}; - -class TosaNoneQuantInfo : public TosaQuantInfoBase -{ -public: - TosaNoneQuantInfo() - {} - TosaNoneQuantInfo(TosaNoneQuantInfo* p) - {} -}; - -#define DEF_ARGS_VER0_S(T, V) _##V = p->V(); -#define DEF_ARGS_VER0_V(T, V) _##V = std::vector(p->V()->begin(), p->V()->end()); -#define DEF_ARGS_VER1_S(T, V) const T& V -#define DEF_ARGS_VER1_V(T, V) const std::vector& V -#define DEF_ARGS_VER2_S(T, V) _##V = V; -#define DEF_ARGS_VER2_V(T, V) _##V = V; -#define DEF_ARGS_VER3_S(T, V) \ - T V() const \ - { \ - return _##V; \ - } -#define DEF_ARGS_VER3_V(T, V) \ - std::vector V() const \ - { \ - return _##V; \ - } -#define DEF_ARGS_VER4_S(T, V) T _##V; -#define DEF_ARGS_VER4_V(T, V) std::vector _##V; - -// another level of preprocessor indirection to handle ", " as function's input argument -#define DEF_ARGS_VER1_TRUE(T, F, V) DEF_ARGS_VER1_##F(T, V) -#define DEF_ARGS_VER1_FALSE(T, F, V) , DEF_ARGS_VER1_##F(T, V) - -#define DEF_ARGS_VER0(FIRST, T, F, V) DEF_ARGS_VER0_##F(T, V) -#define DEF_ARGS_VER1(FIRST, T, F, V) DEF_ARGS_VER1_##FIRST(T, F, V) -#define DEF_ARGS_VER2(FIRST, T, F, V) DEF_ARGS_VER2_##F(T, V) -#define DEF_ARGS_VER3(FIRST, T, F, V) DEF_ARGS_VER3_##F(T, V) -#define DEF_ARGS_VER4(FIRST, T, F, V) DEF_ARGS_VER4_##F(T, V) - -#define DEF_ARGS_1(VER, T0, F0, V0) DEF_ARGS_##VER(TRUE, T0, F0, V0) -#define DEF_ARGS_2(VER, T0, F0, V0, T1, F1, V1) DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) -#define DEF_ARGS_3(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) -#define DEF_ARGS_4(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) -#define DEF_ARGS_5(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) -#define DEF_ARGS_6(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) -#define DEF_ARGS_7(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ - DEF_ARGS_##VER(FALSE, T6, F6, V6) -#define DEF_ARGS_8(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ - V7) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ - DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) -#define DEF_ARGS_9(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ - V7, T8, F8, V8) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ - DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) DEF_ARGS_##VER(FALSE, T8, F8, V8) -#define DEF_ARGS_10(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ - V7, T8, F8, V8, T9, F9, V9) \ - DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ - DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ - DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) DEF_ARGS_##VER(FALSE, T8, F8, V8) \ - DEF_ARGS_##VER(FALSE, T9, F9, V9) - -#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \ - class Tosa##NAME##QuantInfo : public TosaQuantInfoBase \ - { \ - public: \ - Tosa##NAME##QuantInfo(const TosaQuantInfoBase* qinfo) \ - { \ - const Tosa##NAME##QuantInfo* p = static_cast(qinfo); \ - *this = *p; \ - } \ - Tosa##NAME##QuantInfo(const Tosa##NAME##QuantInfo* p) \ - { \ - *this = *p; \ - } \ - Tosa##NAME##QuantInfo(const void* qinfo) \ - { \ - const NAME##QuantInfo* p = static_cast(qinfo); \ - DEF_ARGS_##NUM_ARGS(VER0, __VA_ARGS__) \ - } \ - Tosa##NAME##QuantInfo(DEF_ARGS_##NUM_ARGS(VER1, __VA_ARGS__)) \ - { \ - DEF_ARGS_##NUM_ARGS(VER2, __VA_ARGS__) \ - } \ - virtual ~Tosa##NAME##QuantInfo() \ - {} \ - DEF_ARGS_##NUM_ARGS(VER3, __VA_ARGS__) private : DEF_ARGS_##NUM_ARGS(VER4, __VA_ARGS__) \ - }; - -#include "quant_info.def" -#undef DEF_QUANTIZATION_INFO -#undef DEF_ARGS_1 -#undef DEF_ARGS_2 -#undef DEF_ARGS_3 -#undef DEF_ARGS_4 -#undef DEF_ARGS_5 -#undef DEF_ARGS_6 -#undef DEF_ARGS_7 -#undef DEF_ARGS_8 -#undef DEF_ARGS_9 -#undef DEF_ARGS_10 -#undef DEF_ARGS_VER0 -#undef DEF_ARGS_VER1 -#undef DEF_ARGS_VER2 -#undef DEF_ARGS_VER3 -#undef DEF_ARGS_VER4 -#undef DEF_ARGS_VER1_TRUE -#undef DEF_ARGS_VER1_FALSE -#undef DEF_ARGS_VER0_S -#undef DEF_ARGS_VER0_V -#undef DEF_ARGS_VER1_S -#undef DEF_ARGS_VER1_V -#undef DEF_ARGS_VER2_S -#undef DEF_ARGS_VER2_V -#undef DEF_ARGS_VER3_S -#undef DEF_ARGS_VER3_V -#undef DEF_ARGS_VER4_S -#undef DEF_ARGS_VER4_V - -} // namespace tosa - -#endif diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 58fe4a7..bd56e4e 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -59,17 +59,14 @@ struct TransposeAttributeBuilder; struct TableAttribute; struct TableAttributeBuilder; -struct UnaryQuantInfo; -struct UnaryQuantInfoBuilder; +struct MatMulAttribute; +struct MatMulAttributeBuilder; -struct ConvQuantInfo; -struct ConvQuantInfoBuilder; +struct FullyConnectedAttribute; +struct FullyConnectedAttributeBuilder; -struct MatMulQuantInfo; -struct MatMulQuantInfoBuilder; - -struct PadQuantInfo; -struct PadQuantInfoBuilder; +struct NegateAttribute; +struct NegateAttributeBuilder; struct Version; struct VersionBuilder; @@ -423,11 +420,14 @@ enum Attribute { Attribute_WhileLoopAttribute = 15, Attribute_TransposeAttribute = 16, Attribute_TableAttribute = 17, + Attribute_MatMulAttribute = 18, + Attribute_FullyConnectedAttribute = 19, + Attribute_NegateAttribute = 20, Attribute_MIN = Attribute_NONE, - Attribute_MAX = Attribute_TableAttribute + Attribute_MAX = Attribute_NegateAttribute }; -inline const Attribute (&EnumValuesAttribute())[18] { +inline const Attribute (&EnumValuesAttribute())[21] { static const Attribute values[] = { Attribute_NONE, Attribute_PoolAttribute, @@ -446,13 +446,16 @@ inline const Attribute (&EnumValuesAttribute())[18] { Attribute_CondIfAttribute, Attribute_WhileLoopAttribute, Attribute_TransposeAttribute, - Attribute_TableAttribute + Attribute_TableAttribute, + Attribute_MatMulAttribute, + Attribute_FullyConnectedAttribute, + Attribute_NegateAttribute }; return values; } inline const char * const *EnumNamesAttribute() { - static const char * const names[19] = { + static const char * const names[22] = { "NONE", "PoolAttribute", "ConvAttribute", @@ -471,13 +474,16 @@ inline const char * const *EnumNamesAttribute() { "WhileLoopAttribute", "TransposeAttribute", "TableAttribute", + "MatMulAttribute", + "FullyConnectedAttribute", + "NegateAttribute", nullptr }; return names; } inline const char *EnumNameAttribute(Attribute e) { - if (flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_TableAttribute)) return ""; + if (flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_NegateAttribute)) return ""; const size_t index = static_cast(e); return EnumNamesAttribute()[index]; } @@ -554,77 +560,29 @@ template<> struct AttributeTraits { static const Attribute enum_value = Attribute_TableAttribute; }; -bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type); -bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); - -enum QuantInfo { - QuantInfo_NONE = 0, - QuantInfo_UnaryQuantInfo = 1, - QuantInfo_ConvQuantInfo = 2, - QuantInfo_MatMulQuantInfo = 3, - QuantInfo_PadQuantInfo = 4, - QuantInfo_MIN = QuantInfo_NONE, - QuantInfo_MAX = QuantInfo_PadQuantInfo -}; - -inline const QuantInfo (&EnumValuesQuantInfo())[5] { - static const QuantInfo values[] = { - QuantInfo_NONE, - QuantInfo_UnaryQuantInfo, - QuantInfo_ConvQuantInfo, - QuantInfo_MatMulQuantInfo, - QuantInfo_PadQuantInfo - }; - return values; -} - -inline const char * const *EnumNamesQuantInfo() { - static const char * const names[6] = { - "NONE", - "UnaryQuantInfo", - "ConvQuantInfo", - "MatMulQuantInfo", - "PadQuantInfo", - nullptr - }; - return names; -} - -inline const char *EnumNameQuantInfo(QuantInfo e) { - if (flatbuffers::IsOutRange(e, QuantInfo_NONE, QuantInfo_PadQuantInfo)) return ""; - const size_t index = static_cast(e); - return EnumNamesQuantInfo()[index]; -} - -template struct QuantInfoTraits { - static const QuantInfo enum_value = QuantInfo_NONE; -}; - -template<> struct QuantInfoTraits { - static const QuantInfo enum_value = QuantInfo_UnaryQuantInfo; +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_MatMulAttribute; }; -template<> struct QuantInfoTraits { - static const QuantInfo enum_value = QuantInfo_ConvQuantInfo; +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_FullyConnectedAttribute; }; -template<> struct QuantInfoTraits { - static const QuantInfo enum_value = QuantInfo_MatMulQuantInfo; +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_NegateAttribute; }; -template<> struct QuantInfoTraits { - static const QuantInfo enum_value = QuantInfo_PadQuantInfo; -}; - -bool VerifyQuantInfo(flatbuffers::Verifier &verifier, const void *obj, QuantInfo type); -bool VerifyQuantInfoVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); +bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type); +bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); struct PoolAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef PoolAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_PAD = 4, VT_KERNEL = 6, - VT_STRIDE = 8 + VT_STRIDE = 8, + VT_INPUT_ZP = 10, + VT_OUTPUT_ZP = 12 }; const flatbuffers::Vector *pad() const { return GetPointer *>(VT_PAD); @@ -635,6 +593,12 @@ struct PoolAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *stride() const { return GetPointer *>(VT_STRIDE); } + int32_t input_zp() const { + return GetField(VT_INPUT_ZP, 0); + } + int32_t output_zp() const { + return GetField(VT_OUTPUT_ZP, 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PAD) && @@ -643,6 +607,8 @@ struct PoolAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyVector(kernel()) && VerifyOffset(verifier, VT_STRIDE) && verifier.VerifyVector(stride()) && + VerifyField(verifier, VT_INPUT_ZP) && + VerifyField(verifier, VT_OUTPUT_ZP) && verifier.EndTable(); } }; @@ -660,6 +626,12 @@ struct PoolAttributeBuilder { void add_stride(flatbuffers::Offset> stride) { fbb_.AddOffset(PoolAttribute::VT_STRIDE, stride); } + void add_input_zp(int32_t input_zp) { + fbb_.AddElement(PoolAttribute::VT_INPUT_ZP, input_zp, 0); + } + void add_output_zp(int32_t output_zp) { + fbb_.AddElement(PoolAttribute::VT_OUTPUT_ZP, output_zp, 0); + } explicit PoolAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -676,8 +648,12 @@ inline flatbuffers::Offset CreatePoolAttribute( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> pad = 0, flatbuffers::Offset> kernel = 0, - flatbuffers::Offset> stride = 0) { + flatbuffers::Offset> stride = 0, + int32_t input_zp = 0, + int32_t output_zp = 0) { PoolAttributeBuilder builder_(_fbb); + builder_.add_output_zp(output_zp); + builder_.add_input_zp(input_zp); builder_.add_stride(stride); builder_.add_kernel(kernel); builder_.add_pad(pad); @@ -688,7 +664,9 @@ inline flatbuffers::Offset CreatePoolAttributeDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *pad = nullptr, const std::vector *kernel = nullptr, - const std::vector *stride = nullptr) { + const std::vector *stride = nullptr, + int32_t input_zp = 0, + int32_t output_zp = 0) { auto pad__ = pad ? _fbb.CreateVector(*pad) : 0; auto kernel__ = kernel ? _fbb.CreateVector(*kernel) : 0; auto stride__ = stride ? _fbb.CreateVector(*stride) : 0; @@ -696,7 +674,9 @@ inline flatbuffers::Offset CreatePoolAttributeDirect( _fbb, pad__, kernel__, - stride__); + stride__, + input_zp, + output_zp); } struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -704,7 +684,9 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_PAD = 4, VT_STRIDE = 6, - VT_DILATION = 8 + VT_DILATION = 8, + VT_INPUT_ZP = 10, + VT_WEIGHT_ZP = 12 }; const flatbuffers::Vector *pad() const { return GetPointer *>(VT_PAD); @@ -715,6 +697,12 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *dilation() const { return GetPointer *>(VT_DILATION); } + int32_t input_zp() const { + return GetField(VT_INPUT_ZP, 0); + } + int32_t weight_zp() const { + return GetField(VT_WEIGHT_ZP, 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PAD) && @@ -723,6 +711,8 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyVector(stride()) && VerifyOffset(verifier, VT_DILATION) && verifier.VerifyVector(dilation()) && + VerifyField(verifier, VT_INPUT_ZP) && + VerifyField(verifier, VT_WEIGHT_ZP) && verifier.EndTable(); } }; @@ -740,6 +730,12 @@ struct ConvAttributeBuilder { void add_dilation(flatbuffers::Offset> dilation) { fbb_.AddOffset(ConvAttribute::VT_DILATION, dilation); } + void add_input_zp(int32_t input_zp) { + fbb_.AddElement(ConvAttribute::VT_INPUT_ZP, input_zp, 0); + } + void add_weight_zp(int32_t weight_zp) { + fbb_.AddElement(ConvAttribute::VT_WEIGHT_ZP, weight_zp, 0); + } explicit ConvAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -756,8 +752,12 @@ inline flatbuffers::Offset CreateConvAttribute( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> pad = 0, flatbuffers::Offset> stride = 0, - flatbuffers::Offset> dilation = 0) { + flatbuffers::Offset> dilation = 0, + int32_t input_zp = 0, + int32_t weight_zp = 0) { ConvAttributeBuilder builder_(_fbb); + builder_.add_weight_zp(weight_zp); + builder_.add_input_zp(input_zp); builder_.add_dilation(dilation); builder_.add_stride(stride); builder_.add_pad(pad); @@ -768,7 +768,9 @@ inline flatbuffers::Offset CreateConvAttributeDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *pad = nullptr, const std::vector *stride = nullptr, - const std::vector *dilation = nullptr) { + const std::vector *dilation = nullptr, + int32_t input_zp = 0, + int32_t weight_zp = 0) { auto pad__ = pad ? _fbb.CreateVector(*pad) : 0; auto stride__ = stride ? _fbb.CreateVector(*stride) : 0; auto dilation__ = dilation ? _fbb.CreateVector(*dilation) : 0; @@ -776,7 +778,9 @@ inline flatbuffers::Offset CreateConvAttributeDirect( _fbb, pad__, stride__, - dilation__); + dilation__, + input_zp, + weight_zp); } struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -784,7 +788,9 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_OUT_PAD = 4, VT_STRIDE = 6, - VT_OUTPUT_SHAPE = 8 + VT_OUTPUT_SHAPE = 8, + VT_INPUT_ZP = 10, + VT_WEIGHT_ZP = 12 }; const flatbuffers::Vector *out_pad() const { return GetPointer *>(VT_OUT_PAD); @@ -795,6 +801,12 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab const flatbuffers::Vector *output_shape() const { return GetPointer *>(VT_OUTPUT_SHAPE); } + int32_t input_zp() const { + return GetField(VT_INPUT_ZP, 0); + } + int32_t weight_zp() const { + return GetField(VT_WEIGHT_ZP, 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_OUT_PAD) && @@ -803,6 +815,8 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab verifier.VerifyVector(stride()) && VerifyOffset(verifier, VT_OUTPUT_SHAPE) && verifier.VerifyVector(output_shape()) && + VerifyField(verifier, VT_INPUT_ZP) && + VerifyField(verifier, VT_WEIGHT_ZP) && verifier.EndTable(); } }; @@ -820,6 +834,12 @@ struct TransposeConvAttributeBuilder { void add_output_shape(flatbuffers::Offset> output_shape) { fbb_.AddOffset(TransposeConvAttribute::VT_OUTPUT_SHAPE, output_shape); } + void add_input_zp(int32_t input_zp) { + fbb_.AddElement(TransposeConvAttribute::VT_INPUT_ZP, input_zp, 0); + } + void add_weight_zp(int32_t weight_zp) { + fbb_.AddElement(TransposeConvAttribute::VT_WEIGHT_ZP, weight_zp, 0); + } explicit TransposeConvAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -836,8 +856,12 @@ inline flatbuffers::Offset CreateTransposeConvAttribute( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> out_pad = 0, flatbuffers::Offset> stride = 0, - flatbuffers::Offset> output_shape = 0) { + flatbuffers::Offset> output_shape = 0, + int32_t input_zp = 0, + int32_t weight_zp = 0) { TransposeConvAttributeBuilder builder_(_fbb); + builder_.add_weight_zp(weight_zp); + builder_.add_input_zp(input_zp); builder_.add_output_shape(output_shape); builder_.add_stride(stride); builder_.add_out_pad(out_pad); @@ -848,7 +872,9 @@ inline flatbuffers::Offset CreateTransposeConvAttributeD flatbuffers::FlatBufferBuilder &_fbb, const std::vector *out_pad = nullptr, const std::vector *stride = nullptr, - const std::vector *output_shape = nullptr) { + const std::vector *output_shape = nullptr, + int32_t input_zp = 0, + int32_t weight_zp = 0) { auto out_pad__ = out_pad ? _fbb.CreateVector(*out_pad) : 0; auto stride__ = stride ? _fbb.CreateVector(*stride) : 0; auto output_shape__ = output_shape ? _fbb.CreateVector(*output_shape) : 0; @@ -856,7 +882,9 @@ inline flatbuffers::Offset CreateTransposeConvAttributeD _fbb, out_pad__, stride__, - output_shape__); + output_shape__, + input_zp, + weight_zp); } struct PadAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -1797,60 +1825,60 @@ inline flatbuffers::Offset CreateTableAttributeDirect( table__); } -struct UnaryQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef UnaryQuantInfoBuilder Builder; +struct MatMulAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef MatMulAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_INPUT_ZP = 4, - VT_OUTPUT_ZP = 6 + VT_A_ZP = 4, + VT_B_ZP = 6 }; - int32_t input_zp() const { - return GetField(VT_INPUT_ZP, 0); + int32_t a_zp() const { + return GetField(VT_A_ZP, 0); } - int32_t output_zp() const { - return GetField(VT_OUTPUT_ZP, 0); + int32_t b_zp() const { + return GetField(VT_B_ZP, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_INPUT_ZP) && - VerifyField(verifier, VT_OUTPUT_ZP) && + VerifyField(verifier, VT_A_ZP) && + VerifyField(verifier, VT_B_ZP) && verifier.EndTable(); } }; -struct UnaryQuantInfoBuilder { - typedef UnaryQuantInfo Table; +struct MatMulAttributeBuilder { + typedef MatMulAttribute Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_input_zp(int32_t input_zp) { - fbb_.AddElement(UnaryQuantInfo::VT_INPUT_ZP, input_zp, 0); + void add_a_zp(int32_t a_zp) { + fbb_.AddElement(MatMulAttribute::VT_A_ZP, a_zp, 0); } - void add_output_zp(int32_t output_zp) { - fbb_.AddElement(UnaryQuantInfo::VT_OUTPUT_ZP, output_zp, 0); + void add_b_zp(int32_t b_zp) { + fbb_.AddElement(MatMulAttribute::VT_B_ZP, b_zp, 0); } - explicit UnaryQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit MatMulAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - UnaryQuantInfoBuilder &operator=(const UnaryQuantInfoBuilder &); - flatbuffers::Offset Finish() { + MatMulAttributeBuilder &operator=(const MatMulAttributeBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateUnaryQuantInfo( +inline flatbuffers::Offset CreateMatMulAttribute( flatbuffers::FlatBufferBuilder &_fbb, - int32_t input_zp = 0, - int32_t output_zp = 0) { - UnaryQuantInfoBuilder builder_(_fbb); - builder_.add_output_zp(output_zp); - builder_.add_input_zp(input_zp); + int32_t a_zp = 0, + int32_t b_zp = 0) { + MatMulAttributeBuilder builder_(_fbb); + builder_.add_b_zp(b_zp); + builder_.add_a_zp(a_zp); return builder_.Finish(); } -struct ConvQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef ConvQuantInfoBuilder Builder; +struct FullyConnectedAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef FullyConnectedAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_INPUT_ZP = 4, VT_WEIGHT_ZP = 6 @@ -1869,129 +1897,87 @@ struct ConvQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } }; -struct ConvQuantInfoBuilder { - typedef ConvQuantInfo Table; +struct FullyConnectedAttributeBuilder { + typedef FullyConnectedAttribute Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_input_zp(int32_t input_zp) { - fbb_.AddElement(ConvQuantInfo::VT_INPUT_ZP, input_zp, 0); + fbb_.AddElement(FullyConnectedAttribute::VT_INPUT_ZP, input_zp, 0); } void add_weight_zp(int32_t weight_zp) { - fbb_.AddElement(ConvQuantInfo::VT_WEIGHT_ZP, weight_zp, 0); + fbb_.AddElement(FullyConnectedAttribute::VT_WEIGHT_ZP, weight_zp, 0); } - explicit ConvQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit FullyConnectedAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - ConvQuantInfoBuilder &operator=(const ConvQuantInfoBuilder &); - flatbuffers::Offset Finish() { + FullyConnectedAttributeBuilder &operator=(const FullyConnectedAttributeBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateConvQuantInfo( +inline flatbuffers::Offset CreateFullyConnectedAttribute( flatbuffers::FlatBufferBuilder &_fbb, int32_t input_zp = 0, int32_t weight_zp = 0) { - ConvQuantInfoBuilder builder_(_fbb); + FullyConnectedAttributeBuilder builder_(_fbb); builder_.add_weight_zp(weight_zp); builder_.add_input_zp(input_zp); return builder_.Finish(); } -struct MatMulQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef MatMulQuantInfoBuilder Builder; +struct NegateAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef NegateAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_A_ZP = 4, - VT_B_ZP = 6 + VT_INPUT1_ZP = 4, + VT_OUTPUT_ZP = 6 }; - int32_t a_zp() const { - return GetField(VT_A_ZP, 0); + int32_t input1_zp() const { + return GetField(VT_INPUT1_ZP, 0); } - int32_t b_zp() const { - return GetField(VT_B_ZP, 0); + int32_t output_zp() const { + return GetField(VT_OUTPUT_ZP, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_A_ZP) && - VerifyField(verifier, VT_B_ZP) && + VerifyField(verifier, VT_INPUT1_ZP) && + VerifyField(verifier, VT_OUTPUT_ZP) && verifier.EndTable(); } }; -struct MatMulQuantInfoBuilder { - typedef MatMulQuantInfo Table; +struct NegateAttributeBuilder { + typedef NegateAttribute Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_a_zp(int32_t a_zp) { - fbb_.AddElement(MatMulQuantInfo::VT_A_ZP, a_zp, 0); - } - void add_b_zp(int32_t b_zp) { - fbb_.AddElement(MatMulQuantInfo::VT_B_ZP, b_zp, 0); - } - explicit MatMulQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { - start_ = fbb_.StartTable(); - } - MatMulQuantInfoBuilder &operator=(const MatMulQuantInfoBuilder &); - flatbuffers::Offset Finish() { - const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); - return o; - } -}; - -inline flatbuffers::Offset CreateMatMulQuantInfo( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t a_zp = 0, - int32_t b_zp = 0) { - MatMulQuantInfoBuilder builder_(_fbb); - builder_.add_b_zp(b_zp); - builder_.add_a_zp(a_zp); - return builder_.Finish(); -} - -struct PadQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef PadQuantInfoBuilder Builder; - enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_INPUT_ZP = 4 - }; - int32_t input_zp() const { - return GetField(VT_INPUT_ZP, 0); - } - bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyField(verifier, VT_INPUT_ZP) && - verifier.EndTable(); + void add_input1_zp(int32_t input1_zp) { + fbb_.AddElement(NegateAttribute::VT_INPUT1_ZP, input1_zp, 0); } -}; - -struct PadQuantInfoBuilder { - typedef PadQuantInfo Table; - flatbuffers::FlatBufferBuilder &fbb_; - flatbuffers::uoffset_t start_; - void add_input_zp(int32_t input_zp) { - fbb_.AddElement(PadQuantInfo::VT_INPUT_ZP, input_zp, 0); + void add_output_zp(int32_t output_zp) { + fbb_.AddElement(NegateAttribute::VT_OUTPUT_ZP, output_zp, 0); } - explicit PadQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit NegateAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - PadQuantInfoBuilder &operator=(const PadQuantInfoBuilder &); - flatbuffers::Offset Finish() { + NegateAttributeBuilder &operator=(const NegateAttributeBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreatePadQuantInfo( +inline flatbuffers::Offset CreateNegateAttribute( flatbuffers::FlatBufferBuilder &_fbb, - int32_t input_zp = 0) { - PadQuantInfoBuilder builder_(_fbb); - builder_.add_input_zp(input_zp); + int32_t input1_zp = 0, + int32_t output_zp = 0) { + NegateAttributeBuilder builder_(_fbb); + builder_.add_output_zp(output_zp); + builder_.add_input1_zp(input1_zp); return builder_.Finish(); } @@ -2007,7 +1993,7 @@ struct Version FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetField(VT__MAJOR, 0); } int32_t _minor() const { - return GetField(VT__MINOR, 25); + return GetField(VT__MINOR, 30); } int32_t _patch() const { return GetField(VT__PATCH, 0); @@ -2033,7 +2019,7 @@ struct VersionBuilder { fbb_.AddElement(Version::VT__MAJOR, _major, 0); } void add__minor(int32_t _minor) { - fbb_.AddElement(Version::VT__MINOR, _minor, 25); + fbb_.AddElement(Version::VT__MINOR, _minor, 30); } void add__patch(int32_t _patch) { fbb_.AddElement(Version::VT__PATCH, _patch, 0); @@ -2056,7 +2042,7 @@ struct VersionBuilder { inline flatbuffers::Offset CreateVersion( flatbuffers::FlatBufferBuilder &_fbb, int32_t _major = 0, - int32_t _minor = 25, + int32_t _minor = 30, int32_t _patch = 0, bool _draft = true) { VersionBuilder builder_(_fbb); @@ -2167,9 +2153,7 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_ATTRIBUTE_TYPE = 6, VT_ATTRIBUTE = 8, VT_INPUTS = 10, - VT_OUTPUTS = 12, - VT_QUANT_INFO_TYPE = 14, - VT_QUANT_INFO = 16 + VT_OUTPUTS = 12 }; tosa::Op op() const { return static_cast(GetField(VT_OP, 0)); @@ -2232,31 +2216,21 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const tosa::TableAttribute *attribute_as_TableAttribute() const { return attribute_type() == tosa::Attribute_TableAttribute ? static_cast(attribute()) : nullptr; } + const tosa::MatMulAttribute *attribute_as_MatMulAttribute() const { + return attribute_type() == tosa::Attribute_MatMulAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::FullyConnectedAttribute *attribute_as_FullyConnectedAttribute() const { + return attribute_type() == tosa::Attribute_FullyConnectedAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::NegateAttribute *attribute_as_NegateAttribute() const { + return attribute_type() == tosa::Attribute_NegateAttribute ? static_cast(attribute()) : nullptr; + } const flatbuffers::Vector> *inputs() const { return GetPointer> *>(VT_INPUTS); } const flatbuffers::Vector> *outputs() const { return GetPointer> *>(VT_OUTPUTS); } - tosa::QuantInfo quant_info_type() const { - return static_cast(GetField(VT_QUANT_INFO_TYPE, 0)); - } - const void *quant_info() const { - return GetPointer(VT_QUANT_INFO); - } - template const T *quant_info_as() const; - const tosa::UnaryQuantInfo *quant_info_as_UnaryQuantInfo() const { - return quant_info_type() == tosa::QuantInfo_UnaryQuantInfo ? static_cast(quant_info()) : nullptr; - } - const tosa::ConvQuantInfo *quant_info_as_ConvQuantInfo() const { - return quant_info_type() == tosa::QuantInfo_ConvQuantInfo ? static_cast(quant_info()) : nullptr; - } - const tosa::MatMulQuantInfo *quant_info_as_MatMulQuantInfo() const { - return quant_info_type() == tosa::QuantInfo_MatMulQuantInfo ? static_cast(quant_info()) : nullptr; - } - const tosa::PadQuantInfo *quant_info_as_PadQuantInfo() const { - return quant_info_type() == tosa::QuantInfo_PadQuantInfo ? static_cast(quant_info()) : nullptr; - } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_OP) && @@ -2269,9 +2243,6 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_OUTPUTS) && verifier.VerifyVector(outputs()) && verifier.VerifyVectorOfStrings(outputs()) && - VerifyField(verifier, VT_QUANT_INFO_TYPE) && - VerifyOffset(verifier, VT_QUANT_INFO) && - VerifyQuantInfo(verifier, quant_info(), quant_info_type()) && verifier.EndTable(); } }; @@ -2344,20 +2315,16 @@ template<> inline const tosa::TableAttribute *TosaOperator::attribute_as inline const tosa::UnaryQuantInfo *TosaOperator::quant_info_as() const { - return quant_info_as_UnaryQuantInfo(); +template<> inline const tosa::MatMulAttribute *TosaOperator::attribute_as() const { + return attribute_as_MatMulAttribute(); } -template<> inline const tosa::ConvQuantInfo *TosaOperator::quant_info_as() const { - return quant_info_as_ConvQuantInfo(); +template<> inline const tosa::FullyConnectedAttribute *TosaOperator::attribute_as() const { + return attribute_as_FullyConnectedAttribute(); } -template<> inline const tosa::MatMulQuantInfo *TosaOperator::quant_info_as() const { - return quant_info_as_MatMulQuantInfo(); -} - -template<> inline const tosa::PadQuantInfo *TosaOperator::quant_info_as() const { - return quant_info_as_PadQuantInfo(); +template<> inline const tosa::NegateAttribute *TosaOperator::attribute_as() const { + return attribute_as_NegateAttribute(); } struct TosaOperatorBuilder { @@ -2379,12 +2346,6 @@ struct TosaOperatorBuilder { void add_outputs(flatbuffers::Offset>> outputs) { fbb_.AddOffset(TosaOperator::VT_OUTPUTS, outputs); } - void add_quant_info_type(tosa::QuantInfo quant_info_type) { - fbb_.AddElement(TosaOperator::VT_QUANT_INFO_TYPE, static_cast(quant_info_type), 0); - } - void add_quant_info(flatbuffers::Offset quant_info) { - fbb_.AddOffset(TosaOperator::VT_QUANT_INFO, quant_info); - } explicit TosaOperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2403,16 +2364,12 @@ inline flatbuffers::Offset CreateTosaOperator( tosa::Attribute attribute_type = tosa::Attribute_NONE, flatbuffers::Offset attribute = 0, flatbuffers::Offset>> inputs = 0, - flatbuffers::Offset>> outputs = 0, - tosa::QuantInfo quant_info_type = tosa::QuantInfo_NONE, - flatbuffers::Offset quant_info = 0) { + flatbuffers::Offset>> outputs = 0) { TosaOperatorBuilder builder_(_fbb); - builder_.add_quant_info(quant_info); builder_.add_outputs(outputs); builder_.add_inputs(inputs); builder_.add_attribute(attribute); builder_.add_op(op); - builder_.add_quant_info_type(quant_info_type); builder_.add_attribute_type(attribute_type); return builder_.Finish(); } @@ -2423,9 +2380,7 @@ inline flatbuffers::Offset CreateTosaOperatorDirect( tosa::Attribute attribute_type = tosa::Attribute_NONE, flatbuffers::Offset attribute = 0, const std::vector> *inputs = nullptr, - const std::vector> *outputs = nullptr, - tosa::QuantInfo quant_info_type = tosa::QuantInfo_NONE, - flatbuffers::Offset quant_info = 0) { + const std::vector> *outputs = nullptr) { auto inputs__ = inputs ? _fbb.CreateVector>(*inputs) : 0; auto outputs__ = outputs ? _fbb.CreateVector>(*outputs) : 0; return tosa::CreateTosaOperator( @@ -2434,9 +2389,7 @@ inline flatbuffers::Offset CreateTosaOperatorDirect( attribute_type, attribute, inputs__, - outputs__, - quant_info_type, - quant_info); + outputs__); } struct TosaBasicBlock FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -2690,53 +2643,28 @@ inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, At auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } - default: return true; - } -} - -inline bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { - if (!values || !types) return !values && !types; - if (values->size() != types->size()) return false; - for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { - if (!VerifyAttribute( - verifier, values->Get(i), types->GetEnum(i))) { - return false; - } - } - return true; -} - -inline bool VerifyQuantInfo(flatbuffers::Verifier &verifier, const void *obj, QuantInfo type) { - switch (type) { - case QuantInfo_NONE: { - return true; - } - case QuantInfo_UnaryQuantInfo: { - auto ptr = reinterpret_cast(obj); - return verifier.VerifyTable(ptr); - } - case QuantInfo_ConvQuantInfo: { - auto ptr = reinterpret_cast(obj); + case Attribute_MatMulAttribute: { + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } - case QuantInfo_MatMulQuantInfo: { - auto ptr = reinterpret_cast(obj); + case Attribute_FullyConnectedAttribute: { + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } - case QuantInfo_PadQuantInfo: { - auto ptr = reinterpret_cast(obj); + case Attribute_NegateAttribute: { + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } default: return true; } } -inline bool VerifyQuantInfoVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { +inline bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { if (!values || !types) return !values && !types; if (values->size() != types->size()) return false; for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { - if (!VerifyQuantInfo( - verifier, values->Get(i), types->GetEnum(i))) { + if (!VerifyAttribute( + verifier, values->Get(i), types->GetEnum(i))) { return false; } } diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 33c8047..695c530 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -19,7 +19,6 @@ #include "flatbuffers/idl.h" #include "flatbuffers/util.h" #include "numpy_utils.h" -#include "quant_info.h" #include "tosa_generated.h" #include #include @@ -28,7 +27,7 @@ // Keep version number in sync with the version default value with schema/tosa.fbs #define TOSA_VERSION_MAJOR 0 -#define TOSA_VERSION_MINOR 25 +#define TOSA_VERSION_MINOR 30 #define TOSA_VERSION_PATCH 0 #define TOSA_VERSION_DRAFT true #define TENSOR_BUFFER_FORCE_ALIGNMENT 8 @@ -172,15 +171,11 @@ public: TosaSerializationOperator(Op op, Attribute attribute_type, const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo, const std::vector& input_tensor_names, const std::vector& output_tensor_names); TosaSerializationOperator(Op op, Attribute attribute_type, const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo, std::vector&& input_tensor_names, std::vector&& output_tensor_names); ~TosaSerializationOperator(); @@ -198,14 +193,6 @@ public: { return _attribute; } - QuantInfo GetQInfoType() const - { - return _qinfo_type; - } - TosaQuantInfoBase* GetQInfo() const - { - return _qinfo; - } std::vector& GetInputTensorNames() { return _input_tensor_names; @@ -216,15 +203,10 @@ public: } private: - void InitializeAttributeQinfo(Attribute attribute_type, - const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo); + void InitializeAttribute(Attribute attribute_type, const TosaAttributeBase* attribute); Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */ Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */ TosaAttributeBase* _attribute; /* real attribute class goes here */ - QuantInfo _qinfo_type; /* QuantInfo enum */ - TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */ std::vector _input_tensor_names; /* array of input tensor names */ std::vector _output_tensor_names; /* array of output tensor names */ }; diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index 4d7d7bf..10372e4 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -30,7 +30,7 @@ import tosa.Op as TosaOp # Keep version number in sync with the version default value with schema/tosa.fbs TOSA_VERSION_MAJOR = 0 -TOSA_VERSION_MINOR = 25 +TOSA_VERSION_MINOR = 30 TOSA_VERSION_PATCH = 0 TOSA_VERSION_DRAFT = True TOSA_VERSION = [ @@ -141,7 +141,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): def __init__(self): super().__init__() - def PoolAttribute(self, kernel, stride, pad): + def PoolAttribute(self, kernel, stride, pad, input_zp, output_zp): from tosa import PoolAttribute as a, Attribute self.utype = Attribute.Attribute().PoolAttribute @@ -150,8 +150,10 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddPad, pad)) self.intvecs.append((a.AddKernel, kernel)) self.intvecs.append((a.AddStride, stride)) + self.ints.append((a.AddInputZp, input_zp)) + self.ints.append((a.AddOutputZp, output_zp)) - def ConvAttribute(self, pad, stride, dilation): + def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp): from tosa import ConvAttribute as a, Attribute self.utype = Attribute.Attribute().ConvAttribute @@ -160,8 +162,10 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddPad, pad)) self.intvecs.append((a.AddStride, stride)) self.intvecs.append((a.AddDilation, dilation)) + self.ints.append((a.AddInputZp, input_zp)) + self.ints.append((a.AddWeightZp, weight_zp)) - def TransposeConvAttribute(self, outpad, stride, output_shape): + def TransposeConvAttribute(self, outpad, stride, output_shape, input_zp, weight_zp): from tosa import TransposeConvAttribute as a, Attribute self.utype = Attribute.Attribute().TransposeConvAttribute @@ -170,6 +174,8 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddOutPad, outpad)) self.intvecs.append((a.AddStride, stride)) self.intvecs.append((a.AddOutputShape, output_shape)) + self.ints.append((a.AddInputZp, input_zp)) + self.ints.append((a.AddWeightZp, weight_zp)) def PadAttribute(self, padding, pad_const_int, pad_const_fp): from tosa import PadAttribute as a, Attribute @@ -311,43 +317,32 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddTable, table)) + def MatMulAttribute(self, A_zp, B_zp): + from tosa import MatMulAttribute as a, Attribute -class TosaSerializerQuantInfo(TosaSerializerUnion): - """This class handles encapsulating all of the enumerated types for quantinfo""" - - def __init__(self): - super().__init__() - - def ConvQuantInfo(self, input_zp, weight_zp): - from tosa import ConvQuantInfo as q, QuantInfo + self.utype = Attribute.Attribute().MatMulAttribute + self.optFcns = (a.Start, a.End) - self.utype = QuantInfo.QuantInfo().ConvQuantInfo - self.optFcns = (q.Start, q.End) - self.ints.append((q.AddInputZp, input_zp)) - self.ints.append((q.AddWeightZp, weight_zp)) + self.ints.append((a.AddAZp, A_zp)) + self.ints.append((a.AddBZp, B_zp)) - def UnaryQuantInfo(self, input_zp, output_zp): - from tosa import UnaryQuantInfo as q, QuantInfo + def FullyConnectedAttribute(self, input_zp, weight_zp): + from tosa import FullyConnectedAttribute as a, Attribute - self.utype = QuantInfo.QuantInfo().UnaryQuantInfo - self.optFcns = (q.Start, q.End) - self.ints.append((q.AddInputZp, input_zp)) - self.ints.append((q.AddOutputZp, output_zp)) + self.utype = Attribute.Attribute().FullyConnectedAttribute + self.optFcns = (a.Start, a.End) - def MatMulQuantInfo(self, a_zp, b_zp): - from tosa import MatMulQuantInfo as q, QuantInfo + self.ints.append((a.AddInputZp, input_zp)) + self.ints.append((a.AddWeightZp, weight_zp)) - self.utype = QuantInfo.QuantInfo().MatMulQuantInfo - self.optFcns = (q.Start, q.End) - self.ints.append((q.AddAZp, a_zp)) - self.ints.append((q.AddBZp, b_zp)) + def NegateAttribute(self, input1_zp, output_zp): + from tosa import NegateAttribute as a, Attribute - def PadQuantInfo(self, input_zp): - from tosa import PadQuantInfo as q, QuantInfo + self.utype = Attribute.Attribute().NegateAttribute + self.optFcns = (a.Start, a.End) - self.utype = QuantInfo.QuantInfo().PadQuantInfo - self.optFcns = (q.Start, q.End) - self.ints.append((q.AddInputZp, input_zp)) + self.ints.append((a.AddInput1Zp, input1_zp)) + self.ints.append((a.AddOutputZp, output_zp)) class TosaSerializerTensor: @@ -467,12 +462,11 @@ class TosaSerializerTensor: class TosaSerializerOperator: - def __init__(self, op, inputs, outputs, attributes=None, quantInfo=None): + def __init__(self, op, inputs, outputs, attributes=None): self.op = op self.attributes = attributes self.inputs = TosaSerializer.toList(inputs) self.outputs = TosaSerializer.toList(outputs) - self.quantInfo = quantInfo def __str__(self): str = "Op {}\n----\n".format(self.op) @@ -491,13 +485,10 @@ class TosaSerializerOperator: fb_outputs = TosaSerializer.serializeStrVec( builder, self.outputs, TosaOperator.StartOutputsVector ) - # Need to serialize quant_info and attributes enums still + # Need to serialize attributes enums still if self.attributes is not None: fb_attributes = self.attributes.serialize(builder) - if self.quantInfo is not None: - fb_qinfo = self.quantInfo.serialize(builder) - TosaOperator.Start(builder) TosaOperator.AddOp(builder, self.op) TosaOperator.AddInputs(builder, fb_inputs) @@ -505,9 +496,6 @@ class TosaSerializerOperator: if self.attributes is not None: TosaOperator.AddAttributeType(builder, self.attributes.utype) TosaOperator.AddAttribute(builder, fb_attributes) - if self.quantInfo is not None: - TosaOperator.AddQuantInfoType(builder, self.quantInfo.utype) - TosaOperator.AddQuantInfo(builder, fb_qinfo) return TosaOperator.End(builder) @@ -544,10 +532,8 @@ class TosaSerializerBasicBlock: def addOutput(self, name): self.outputs.append(name) - def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None): - self.operators.append( - TosaSerializerOperator(op, inputs, outputs, attributes, quant_info) - ) + def addOperator(self, op, inputs, outputs, attributes=None): + self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes)) def serialize(self, builder): fb_name = builder.CreateString(self.name) @@ -671,13 +657,16 @@ class TosaSerializer: self.currBasicBlock.addOutput(name) return tens - def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None): + def addOperator(self, op, inputs, outputs, attributes=None): if op == TosaOp.Op().CONST: raise Exception("Use addConstTensor() to add CONST ops") return self.currBasicBlock.addOperator( - op, inputs, outputs, attributes, quant_info + op, + inputs, + outputs, + attributes, ) def setExpectedReturnCode(self, val, fail, desc=""): @@ -861,21 +850,48 @@ class TosaSerializer: ConvAttribute.StartDilationVector = ( ConvAttribute.ConvAttributeStartDilationVector ) + ConvAttribute.AddInputZp = ConvAttribute.ConvAttributeAddInputZp + ConvAttribute.AddWeightZp = ConvAttribute.ConvAttributeAddWeightZp ConvAttribute.End = ConvAttribute.ConvAttributeEnd - from tosa import ConvQuantInfo - - if not hasattr(ConvQuantInfo, "Start"): - ConvQuantInfo.Start = ConvQuantInfo.ConvQuantInfoStart - ConvQuantInfo.AddInputZp = ConvQuantInfo.ConvQuantInfoAddInputZp - ConvQuantInfo.AddWeightZp = ConvQuantInfo.ConvQuantInfoAddWeightZp - ConvQuantInfo.End = ConvQuantInfo.ConvQuantInfoEnd - from tosa import MatMulQuantInfo - - if not hasattr(MatMulQuantInfo, "Start"): - MatMulQuantInfo.Start = MatMulQuantInfo.MatMulQuantInfoStart - MatMulQuantInfo.AddAZp = MatMulQuantInfo.MatMulQuantInfoAddAZp - MatMulQuantInfo.AddBZp = MatMulQuantInfo.MatMulQuantInfoAddBZp - MatMulQuantInfo.End = MatMulQuantInfo.MatMulQuantInfoEnd + from tosa import FullyConnectedAttribute + + if not hasattr(FullyConnectedAttribute, "Start"): + FullyConnectedAttribute.Start = ( + FullyConnectedAttribute.FullyConnectedAttributeStart + ) + FullyConnectedAttribute.AddInputZp = ( + FullyConnectedAttribute.FullyConnectedAttributeAddInputZp + ) + FullyConnectedAttribute.AddWeightZp = ( + FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp + ) + FullyConnectedAttribute.End = ( + FullyConnectedAttribute.FullyConnectedAttributeEnd + ) + from tosa import MatMulAttribute + + if not hasattr(MatMulAttribute, "Start"): + MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart + MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp + MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp + MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd + from tosa import PoolAttribute + + if not hasattr(PoolAttribute, "Start"): + PoolAttribute.Start = PoolAttribute.PoolAttributeStart + PoolAttribute.AddPad = PoolAttribute.PoolAttributeAddPad + PoolAttribute.StartPadVector = PoolAttribute.PoolAttributeStartPadVector + PoolAttribute.AddKernel = PoolAttribute.PoolAttributeAddKernel + PoolAttribute.StartKernelVector = ( + PoolAttribute.PoolAttributeStartKernelVector + ) + PoolAttribute.AddStride = PoolAttribute.PoolAttributeAddStride + PoolAttribute.StartStrideVector = ( + PoolAttribute.PoolAttributeStartStrideVector + ) + PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp + PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp + PoolAttribute.End = PoolAttribute.PoolAttributeEnd from tosa import MulAttribute if not hasattr(MulAttribute, "Start"): @@ -893,12 +909,6 @@ class TosaSerializer: PadAttribute.AddPadConstInt = PadAttribute.PadAttributeAddPadConstInt PadAttribute.AddPadConstFp = PadAttribute.PadAttributeAddPadConstFp PadAttribute.End = PadAttribute.PadAttributeEnd - from tosa import PadQuantInfo - - if not hasattr(PadQuantInfo, "Start"): - PadQuantInfo.Start = PadQuantInfo.PadQuantInfoStart - PadQuantInfo.AddInputZp = PadQuantInfo.PadQuantInfoAddInputZp - PadQuantInfo.End = PadQuantInfo.PadQuantInfoEnd from tosa import PoolAttribute if not hasattr(PoolAttribute, "Start"): @@ -913,6 +923,8 @@ class TosaSerializer: PoolAttribute.StartStrideVector = ( PoolAttribute.PoolAttributeStartStrideVector ) + PoolAttribute.AddInputZp = PoolAttribute.PoolAttributeAddInputZp + PoolAttribute.AddOutputZp = PoolAttribute.PoolAttributeAddOutputZp PoolAttribute.End = PoolAttribute.PoolAttributeEnd from tosa import RescaleAttribute @@ -1048,8 +1060,6 @@ class TosaSerializer: TosaOperator.StartOutputsVector = ( TosaOperator.TosaOperatorStartOutputsVector ) - TosaOperator.AddQuantInfoType = TosaOperator.TosaOperatorAddQuantInfoType - TosaOperator.AddQuantInfo = TosaOperator.TosaOperatorAddQuantInfo TosaOperator.End = TosaOperator.TosaOperatorEnd from tosa import TosaTensor @@ -1095,16 +1105,15 @@ class TosaSerializer: TransposeConvAttribute.StartOutputShapeVector = ( TransposeConvAttribute.TransposeConvAttributeStartOutputShapeVector ) + TransposeConvAttribute.AddInputZp = ( + TransposeConvAttribute.TransposeConvAttributeAddInputZp + ) + TransposeConvAttribute.AddWeightZp = ( + TransposeConvAttribute.TransposeConvAttributeAddWeightZp + ) TransposeConvAttribute.End = ( TransposeConvAttribute.TransposeConvAttributeEnd ) - from tosa import UnaryQuantInfo - - if not hasattr(UnaryQuantInfo, "Start"): - UnaryQuantInfo.Start = UnaryQuantInfo.UnaryQuantInfoStart - UnaryQuantInfo.AddInputZp = UnaryQuantInfo.UnaryQuantInfoAddInputZp - UnaryQuantInfo.AddOutputZp = UnaryQuantInfo.UnaryQuantInfoAddOutputZp - UnaryQuantInfo.End = UnaryQuantInfo.UnaryQuantInfoEnd from tosa import Version if not hasattr(Version, "Start"): @@ -1114,6 +1123,35 @@ class TosaSerializer: Version.Add_patch = Version.VersionAdd_patch Version.Add_draft = Version.VersionAdd_draft Version.End = Version.VersionEnd + from tosa import MatMulAttribute + + if not hasattr(MatMulAttribute, "Start"): + MatMulAttribute.Start = MatMulAttribute.MatMulAttributeStart + MatMulAttribute.AddAZp = MatMulAttribute.MatMulAttributeAddAZp + MatMulAttribute.AddBZp = MatMulAttribute.MatMulAttributeAddBZp + MatMulAttribute.End = MatMulAttribute.MatMulAttributeEnd + from tosa import FullyConnectedAttribute + + if not hasattr(FullyConnectedAttribute, "Start"): + FullyConnectedAttribute.Start = ( + FullyConnectedAttribute.FullyConnectedAttributeStart + ) + FullyConnectedAttribute.AddInputZp = ( + FullyConnectedAttribute.FullyConnectedAttributeAddInputZp + ) + FullyConnectedAttribute.AddWeightZp = ( + FullyConnectedAttribute.FullyConnectedAttributeAddWeightZp + ) + FullyConnectedAttribute.End = ( + FullyConnectedAttribute.FullyConnectedAttributeEnd + ) + from tosa import NegateAttribute + + if not hasattr(NegateAttribute, "Start"): + NegateAttribute.Start = NegateAttribute.NegateAttributeStart + NegateAttribute.AddInput1Zp = NegateAttribute.NegateAttributeAddInput1Zp + NegateAttribute.AddOutputZp = NegateAttribute.NegateAttributeAddOutputZp + NegateAttribute.End = NegateAttribute.NegateAttributeEnd from tosa import WhileLoopAttribute if not hasattr(WhileLoopAttribute, "Start"): diff --git a/python/tosa/Attribute.py b/python/tosa/Attribute.py index f0307af..166de8e 100644 --- a/python/tosa/Attribute.py +++ b/python/tosa/Attribute.py @@ -21,4 +21,7 @@ class Attribute(object): WhileLoopAttribute = 15 TransposeAttribute = 16 TableAttribute = 17 + MatMulAttribute = 18 + FullyConnectedAttribute = 19 + NegateAttribute = 20 diff --git a/python/tosa/ConvAttribute.py b/python/tosa/ConvAttribute.py index 72a24ce..8244ea5 100644 --- a/python/tosa/ConvAttribute.py +++ b/python/tosa/ConvAttribute.py @@ -105,11 +105,27 @@ class ConvAttribute(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) return o == 0 -def ConvAttributeStart(builder): builder.StartObject(3) + # ConvAttribute + def InputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # ConvAttribute + def WeightZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def ConvAttributeStart(builder): builder.StartObject(5) def ConvAttributeAddPad(builder, pad): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(pad), 0) def ConvAttributeStartPadVector(builder, numElems): return builder.StartVector(4, numElems, 4) def ConvAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0) def ConvAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) def ConvAttributeAddDilation(builder, dilation): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(dilation), 0) def ConvAttributeStartDilationVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ConvAttributeAddInputZp(builder, inputZp): builder.PrependInt32Slot(3, inputZp, 0) +def ConvAttributeAddWeightZp(builder, weightZp): builder.PrependInt32Slot(4, weightZp, 0) def ConvAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/ConvQuantInfo.py b/python/tosa/ConvQuantInfo.py deleted file mode 100644 index 7902c67..0000000 --- a/python/tosa/ConvQuantInfo.py +++ /dev/null @@ -1,44 +0,0 @@ -# automatically generated by the FlatBuffers compiler, do not modify - -# namespace: tosa - -import flatbuffers -from flatbuffers.compat import import_numpy -np = import_numpy() - -class ConvQuantInfo(object): - __slots__ = ['_tab'] - - @classmethod - def GetRootAsConvQuantInfo(cls, buf, offset): - n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) - x = ConvQuantInfo() - x.Init(buf, n + offset) - return x - - @classmethod - def ConvQuantInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False): - return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) - - # ConvQuantInfo - def Init(self, buf, pos): - self._tab = flatbuffers.table.Table(buf, pos) - - # ConvQuantInfo - def InputZp(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) - return 0 - - # ConvQuantInfo - def WeightZp(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) - return 0 - -def ConvQuantInfoStart(builder): builder.StartObject(2) -def ConvQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0) -def ConvQuantInfoAddWeightZp(builder, weightZp): builder.PrependInt32Slot(1, weightZp, 0) -def ConvQuantInfoEnd(builder): return builder.EndObject() diff --git a/python/tosa/FullyConnectedAttribute.py b/python/tosa/FullyConnectedAttribute.py new file mode 100644 index 0000000..62b480d --- /dev/null +++ b/python/tosa/FullyConnectedAttribute.py @@ -0,0 +1,44 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class FullyConnectedAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsFullyConnectedAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = FullyConnectedAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def FullyConnectedAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # FullyConnectedAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # FullyConnectedAttribute + def InputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # FullyConnectedAttribute + def WeightZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def FullyConnectedAttributeStart(builder): builder.StartObject(2) +def FullyConnectedAttributeAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0) +def FullyConnectedAttributeAddWeightZp(builder, weightZp): builder.PrependInt32Slot(1, weightZp, 0) +def FullyConnectedAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/MatMulAttribute.py b/python/tosa/MatMulAttribute.py new file mode 100644 index 0000000..601f13f --- /dev/null +++ b/python/tosa/MatMulAttribute.py @@ -0,0 +1,44 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class MatMulAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsMatMulAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = MatMulAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def MatMulAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # MatMulAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # MatMulAttribute + def AZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # MatMulAttribute + def BZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def MatMulAttributeStart(builder): builder.StartObject(2) +def MatMulAttributeAddAZp(builder, aZp): builder.PrependInt32Slot(0, aZp, 0) +def MatMulAttributeAddBZp(builder, bZp): builder.PrependInt32Slot(1, bZp, 0) +def MatMulAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/MatMulQuantInfo.py b/python/tosa/MatMulQuantInfo.py deleted file mode 100644 index 457da66..0000000 --- a/python/tosa/MatMulQuantInfo.py +++ /dev/null @@ -1,44 +0,0 @@ -# automatically generated by the FlatBuffers compiler, do not modify - -# namespace: tosa - -import flatbuffers -from flatbuffers.compat import import_numpy -np = import_numpy() - -class MatMulQuantInfo(object): - __slots__ = ['_tab'] - - @classmethod - def GetRootAsMatMulQuantInfo(cls, buf, offset): - n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) - x = MatMulQuantInfo() - x.Init(buf, n + offset) - return x - - @classmethod - def MatMulQuantInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False): - return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) - - # MatMulQuantInfo - def Init(self, buf, pos): - self._tab = flatbuffers.table.Table(buf, pos) - - # MatMulQuantInfo - def AZp(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) - return 0 - - # MatMulQuantInfo - def BZp(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) - return 0 - -def MatMulQuantInfoStart(builder): builder.StartObject(2) -def MatMulQuantInfoAddAZp(builder, aZp): builder.PrependInt32Slot(0, aZp, 0) -def MatMulQuantInfoAddBZp(builder, bZp): builder.PrependInt32Slot(1, bZp, 0) -def MatMulQuantInfoEnd(builder): return builder.EndObject() diff --git a/python/tosa/NegateAttribute.py b/python/tosa/NegateAttribute.py new file mode 100644 index 0000000..24a57dc --- /dev/null +++ b/python/tosa/NegateAttribute.py @@ -0,0 +1,44 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class NegateAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsNegateAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = NegateAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def NegateAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # NegateAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # NegateAttribute + def Input1Zp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # NegateAttribute + def OutputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def NegateAttributeStart(builder): builder.StartObject(2) +def NegateAttributeAddInput1Zp(builder, input1Zp): builder.PrependInt32Slot(0, input1Zp, 0) +def NegateAttributeAddOutputZp(builder, outputZp): builder.PrependInt32Slot(1, outputZp, 0) +def NegateAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/PadQuantInfo.py b/python/tosa/PadQuantInfo.py deleted file mode 100644 index c07db07..0000000 --- a/python/tosa/PadQuantInfo.py +++ /dev/null @@ -1,36 +0,0 @@ -# automatically generated by the FlatBuffers compiler, do not modify - -# namespace: tosa - -import flatbuffers -from flatbuffers.compat import import_numpy -np = import_numpy() - -class PadQuantInfo(object): - __slots__ = ['_tab'] - - @classmethod - def GetRootAsPadQuantInfo(cls, buf, offset): - n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) - x = PadQuantInfo() - x.Init(buf, n + offset) - return x - - @classmethod - def PadQuantInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False): - return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) - - # PadQuantInfo - def Init(self, buf, pos): - self._tab = flatbuffers.table.Table(buf, pos) - - # PadQuantInfo - def InputZp(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) - return 0 - -def PadQuantInfoStart(builder): builder.StartObject(1) -def PadQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0) -def PadQuantInfoEnd(builder): return builder.EndObject() diff --git a/python/tosa/PoolAttribute.py b/python/tosa/PoolAttribute.py index f2c12cf..8b6903e 100644 --- a/python/tosa/PoolAttribute.py +++ b/python/tosa/PoolAttribute.py @@ -105,11 +105,27 @@ class PoolAttribute(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) return o == 0 -def PoolAttributeStart(builder): builder.StartObject(3) + # PoolAttribute + def InputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # PoolAttribute + def OutputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def PoolAttributeStart(builder): builder.StartObject(5) def PoolAttributeAddPad(builder, pad): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(pad), 0) def PoolAttributeStartPadVector(builder, numElems): return builder.StartVector(4, numElems, 4) def PoolAttributeAddKernel(builder, kernel): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernel), 0) def PoolAttributeStartKernelVector(builder, numElems): return builder.StartVector(4, numElems, 4) def PoolAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0) def PoolAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def PoolAttributeAddInputZp(builder, inputZp): builder.PrependInt32Slot(3, inputZp, 0) +def PoolAttributeAddOutputZp(builder, outputZp): builder.PrependInt32Slot(4, outputZp, 0) def PoolAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/QuantInfo.py b/python/tosa/QuantInfo.py deleted file mode 100644 index ffdfd32..0000000 --- a/python/tosa/QuantInfo.py +++ /dev/null @@ -1,11 +0,0 @@ -# automatically generated by the FlatBuffers compiler, do not modify - -# namespace: tosa - -class QuantInfo(object): - NONE = 0 - UnaryQuantInfo = 1 - ConvQuantInfo = 2 - MatMulQuantInfo = 3 - PadQuantInfo = 4 - diff --git a/python/tosa/TosaOperator.py b/python/tosa/TosaOperator.py index 040c2dc..fd11f76 100644 --- a/python/tosa/TosaOperator.py +++ b/python/tosa/TosaOperator.py @@ -88,24 +88,7 @@ class TosaOperator(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) return o == 0 - # TosaOperator - def QuantInfoType(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) - return 0 - - # TosaOperator - def QuantInfo(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) - if o != 0: - from flatbuffers.table import Table - obj = Table(bytearray(), 0) - self._tab.Union(obj, o) - return obj - return None - -def TosaOperatorStart(builder): builder.StartObject(7) +def TosaOperatorStart(builder): builder.StartObject(5) def TosaOperatorAddOp(builder, op): builder.PrependUint32Slot(0, op, 0) def TosaOperatorAddAttributeType(builder, attributeType): builder.PrependUint8Slot(1, attributeType, 0) def TosaOperatorAddAttribute(builder, attribute): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(attribute), 0) @@ -113,6 +96,4 @@ def TosaOperatorAddInputs(builder, inputs): builder.PrependUOffsetTRelativeSlot( def TosaOperatorStartInputsVector(builder, numElems): return builder.StartVector(4, numElems, 4) def TosaOperatorAddOutputs(builder, outputs): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0) def TosaOperatorStartOutputsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def TosaOperatorAddQuantInfoType(builder, quantInfoType): builder.PrependUint8Slot(5, quantInfoType, 0) -def TosaOperatorAddQuantInfo(builder, quantInfo): builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(quantInfo), 0) def TosaOperatorEnd(builder): return builder.EndObject() diff --git a/python/tosa/TransposeConvAttribute.py b/python/tosa/TransposeConvAttribute.py index 2a675a0..8ca5ba7 100644 --- a/python/tosa/TransposeConvAttribute.py +++ b/python/tosa/TransposeConvAttribute.py @@ -105,11 +105,27 @@ class TransposeConvAttribute(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) return o == 0 -def TransposeConvAttributeStart(builder): builder.StartObject(3) + # TransposeConvAttribute + def InputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # TransposeConvAttribute + def WeightZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def TransposeConvAttributeStart(builder): builder.StartObject(5) def TransposeConvAttributeAddOutPad(builder, outPad): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(outPad), 0) def TransposeConvAttributeStartOutPadVector(builder, numElems): return builder.StartVector(4, numElems, 4) def TransposeConvAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0) def TransposeConvAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) def TransposeConvAttributeAddOutputShape(builder, outputShape): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(outputShape), 0) def TransposeConvAttributeStartOutputShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TransposeConvAttributeAddInputZp(builder, inputZp): builder.PrependInt32Slot(3, inputZp, 0) +def TransposeConvAttributeAddWeightZp(builder, weightZp): builder.PrependInt32Slot(4, weightZp, 0) def TransposeConvAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/UnaryQuantInfo.py b/python/tosa/UnaryQuantInfo.py deleted file mode 100644 index 648fc34..0000000 --- a/python/tosa/UnaryQuantInfo.py +++ /dev/null @@ -1,44 +0,0 @@ -# automatically generated by the FlatBuffers compiler, do not modify - -# namespace: tosa - -import flatbuffers -from flatbuffers.compat import import_numpy -np = import_numpy() - -class UnaryQuantInfo(object): - __slots__ = ['_tab'] - - @classmethod - def GetRootAsUnaryQuantInfo(cls, buf, offset): - n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) - x = UnaryQuantInfo() - x.Init(buf, n + offset) - return x - - @classmethod - def UnaryQuantInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False): - return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) - - # UnaryQuantInfo - def Init(self, buf, pos): - self._tab = flatbuffers.table.Table(buf, pos) - - # UnaryQuantInfo - def InputZp(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) - return 0 - - # UnaryQuantInfo - def OutputZp(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) - return 0 - -def UnaryQuantInfoStart(builder): builder.StartObject(2) -def UnaryQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0) -def UnaryQuantInfoAddOutputZp(builder, outputZp): builder.PrependInt32Slot(1, outputZp, 0) -def UnaryQuantInfoEnd(builder): return builder.EndObject() diff --git a/python/tosa/Version.py b/python/tosa/Version.py index 27dea53..bdac948 100644 --- a/python/tosa/Version.py +++ b/python/tosa/Version.py @@ -36,7 +36,7 @@ class Version(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) - return 25 + return 30 # Version def _patch(self): @@ -54,7 +54,7 @@ class Version(object): def VersionStart(builder): builder.StartObject(4) def VersionAdd_major(builder, Major): builder.PrependInt32Slot(0, Major, 0) -def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 25) +def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 30) def VersionAdd_patch(builder, Patch): builder.PrependInt32Slot(2, Patch, 0) def VersionAdd_draft(builder, Draft): builder.PrependBoolSlot(3, Draft, 1) def VersionEnd(builder): return builder.EndObject() diff --git a/schema/tosa.fbs b/schema/tosa.fbs index 7830bd2..b342103 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -157,24 +157,33 @@ union Attribute { WhileLoopAttribute, TransposeAttribute, TableAttribute, + MatMulAttribute, + FullyConnectedAttribute, + NegateAttribute } table PoolAttribute { pad: [int32]; kernel: [int32]; stride: [int32]; + input_zp: int32; + output_zp: int32; } table ConvAttribute { pad: [int32]; stride: [int32]; dilation: [int32]; + input_zp: int32; + weight_zp: int32; } table TransposeConvAttribute { out_pad: [int32]; stride: [int32]; output_shape: [int32]; + input_zp: int32; + weight_zp: int32; } table PadAttribute { @@ -253,35 +262,24 @@ table TableAttribute { table: [int16]; } -union QuantInfo { - UnaryQuantInfo, - ConvQuantInfo, - MatMulQuantInfo, - PadQuantInfo, -} - -table UnaryQuantInfo { - input_zp: int32; - output_zp: int32; +table MatMulAttribute { + a_zp: int32; + b_zp: int32; } -table ConvQuantInfo { +table FullyConnectedAttribute { input_zp: int32; weight_zp: int32; } -table MatMulQuantInfo { - a_zp: int32; - b_zp: int32; -} - -table PadQuantInfo { - input_zp: int32; +table NegateAttribute { + input1_zp: int32; + output_zp: int32; } table Version { _major: int32 = 0; - _minor: int32 = 25; + _minor: int32 = 30; _patch: int32 = 0; _draft: bool = true; } @@ -298,7 +296,6 @@ table TosaOperator { attribute: Attribute; // union structure. operator attribute inputs:[string]; // list of input tensor names outputs:[string]; // list of output tensor names - quant_info: QuantInfo; // op-based quantization information } table TosaBasicBlock { diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 32725a0..99da2e7 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -56,10 +56,7 @@ TosaSerializationTensor::TosaSerializationTensor() TosaSerializationTensor::~TosaSerializationTensor() {} -void TosaSerializationOperator::InitializeAttributeQinfo(Attribute attribute_type, - const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo) +void TosaSerializationOperator::InitializeAttribute(Attribute attribute_type, const TosaAttributeBase* attribute) { _attribute_type = attribute_type; @@ -80,32 +77,12 @@ void TosaSerializationOperator::InitializeAttributeQinfo(Attribute attribute_typ assert(0); } - _qinfo_type = qinfo_type; - switch (qinfo_type) - { - case QuantInfo_NONE: - _qinfo = new TosaNoneQuantInfo(); - break; -#define DEF_QUANTIZATION_INFO(NAME, ...) \ - case QuantInfo_##NAME##QuantInfo: \ - _qinfo = new Tosa##NAME##QuantInfo(qinfo); \ - break; -#include "quant_info.def" -#undef DEF_QUANTIZATION_INFO - default: - printf("TosaSerializationOperator::TosaSerializationOperator(): QuantInfo %s not implemented yet\n", - EnumNamesQuantInfo()[qinfo_type]); - assert(0); - } - - assert(_attribute && _qinfo); + assert(_attribute); } TosaSerializationOperator::TosaSerializationOperator(Op op, Attribute attribute_type, const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo, const std::vector& input_tensor_names, const std::vector& output_tensor_names) { @@ -113,14 +90,12 @@ TosaSerializationOperator::TosaSerializationOperator(Op op, _input_tensor_names = input_tensor_names; _output_tensor_names = output_tensor_names; - InitializeAttributeQinfo(attribute_type, attribute, qinfo_type, qinfo); + InitializeAttribute(attribute_type, attribute); } TosaSerializationOperator::TosaSerializationOperator(Op op, Attribute attribute_type, const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo, std::vector&& input_tensor_names, std::vector&& output_tensor_names) { @@ -128,13 +103,12 @@ TosaSerializationOperator::TosaSerializationOperator(Op op, _input_tensor_names = std::move(input_tensor_names); _output_tensor_names = std::move(output_tensor_names); - InitializeAttributeQinfo(attribute_type, attribute, qinfo_type, qinfo); + InitializeAttribute(attribute_type, attribute); } TosaSerializationOperator::~TosaSerializationOperator() { delete _attribute; - delete _qinfo; } TosaSerializationBasicBlock::TosaSerializationBasicBlock(const std::string& name, @@ -453,7 +427,6 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) std::vector block_outputs_container; TosaAttributeBase* typed_attribute = NULL; - TosaQuantInfoBase* typed_qinfo = NULL; TosaSerializationOperator* new_operator = NULL; TosaSerializationBasicBlock* new_block = NULL; TosaSerializationTensor* new_tensor = NULL; @@ -491,11 +464,9 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) { auto curr_operator = fb_tosa_operators->Get(j); - auto operator_op = curr_operator->op(); - auto attribute_type = curr_operator->attribute_type(); - auto attribute = curr_operator->attribute(); - auto operator_qinfo_type = curr_operator->quant_info_type(); - auto operator_qinfo = curr_operator->quant_info(); + auto operator_op = curr_operator->op(); + auto attribute_type = curr_operator->attribute_type(); + auto attribute = curr_operator->attribute(); // input tensors auto operator_inputs = curr_operator->inputs(); @@ -538,27 +509,8 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) return TOSA_INTERNAL_ERROR; } - switch (operator_qinfo_type) - { - case QuantInfo_NONE: - typed_qinfo = new TosaNoneQuantInfo(); - break; -#define DEF_QUANTIZATION_INFO(NAME, ...) \ - case QuantInfo_##NAME##QuantInfo: \ - typed_qinfo = new Tosa##NAME##QuantInfo(operator_qinfo); \ - break; - -#include "quant_info.def" -#undef DEF_QUANTIZATION_INFO - default: - printf("TosaSerializationHandler::Deserialize(): QuantInfo %s not implemented yet\n", - EnumNamesQuantInfo()[operator_qinfo_type]); - return TOSA_INTERNAL_ERROR; - } - - new_operator = - new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, operator_qinfo_type, - typed_qinfo, operator_inputs_container, operator_outputs_container); + new_operator = new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, + operator_inputs_container, operator_outputs_container); if (new_operator) { block_operators_container.push_back(new_operator); @@ -570,8 +522,6 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) if (typed_attribute) delete typed_attribute; - if (typed_qinfo) - delete typed_qinfo; } auto fb_tosa_tensors = curr_block->tensors(); @@ -751,75 +701,8 @@ tosa_err_t TosaSerializationHandler::Serialize() return TOSA_INTERNAL_ERROR; } - auto qinfo_type = op->GetQInfoType(); - flatbuffers::Offset fb_operator_qinfo; - switch (qinfo_type) - { - case QuantInfo_NONE: - fb_operator_qinfo = 0; - break; -#define DEF_ARGS_S(NAME, T, V) , reinterpret_cast(op->GetQInfo())->V() -#define DEF_ARGS_V(NAME, T, V) , _builder.CreateVector(reinterpret_cast(op->GetQInfo())->V()) - -#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0) -#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) -#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) -#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) -#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ - DEF_ARGS_##F4(NAME, T4, V4) -#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ - DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) -#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ - DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) -#define DEF_ARGS_8(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ - V7) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ - DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ - DEF_ARGS_##F7(NAME, T7, V7) -#define DEF_ARGS_9(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ - V7, T8, F8, V8) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ - DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ - DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) -#define DEF_ARGS_10(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ - V7, T8, F8, V8, T9, F9, V9) \ - DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ - DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ - DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) DEF_ARGS_##F9(NAME, T9, V9) -#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \ - case QuantInfo_##NAME##QuantInfo: \ - fb_operator_qinfo = \ - Create##NAME##QuantInfo(_builder DEF_ARGS_##NUM_ARGS(NAME##QuantInfo, __VA_ARGS__)).Union(); \ - break; - -#include "quant_info.def" -#undef DEF_QUANTIZATION_INFO -#undef DEF_ARGS_1 -#undef DEF_ARGS_2 -#undef DEF_ARGS_3 -#undef DEF_ARGS_4 -#undef DEF_ARGS_5 -#undef DEF_ARGS_6 -#undef DEF_ARGS_7 -#undef DEF_ARGS_8 -#undef DEF_ARGS_9 -#undef DEF_ARGS_10 -#undef DEF_ARGS_S -#undef DEF_ARGS_V - default: - printf("TosaSerializationHandler::Serialize(): Attribute %s not implemented yet\n", - EnumNamesAttribute()[attribute_type]); - return TOSA_INTERNAL_ERROR; - } - - auto fboffset_operator = - CreateTosaOperator(_builder, operator_op, attribute_type, fb_attribute, fb_operator_inputs, - fb_operator_outputs, qinfo_type, fb_operator_qinfo); + auto fboffset_operator = CreateTosaOperator(_builder, operator_op, attribute_type, fb_attribute, + fb_operator_inputs, fb_operator_outputs); fboffset_block_operators.push_back(fboffset_operator); } -- cgit v1.2.1