diff options
author | Eric Kunze <eric.kunze@arm.com> | 2020-10-13 16:11:07 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2020-10-14 11:11:43 -0700 |
commit | e5e2676409a936431f87d31fb74d825257b20804 (patch) | |
tree | 304d93d993ef6417b02a515025f9030367682774 /serialization | |
parent | 88b7860f180f91b5b66764c61cfd97d8bc53cece (diff) | |
download | reference_model-e5e2676409a936431f87d31fb74d825257b20804.tar.gz |
Initial checkin of TOSA reference_model and tests
Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6
Signed-off-by: Eric Kunze <eric.kunze@arm.com>
Diffstat (limited to 'serialization')
-rw-r--r-- | serialization/CMakeLists.txt | 32 | ||||
-rw-r--r-- | serialization/attribute.def | 90 | ||||
-rw-r--r-- | serialization/attribute.h | 181 | ||||
-rw-r--r-- | serialization/operator.def | 123 | ||||
-rw-r--r-- | serialization/quant_info.def | 43 | ||||
-rw-r--r-- | serialization/quant_info.h | 164 | ||||
-rw-r--r-- | serialization/tosa.fbs | 318 | ||||
-rw-r--r-- | serialization/tosa_generated.h | 2605 | ||||
-rw-r--r-- | serialization/tosa_serialization_handler.cpp | 1526 | ||||
-rw-r--r-- | serialization/tosa_serialization_handler.h | 423 |
10 files changed, 5505 insertions, 0 deletions
diff --git a/serialization/CMakeLists.txt b/serialization/CMakeLists.txt new file mode 100644 index 0000000..7bca824 --- /dev/null +++ b/serialization/CMakeLists.txt @@ -0,0 +1,32 @@ +cmake_minimum_required (VERSION 3.4) + +# Copyright (c) 2020, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +project (tosa) + +set (CMAKE_CXX_STANDARD 11) +set (CMAKE_CXX_FLAGS "-g -Wall") +set (FLATBUFFERS_SRC_DIR "../thirdparty/flatbuffers") + +set (SOURCE + tosa_serialization_handler.cpp +) + +add_library(tosa_serialization STATIC ${SOURCE}) + +include_directories("./") + +target_link_libraries(tosa_serialization PRIVATE flatbuffers) diff --git a/serialization/attribute.def b/serialization/attribute.def new file mode 100644 index 0000000..88e8c81 --- /dev/null +++ b/serialization/attribute.def @@ -0,0 +1,90 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + Syntax: + DEF_ATTRIBUTE(ATTRIBUTE_NAME, NUM_ARGS_IN_ATTRIBUTES, ARG0_TYPE, ARG0_SCALAR_OR_VECTOR, ARGS0_NAME, ...) + + Description: + ATTRIBUTE_NAME: corresponding attribute name, must match corresponding "table XXXAttribute" in tosa.fbs + NUM_ARGS_IN_ATTRIBUTES: number of arguments in this attribute + ARG0_TYPE: data type of arg0 in attribute + ARG0_SCALAR_OR_VECTOR: is arg0 a scalar(S) or a vector(V) + ARG0_NAME: name of arg0 + ...: variadic variables for more arguments, depending on NUM_ARGS_IN_ATTRIBUTES +*/ + +DEF_ATTRIBUTE(Pool2d, 3, + int32_t, V, padding, + int32_t, V, kernel, + int32_t, V, stride) + +DEF_ATTRIBUTE(Conv2d, 3, + int32_t, V, padding, + int32_t, V, stride, + int32_t, V, dilation) + +DEF_ATTRIBUTE(TransposeConv2d, 4, + int32_t, V, outpad, + int32_t, V, stride, + int32_t, V, dilation, + int32_t, V, output_shape) + +DEF_ATTRIBUTE(ReluN, 2, + int32_t, S, max_int, + float, S, max_fp) + +DEF_ATTRIBUTE(Axis, 1, + int32_t, S, axis) + +DEF_ATTRIBUTE(Reshape, 1, + int32_t, V, shape) + +DEF_ATTRIBUTE(Slice, 2, + int32_t, V, begin, + int32_t, V, size) + +DEF_ATTRIBUTE(Tile, 1, + int32_t, V, multiples) + +DEF_ATTRIBUTE(Resize, 5, + int32_t, V, output_size, + int32_t, V, stride, + int32_t, V, offset, + int32_t, S, shift, + ResizeMode, S, mode) + +DEF_ATTRIBUTE(Clamp, 4, + int32_t, S, min_int, + int32_t, S, max_int, + float, S, min_fp, + float, S, max_fp) + +DEF_ATTRIBUTE(Rescale, 7, + int32_t, S, input_zp, + int32_t, S, output_zp, + int32_t, V, multiplier, + int32_t, V, shift, + bool, S, scale32, + bool, S, double_round, + bool, S, per_channel) + +DEF_ATTRIBUTE(CondIf, 2, + string, S, then_branch, + string, S, else_branch) + +DEF_ATTRIBUTE(WhileLoop, 2, + string, S, cond_branch, + string, S, body_branch) diff --git a/serialization/attribute.h b/serialization/attribute.h new file mode 100644 index 0000000..2a33a8f --- /dev/null +++ b/serialization/attribute.h @@ -0,0 +1,181 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _TOSA_SERIALIZATION_ATTRIBUTE_H +#define _TOSA_SERIALIZATION_ATTRIBUTE_H +#include "flatbuffers/idl.h" +#include "flatbuffers/util.h" +#include "tosa_generated.h" + +using std::string; + +namespace tosa +{ + +class TosaAttributeBase +{ +public: + virtual ~TosaAttributeBase() + {} +}; + +class TosaNoneAttribute : public TosaAttributeBase +{ +public: + TosaNoneAttribute() + {} + TosaNoneAttribute(TosaNoneAttribute* p) + {} +}; + +#define DEF_ARGS_VER0_S_STR(V) _##V = p->V()->str(); +#define DEF_ARGS_VER0_S_DEFAULT(V) _##V = p->V(); + +#define DEF_ARGS_VER0_S_int32_t(V) DEF_ARGS_VER0_S_DEFAULT(V) +#define DEF_ARGS_VER0_S_float(V) DEF_ARGS_VER0_S_DEFAULT(V) +#define DEF_ARGS_VER0_S_bool(V) DEF_ARGS_VER0_S_DEFAULT(V) +#define DEF_ARGS_VER0_S_ResizeMode(V) DEF_ARGS_VER0_S_DEFAULT(V) +#define DEF_ARGS_VER0_S_string(V) DEF_ARGS_VER0_S_STR(V) + +#define DEF_ARGS_VER0_S(T, V) DEF_ARGS_VER0_S_##T(V) +#define DEF_ARGS_VER0_V(T, V) _##V = std::vector<T>(p->V()->begin(), p->V()->end()); + +#define DEF_ARGS_VER1_S(T, V) const T& V +#define DEF_ARGS_VER1_V(T, V) const std::vector<T>& V +#define DEF_ARGS_VER2_S(T, V) _##V = V; +#define DEF_ARGS_VER2_V(T, V) _##V = V; +#define DEF_ARGS_VER3_S(T, V) \ + T V() const \ + { \ + return _##V; \ + } +#define DEF_ARGS_VER3_V(T, V) \ + std::vector<T> V() const \ + { \ + return _##V; \ + } +#define DEF_ARGS_VER4_S(T, V) T _##V; +#define DEF_ARGS_VER4_V(T, V) std::vector<T> _##V; + +// another level of preprocessor indirection to handle ", " as function's input argument +#define DEF_ARGS_VER1_TRUE(T, F, V) DEF_ARGS_VER1_##F(T, V) +#define DEF_ARGS_VER1_FALSE(T, F, V) , DEF_ARGS_VER1_##F(T, V) + +#define DEF_ARGS_VER0(FIRST, T, F, V) DEF_ARGS_VER0_##F(T, V) +#define DEF_ARGS_VER1(FIRST, T, F, V) DEF_ARGS_VER1_##FIRST(T, F, V) +#define DEF_ARGS_VER2(FIRST, T, F, V) DEF_ARGS_VER2_##F(T, V) +#define DEF_ARGS_VER3(FIRST, T, F, V) DEF_ARGS_VER3_##F(T, V) +#define DEF_ARGS_VER4(FIRST, T, F, V) DEF_ARGS_VER4_##F(T, V) + +#define DEF_ARGS_0(VER, ...) +#define DEF_ARGS_1(VER, T0, F0, V0) DEF_ARGS_##VER(TRUE, T0, F0, V0) +#define DEF_ARGS_2(VER, T0, F0, V0, T1, F1, V1) DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) +#define DEF_ARGS_3(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) +#define DEF_ARGS_4(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) +#define DEF_ARGS_5(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) + +#define DEF_ARGS_6(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) + +#define DEF_ARGS_7(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ + DEF_ARGS_##VER(FALSE, T6, F6, V6) + +#define DEF_VER0_VAR_DECL_PTR(NAME) const NAME* p = static_cast<const NAME*>(options); +#define DEF_VER0_VAR_0(NAME) +#define DEF_VER0_VAR_1(NAME) DEF_VER0_VAR_DECL_PTR(NAME) +#define DEF_VER0_VAR_2(NAME) DEF_VER0_VAR_DECL_PTR(NAME) +#define DEF_VER0_VAR_3(NAME) DEF_VER0_VAR_DECL_PTR(NAME) +#define DEF_VER0_VAR_4(NAME) DEF_VER0_VAR_DECL_PTR(NAME) +#define DEF_VER0_VAR_5(NAME) DEF_VER0_VAR_DECL_PTR(NAME) +#define DEF_VER0_VAR_6(NAME) DEF_VER0_VAR_DECL_PTR(NAME) +#define DEF_VER0_VAR_7(NAME) DEF_VER0_VAR_DECL_PTR(NAME) + +#define DEF_ATTRIBUTE(NAME, NUM_ARGS, ...) \ + class Tosa##NAME##Attribute : public TosaAttributeBase \ + { \ + public: \ + Tosa##NAME##Attribute(const TosaAttributeBase* options) \ + { \ + const Tosa##NAME##Attribute* p = reinterpret_cast<const Tosa##NAME##Attribute*>(options); \ + *this = *p; \ + } \ + Tosa##NAME##Attribute(const Tosa##NAME##Attribute* p) \ + { \ + *this = *p; \ + } \ + Tosa##NAME##Attribute(const void* options){ DEF_VER0_VAR_##NUM_ARGS(NAME##Attribute) \ + DEF_ARGS_##NUM_ARGS(VER0, __VA_ARGS__) } Tosa##NAME \ + ##Attribute(DEF_ARGS_##NUM_ARGS(VER1, __VA_ARGS__)) \ + { \ + DEF_ARGS_##NUM_ARGS(VER2, __VA_ARGS__) \ + } \ + virtual ~Tosa##NAME##Attribute() \ + {} \ + DEF_ARGS_##NUM_ARGS(VER3, __VA_ARGS__) private : DEF_ARGS_##NUM_ARGS(VER4, __VA_ARGS__) \ + }; + +#include "attribute.def" +#undef DEF_ATTRIBUTE +#undef DEF_ARGS_0 +#undef DEF_ARGS_1 +#undef DEF_ARGS_2 +#undef DEF_ARGS_3 +#undef DEF_ARGS_4 +#undef DEF_ARGS_5 +#undef DEF_ARGS_6 +#undef DEF_ARGS_7 +#undef DEF_ARGS_VER0 +#undef DEF_ARGS_VER1 +#undef DEF_ARGS_VER2 +#undef DEF_ARGS_VER3 +#undef DEF_ARGS_VER4 +#undef DEF_ARGS_VER0_S_int32_t +#undef DEF_ARGS_VER0_S_float +#undef DEF_ARGS_VER0_S_bool +#undef DEF_ARGS_VER0_S_ResizeMode +#undef DEF_ARGS_VER0_S_string +#undef DEF_ARGS_VER0_S_STR +#undef DEF_ARGS_VER0_S_DEFAULT +#undef DEF_ARGS_VER1_TRUE +#undef DEF_ARGS_VER1_FALSE +#undef DEF_ARGS_VER0_S +#undef DEF_ARGS_VER0_V +#undef DEF_ARGS_VER1_S +#undef DEF_ARGS_VER1_V +#undef DEF_ARGS_VER2_S +#undef DEF_ARGS_VER2_V +#undef DEF_ARGS_VER3_S +#undef DEF_ARGS_VER3_V +#undef DEF_ARGS_VER4_S +#undef DEF_ARGS_VER4_V +#undef DEF_VER0_VAR_0 +#undef DEF_VER0_VAR_1 +#undef DEF_VER0_VAR_2 +#undef DEF_VER0_VAR_3 +#undef DEF_VER0_VAR_4 +#undef DEF_VER0_VAR_5 +#undef DEF_VER0_VAR_DECL_PTR + +} // namespace tosa + +#endif // _TOSA_SERIALIZATION_ATTRIBUTE_H diff --git a/serialization/operator.def b/serialization/operator.def new file mode 100644 index 0000000..66d3784 --- /dev/null +++ b/serialization/operator.def @@ -0,0 +1,123 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + Syntax: + DEF_OPERATOR(MLIR_NAME, SCHEMA_NAME, REF_IMPL_NAME, OPTIONS, QUANT_INFO) + + Description: + MLIR_NAME: the symbolic string of this op, must match tosa_ops.td + SCHEMA_NAME: corresponding operator name, must match "enum Op" in serialization/tosa.fbs + REF_IMPL_NAME: name used internally in tosa reference implementation + OPTIONS: compile time constant options of this op, corresponding to operator_option.def + QUANT_INFO: quantization infomation of this op, corresponding to quant_info.def +*/ + + +/* tensor operators */ +DEF_OPERATOR(argmax, ARGMAX, ArgMax, Axis, None) +DEF_OPERATOR(avg_pool2d, AVG_POOL2D, AvgPool2d, Pool2d, Unary) +DEF_OPERATOR(conv2d, CONV2D, Conv2d, Conv2d, Conv) +DEF_OPERATOR(conv3d, CONV3D, Conv3d, None, None) +DEF_OPERATOR(depthwise_conv2d, DEPTHWISE_CONV2D, DepthwiseConv2d, Conv2d, Conv) +DEF_OPERATOR(fully_connected, FULLY_CONNECTED, FullyConnected, None, Conv) +DEF_OPERATOR(matmul, MATMUL, MatMul, None, MatMul) +DEF_OPERATOR(max_pool2d, MAX_POOL2D, MaxPool2d, Pool2d, None) +DEF_OPERATOR(transpose_conv2d, TRANSPOSE_CONV2D, TransposeConv2d, TransposeConv2d, Conv) + +/* activation */ +DEF_OPERATOR(clamp, CLAMP, Clamp, Clamp, None) +DEF_OPERATOR(reluN, RELUN, ReluN, ReluN, None) +DEF_OPERATOR(sigmoid, SIGMOID, Sigmoid, None, None) +DEF_OPERATOR(tanh, TANH, Tanh, None, None) + +/* elementwise - binary */ +DEF_OPERATOR(add, ADD, Add, None, None) +DEF_OPERATOR(arithmetic_right_shift, ARITHMETIC_RIGHT_SHIFT, ArithmeticRightShift, None, None) +DEF_OPERATOR(bitwise_and, BITWISE_AND, BitwiseAnd, None, None) +DEF_OPERATOR(bitwise_or, BITWISE_OR, BitwiseOr, None, None) +DEF_OPERATOR(bitwise_xor, BITWISE_XOR, BitwiseXor, None, None) +DEF_OPERATOR(logical_and, LOGICAL_AND, LogicalAnd, None, None) +DEF_OPERATOR(logical_left_shift, LOGICAL_LEFT_SHIFT, LogicalLeftShift, None, None) +DEF_OPERATOR(logical_right_shift, LOGICAL_RIGHT_SHIFT, LogicalRightShift, None, None) +DEF_OPERATOR(logical_or, LOGICAL_OR, LogicalOr, None, None) +DEF_OPERATOR(logical_xor, LOGICAL_XOR, LogicalXor, None, None) +DEF_OPERATOR(maximum, MAXIMUM, Maximum, None, None) +DEF_OPERATOR(minimum, MINIMUM, Minimum, None, None) +DEF_OPERATOR(mul, MUL, Mul, None, None) +DEF_OPERATOR(pow, POW, Pow, None, None) +DEF_OPERATOR(sub, SUB, Sub, None, None) +DEF_OPERATOR(table, TABLE, Table, None, None) + +/* elementwise - unary */ +DEF_OPERATOR(abs, ABS, Abs, None, None) +DEF_OPERATOR(bitwise_not, BITWISE_NOT, BitwiseNot, None, None) +DEF_OPERATOR(ceil, CEIL, Ceil, None, None) +DEF_OPERATOR(clz, CLZ, Clz, None, None) +DEF_OPERATOR(exp, EXP, Exp, None, None) +DEF_OPERATOR(floor, FLOOR, Floor, None, None) +DEF_OPERATOR(log, LOG, Log, None, None) +DEF_OPERATOR(logical_not, LOGICAL_NOT, LogicalNot, None, None) +DEF_OPERATOR(negate, NEGATE, Negate, None, Unary) +DEF_OPERATOR(reciprocal, RECIPROCAL, Reciprocal, None, None) +DEF_OPERATOR(rsqrt, RSQRT, Rsqrt, None, None) + +/* elementwise - ternary */ +DEF_OPERATOR(select, SELECT, Select, None, None) + +/* logical */ +DEF_OPERATOR(equal, EQUAL, Equal, None, None) +DEF_OPERATOR(greater, GREATER, Greater, None, None) +DEF_OPERATOR(greater_equal, GREATER_EQUAL, GreaterEqual, None, None) + +/* reduction */ +DEF_OPERATOR(reduce_any, REDUCE_ANY, ReduceAny, Reduce, None) +DEF_OPERATOR(reduce_all, REDUCE_ALL, ReduceAll, Reduce, None) +DEF_OPERATOR(reduce_max, REDUCE_MAX, ReduceMax, Reduce, None) +DEF_OPERATOR(reduce_min, REDUCE_MIN, ReduceMin, Reduce, None) +DEF_OPERATOR(reduce_prod, REDUCE_PRODUCT, ReduceProduct, Reduce, None) +DEF_OPERATOR(reduce_sum, REDUCE_SUM, ReduceSum, Reduce, None) + +/* memory operation */ +DEF_OPERATOR(concat, CONCAT, Concat, Axis, None) +DEF_OPERATOR(pad, PAD, Pad, None, Pad) +DEF_OPERATOR(reshape, RESHAPE, Reshape, Reshape, None) +DEF_OPERATOR(reverse, REVERSE, Reverse, Reverse, None) +DEF_OPERATOR(slice, SLICE, Slice, Slice, None) +DEF_OPERATOR(tile, TILE, Tile, Tile, None) +DEF_OPERATOR(transpose, TRANSPOSE, Transpose, None, None) + +/* gather/scatter */ +DEF_OPERATOR(gather, GATHER, Gather, Axis, None) + +/* image */ +DEF_OPERATOR(resize, RESIZE, Resize, Resize, None) + +/* quantization */ +DEF_OPERATOR(cast, CAST, Cast, None, None) +DEF_OPERATOR(rescale, RESCALE, Rescale, Rescale, None) + +/* data nodes */ +DEF_OPERATOR(const, CONST, Const, None, None) +DEF_OPERATOR(placeholder, PLACEHOLDER, Placeholder, None, None) +DEF_OPERATOR(identity, IDENTITY, Identity, None, None) +DEF_OPERATOR(identityn, IDENTITYN, IdentityN, None, None) + +/* custom operations */ +DEF_OPERATOR(custom, CUSTOM, Custom, None, None) + +/* control flow operators */ +DEF_OPERATOR(cond_if, COND_IF, CondIf, CondIf, None) +DEF_OPERATOR(while_loop, WHILE_LOOP, WhileLoop, WhileLoop, None) diff --git a/serialization/quant_info.def b/serialization/quant_info.def new file mode 100644 index 0000000..39dc101 --- /dev/null +++ b/serialization/quant_info.def @@ -0,0 +1,43 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + Syntax: + DEF_QUANTIZATION_INFO(NAME, NUM_ARGS_IN_OPTIONS, ARG0_TYPE, ARG0_SCALAR_OR_VECTOR, ARGS0_NAME, ...) + + Description: + NAME: corresponding quantization info name, must match corresponding "table XXXQuantInfo" in tosa.fbs + NUM_ARGS_IN_QINFO: number of arguments in this quantization info + ARG0_TYPE: data type of arg0 + ARG0_SCALAR_OR_VECTOR: is arg0 a scalar (S) or a vector (V) + ARG0_NAME: name of arg0 + ...: variadic variables for more arguments, depending on NUM_ARGS_IN_QINFO +*/ + + +DEF_QUANTIZATION_INFO(Unary, 2, + int32_t, S, input_zp, + int32_t, S, output_zp) + +DEF_QUANTIZATION_INFO(Conv, 2, + int32_t, S, input_zp, + int32_t, S, weight_zp) + +DEF_QUANTIZATION_INFO(MatMul, 2, + int32_t, S, a_zp, + int32_t, S, b_zp) + +DEF_QUANTIZATION_INFO(Pad, 1, + int32_t, S, input_zp) diff --git a/serialization/quant_info.h b/serialization/quant_info.h new file mode 100644 index 0000000..03dcab9 --- /dev/null +++ b/serialization/quant_info.h @@ -0,0 +1,164 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _TOSA_SERIALIZATION_QUANT_INFO_H +#define _TOSA_SERIALIZATION_QUANT_INFO_H +#include "flatbuffers/idl.h" +#include "flatbuffers/util.h" +#include "tosa_generated.h" + +namespace tosa +{ + +class TosaQuantInfoBase +{ +public: + virtual ~TosaQuantInfoBase() + {} +}; + +class TosaNoneQuantInfo : public TosaQuantInfoBase +{ +public: + TosaNoneQuantInfo() + {} + TosaNoneQuantInfo(TosaNoneQuantInfo* p) + {} +}; + +#define DEF_ARGS_VER0_S(T, V) _##V = p->V(); +#define DEF_ARGS_VER0_V(T, V) _##V = std::vector<T>(p->V()->begin(), p->V()->end()); +#define DEF_ARGS_VER1_S(T, V) const T& V +#define DEF_ARGS_VER1_V(T, V) const std::vector<T>& V +#define DEF_ARGS_VER2_S(T, V) _##V = V; +#define DEF_ARGS_VER2_V(T, V) _##V = V; +#define DEF_ARGS_VER3_S(T, V) \ + T V() const \ + { \ + return _##V; \ + } +#define DEF_ARGS_VER3_V(T, V) \ + std::vector<T> V() const \ + { \ + return _##V; \ + } +#define DEF_ARGS_VER4_S(T, V) T _##V; +#define DEF_ARGS_VER4_V(T, V) std::vector<T> _##V; + +// another level of preprocessor indirection to handle ", " as function's input argument +#define DEF_ARGS_VER1_TRUE(T, F, V) DEF_ARGS_VER1_##F(T, V) +#define DEF_ARGS_VER1_FALSE(T, F, V) , DEF_ARGS_VER1_##F(T, V) + +#define DEF_ARGS_VER0(FIRST, T, F, V) DEF_ARGS_VER0_##F(T, V) +#define DEF_ARGS_VER1(FIRST, T, F, V) DEF_ARGS_VER1_##FIRST(T, F, V) +#define DEF_ARGS_VER2(FIRST, T, F, V) DEF_ARGS_VER2_##F(T, V) +#define DEF_ARGS_VER3(FIRST, T, F, V) DEF_ARGS_VER3_##F(T, V) +#define DEF_ARGS_VER4(FIRST, T, F, V) DEF_ARGS_VER4_##F(T, V) + +#define DEF_ARGS_1(VER, T0, F0, V0) DEF_ARGS_##VER(TRUE, T0, F0, V0) +#define DEF_ARGS_2(VER, T0, F0, V0, T1, F1, V1) DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) +#define DEF_ARGS_3(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) +#define DEF_ARGS_4(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) +#define DEF_ARGS_5(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) +#define DEF_ARGS_6(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) +#define DEF_ARGS_7(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ + DEF_ARGS_##VER(FALSE, T6, F6, V6) +#define DEF_ARGS_8(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ + DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) +#define DEF_ARGS_9(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7, T8, F8, V8) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ + DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) DEF_ARGS_##VER(FALSE, T8, F8, V8) +#define DEF_ARGS_10(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7, T8, F8, V8, T9, F9, V9) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ + DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) DEF_ARGS_##VER(FALSE, T8, F8, V8) \ + DEF_ARGS_##VER(FALSE, T9, F9, V9) + +#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \ + class Tosa##NAME##QuantInfo : public TosaQuantInfoBase \ + { \ + public: \ + Tosa##NAME##QuantInfo(const TosaQuantInfoBase* qinfo) \ + { \ + const Tosa##NAME##QuantInfo* p = dynamic_cast<const Tosa##NAME##QuantInfo*>(qinfo); \ + assert(p); \ + *this = *p; \ + } \ + Tosa##NAME##QuantInfo(const Tosa##NAME##QuantInfo* p) \ + { \ + *this = *p; \ + } \ + Tosa##NAME##QuantInfo(const void* qinfo) \ + { \ + const NAME##QuantInfo* p = static_cast<const NAME##QuantInfo*>(qinfo); \ + DEF_ARGS_##NUM_ARGS(VER0, __VA_ARGS__) \ + } \ + Tosa##NAME##QuantInfo(DEF_ARGS_##NUM_ARGS(VER1, __VA_ARGS__)) \ + { \ + DEF_ARGS_##NUM_ARGS(VER2, __VA_ARGS__) \ + } \ + virtual ~Tosa##NAME##QuantInfo() \ + {} \ + DEF_ARGS_##NUM_ARGS(VER3, __VA_ARGS__) private : DEF_ARGS_##NUM_ARGS(VER4, __VA_ARGS__) \ + }; + +#include "quant_info.def" +#undef DEF_QUANTIZATION_INFO +#undef DEF_ARGS_1 +#undef DEF_ARGS_2 +#undef DEF_ARGS_3 +#undef DEF_ARGS_4 +#undef DEF_ARGS_5 +#undef DEF_ARGS_6 +#undef DEF_ARGS_7 +#undef DEF_ARGS_8 +#undef DEF_ARGS_9 +#undef DEF_ARGS_10 +#undef DEF_ARGS_VER0 +#undef DEF_ARGS_VER1 +#undef DEF_ARGS_VER2 +#undef DEF_ARGS_VER3 +#undef DEF_ARGS_VER4 +#undef DEF_ARGS_VER1_TRUE +#undef DEF_ARGS_VER1_FALSE +#undef DEF_ARGS_VER0_S +#undef DEF_ARGS_VER0_V +#undef DEF_ARGS_VER1_S +#undef DEF_ARGS_VER1_V +#undef DEF_ARGS_VER2_S +#undef DEF_ARGS_VER2_V +#undef DEF_ARGS_VER3_S +#undef DEF_ARGS_VER3_V +#undef DEF_ARGS_VER4_S +#undef DEF_ARGS_VER4_V + +} // namespace tosa + +#endif diff --git a/serialization/tosa.fbs b/serialization/tosa.fbs new file mode 100644 index 0000000..841cf3d --- /dev/null +++ b/serialization/tosa.fbs @@ -0,0 +1,318 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace tosa; + +// This corresponds to the version. +file_identifier "TOSA"; +// File extension of any written files. +file_extension "tosa"; + +enum DType:uint32 { + UNKNOWN = 0, + BOOL, + AINT8, + UINT8, + INT4, + INT8, + INT16, + INT32, + INT48, + FLOAT, +} + +enum Format:uint32 { + UNKNOWN = 0, + NHWC, + NDHWC, + OHWI, + HWIM, + DOHWI, +} + +enum Usage:uint32 { + UNKNOWN = 0, + ACTIVATION, + WEIGHT, + INDEX, +} + +enum ResizeMode:uint32 { + UNKNOWN = 0, + NEAREST, + BILINEAR, +} + +enum Op:uint32 { + UNKNOWN = 0, + + // Tensor Operator + ARGMAX, + AVG_POOL2D, + CONV2D, + CONV3D, + DEPTHWISE_CONV2D, + FULLY_CONNECTED, + MATMUL, + MAX_POOL2D, + TRANSPOSE_CONV2D, + + // Activation + CLAMP, + RELUN, + SIGMOID, + TANH, + + // Elementwise-Binary + ADD, + ARITHMETIC_RIGHT_SHIFT, + BITWISE_AND, + BITWISE_OR, + BITWISE_XOR, + LOGICAL_AND, + LOGICAL_LEFT_SHIFT, + LOGICAL_RIGHT_SHIFT, + LOGICAL_OR, + LOGICAL_XOR, + MAXIMUM, + MINIMUM, + MUL, + POW, + SUB, + TABLE, + + // Elementwise-Unary + ABS, + BITWISE_NOT, + CEIL, + CLZ, + EXP, + FLOOR, + LOG, + LOGICAL_NOT, + NEGATE, + RECIPROCAL, + RSQRT, + + // Elementwise-Ternary + SELECT, + + // Logical + EQUAL, + GREATER, + GREATER_EQUAL, + + // Reduction + REDUCE_ANY, + REDUCE_ALL, + REDUCE_MAX, + REDUCE_MIN, + REDUCE_PRODUCT, + REDUCE_SUM, + + // Data layout operation + CONCAT, + PAD, + RESHAPE, + REVERSE, + SLICE, + TILE, + TRANSPOSE, + + // Gather/scatter operation + GATHER, + + // Image + RESIZE, + + // Type conversion + CAST, + RESCALE, + + // Data Nodes + CONST, + PLACEHOLDER, + IDENTITY, + IDENTITYN, + + // Custom operations + CUSTOM, + + // Control flow operators + COND_IF, + WHILE_LOOP, +} + +union Attribute { + Pool2dAttribute, + Conv2dAttribute, + TransposeConv2dAttribute, + ReluNAttribute, + AxisAttribute, + ReshapeAttribute, + SliceAttribute, + TileAttribute, + ResizeAttribute, + ClampAttribute, + RescaleAttribute, + CustomAttribute, + CondIfAttribute, + WhileLoopAttribute, +} + +table Pool2dAttribute { + padding: [int32]; + kernel: [int32]; + stride: [int32]; +} + +table Conv2dAttribute { + padding: [int32]; + stride: [int32]; + dilation: [int32]; +} + +table TransposeConv2dAttribute { + outpad: [int32]; + stride: [int32]; + dilation: [int32]; + output_shape: [int32]; +} + +table ReluNAttribute { + max_int: int32; + max_fp: float; +} + +table AxisAttribute { + axis: int32; +} + +table ReshapeAttribute { + shape: [int32]; +} + +table SliceAttribute { + begin: [int32]; + size: [int32]; +} + +table TileAttribute { + multiples: [int32]; +} + +table ResizeAttribute { + output_size: [int32]; + stride: [int32]; + offset: [int32]; + shift: int32; + mode: ResizeMode; +} + +table ClampAttribute { + min_int: int32; + max_int: int32; + min_fp: float; + max_fp: float; +} + +table RescaleAttribute { + input_zp: int32; + output_zp: int32; + multiplier: [int32]; + shift: [int32]; + scale32: bool; + double_round: bool; + per_channel: bool; +} + +table CustomAttribute { + identifier: string; +} + +table CondIfAttribute { + then_branch: string; + else_branch: string; +} + +table WhileLoopAttribute { + cond_branch: string; + body_branch: string; +} + +union QuantInfo { + UnaryQuantInfo, + ConvQuantInfo, + MatMulQuantInfo, + PadQuantInfo, +} + +table UnaryQuantInfo { + input_zp: int32; + output_zp: int32; +} + +table ConvQuantInfo { + input_zp: int32; + weight_zp: int32; +} + +table MatMulQuantInfo { + a_zp: int32; + b_zp: int32; +} + +table PadQuantInfo { + input_zp: int32; +} + +table Version { + _major: int32 = 0; + _minor: int32 = 20; + _patch: int32 = 0; + _experimental: bool = false; +} + +table TosaTensor { + name:string; // name of the tensor, used for solving dependency + shape:[int32]; // shape of the tensor + type:DType; // data type of the tensor + usage:[Usage]; // vector of possible usages. for the convenience of debugging only. + format:[Format]; // vector of possible formats. for the convenience of debugging only. + npy_filename: string; // numpy array filename +} + +table TosaOperator { + op:Op; // operator enum + attribute: Attribute; // union structure. operator attribute + inputs:[string]; // list of input tensor names + outputs:[string]; // list of output tensor names + quant_info: QuantInfo; // op-based quantization information +} + +table TosaBasicBlock { + name:string; // basic block name + operators:[TosaOperator]; // operators array + tensors:[TosaTensor]; // tensors array + inputs:[string]; // name of graph inputs + outputs:[string]; // name of graph outputs +} + +table TosaGraph { + version: Version; + blocks:[TosaBasicBlock]; // basic blocks array +} + +root_type TosaGraph; diff --git a/serialization/tosa_generated.h b/serialization/tosa_generated.h new file mode 100644 index 0000000..5bb21f3 --- /dev/null +++ b/serialization/tosa_generated.h @@ -0,0 +1,2605 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_TOSA_TOSA_H_ +#define FLATBUFFERS_GENERATED_TOSA_TOSA_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace tosa { + +struct Pool2dAttribute; + +struct Conv2dAttribute; + +struct TransposeConv2dAttribute; + +struct ReluNAttribute; + +struct AxisAttribute; + +struct ReshapeAttribute; + +struct SliceAttribute; + +struct TileAttribute; + +struct ResizeAttribute; + +struct ClampAttribute; + +struct RescaleAttribute; + +struct CustomAttribute; + +struct CondIfAttribute; + +struct WhileLoopAttribute; + +struct UnaryQuantInfo; + +struct ConvQuantInfo; + +struct MatMulQuantInfo; + +struct PadQuantInfo; + +struct Version; + +struct TosaTensor; + +struct TosaOperator; + +struct TosaBasicBlock; + +struct TosaGraph; + +enum DType { + DType_UNKNOWN = 0, + DType_BOOL = 1, + DType_AINT8 = 2, + DType_UINT8 = 3, + DType_INT4 = 4, + DType_INT8 = 5, + DType_INT16 = 6, + DType_INT32 = 7, + DType_INT48 = 8, + DType_FLOAT = 9, + DType_MIN = DType_UNKNOWN, + DType_MAX = DType_FLOAT +}; + +inline const DType (&EnumValuesDType())[10] { + static const DType values[] = { + DType_UNKNOWN, + DType_BOOL, + DType_AINT8, + DType_UINT8, + DType_INT4, + DType_INT8, + DType_INT16, + DType_INT32, + DType_INT48, + DType_FLOAT + }; + return values; +} + +inline const char * const *EnumNamesDType() { + static const char * const names[] = { + "UNKNOWN", + "BOOL", + "AINT8", + "UINT8", + "INT4", + "INT8", + "INT16", + "INT32", + "INT48", + "FLOAT", + nullptr + }; + return names; +} + +inline const char *EnumNameDType(DType e) { + if (e < DType_UNKNOWN || e > DType_FLOAT) return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesDType()[index]; +} + +enum Format { + Format_UNKNOWN = 0, + Format_NHWC = 1, + Format_NDHWC = 2, + Format_OHWI = 3, + Format_HWIM = 4, + Format_DOHWI = 5, + Format_MIN = Format_UNKNOWN, + Format_MAX = Format_DOHWI +}; + +inline const Format (&EnumValuesFormat())[6] { + static const Format values[] = { + Format_UNKNOWN, + Format_NHWC, + Format_NDHWC, + Format_OHWI, + Format_HWIM, + Format_DOHWI + }; + return values; +} + +inline const char * const *EnumNamesFormat() { + static const char * const names[] = { + "UNKNOWN", + "NHWC", + "NDHWC", + "OHWI", + "HWIM", + "DOHWI", + nullptr + }; + return names; +} + +inline const char *EnumNameFormat(Format e) { + if (e < Format_UNKNOWN || e > Format_DOHWI) return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesFormat()[index]; +} + +enum Usage { + Usage_UNKNOWN = 0, + Usage_ACTIVATION = 1, + Usage_WEIGHT = 2, + Usage_INDEX = 3, + Usage_MIN = Usage_UNKNOWN, + Usage_MAX = Usage_INDEX +}; + +inline const Usage (&EnumValuesUsage())[4] { + static const Usage values[] = { + Usage_UNKNOWN, + Usage_ACTIVATION, + Usage_WEIGHT, + Usage_INDEX + }; + return values; +} + +inline const char * const *EnumNamesUsage() { + static const char * const names[] = { + "UNKNOWN", + "ACTIVATION", + "WEIGHT", + "INDEX", + nullptr + }; + return names; +} + +inline const char *EnumNameUsage(Usage e) { + if (e < Usage_UNKNOWN || e > Usage_INDEX) return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesUsage()[index]; +} + +enum ResizeMode { + ResizeMode_UNKNOWN = 0, + ResizeMode_NEAREST = 1, + ResizeMode_BILINEAR = 2, + ResizeMode_MIN = ResizeMode_UNKNOWN, + ResizeMode_MAX = ResizeMode_BILINEAR +}; + +inline const ResizeMode (&EnumValuesResizeMode())[3] { + static const ResizeMode values[] = { + ResizeMode_UNKNOWN, + ResizeMode_NEAREST, + ResizeMode_BILINEAR + }; + return values; +} + +inline const char * const *EnumNamesResizeMode() { + static const char * const names[] = { + "UNKNOWN", + "NEAREST", + "BILINEAR", + nullptr + }; + return names; +} + +inline const char *EnumNameResizeMode(ResizeMode e) { + if (e < ResizeMode_UNKNOWN || e > ResizeMode_BILINEAR) return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesResizeMode()[index]; +} + +enum Op { + Op_UNKNOWN = 0, + Op_ARGMAX = 1, + Op_AVG_POOL2D = 2, + Op_CONV2D = 3, + Op_CONV3D = 4, + Op_DEPTHWISE_CONV2D = 5, + Op_FULLY_CONNECTED = 6, + Op_MATMUL = 7, + Op_MAX_POOL2D = 8, + Op_TRANSPOSE_CONV2D = 9, + Op_CLAMP = 10, + Op_RELUN = 11, + Op_SIGMOID = 12, + Op_TANH = 13, + Op_ADD = 14, + Op_ARITHMETIC_RIGHT_SHIFT = 15, + Op_BITWISE_AND = 16, + Op_BITWISE_OR = 17, + Op_BITWISE_XOR = 18, + Op_LOGICAL_AND = 19, + Op_LOGICAL_LEFT_SHIFT = 20, + Op_LOGICAL_RIGHT_SHIFT = 21, + Op_LOGICAL_OR = 22, + Op_LOGICAL_XOR = 23, + Op_MAXIMUM = 24, + Op_MINIMUM = 25, + Op_MUL = 26, + Op_POW = 27, + Op_SUB = 28, + Op_TABLE = 29, + Op_ABS = 30, + Op_BITWISE_NOT = 31, + Op_CEIL = 32, + Op_CLZ = 33, + Op_EXP = 34, + Op_FLOOR = 35, + Op_LOG = 36, + Op_LOGICAL_NOT = 37, + Op_NEGATE = 38, + Op_RECIPROCAL = 39, + Op_RSQRT = 40, + Op_SELECT = 41, + Op_EQUAL = 42, + Op_GREATER = 43, + Op_GREATER_EQUAL = 44, + Op_REDUCE_ANY = 45, + Op_REDUCE_ALL = 46, + Op_REDUCE_MAX = 47, + Op_REDUCE_MIN = 48, + Op_REDUCE_PRODUCT = 49, + Op_REDUCE_SUM = 50, + Op_CONCAT = 51, + Op_PAD = 52, + Op_RESHAPE = 53, + Op_REVERSE = 54, + Op_SLICE = 55, + Op_TILE = 56, + Op_TRANSPOSE = 57, + Op_GATHER = 58, + Op_RESIZE = 59, + Op_CAST = 60, + Op_RESCALE = 61, + Op_CONST = 62, + Op_PLACEHOLDER = 63, + Op_IDENTITY = 64, + Op_IDENTITYN = 65, + Op_CUSTOM = 66, + Op_COND_IF = 67, + Op_WHILE_LOOP = 68, + Op_MIN = Op_UNKNOWN, + Op_MAX = Op_WHILE_LOOP +}; + +inline const Op (&EnumValuesOp())[69] { + static const Op values[] = { + Op_UNKNOWN, + Op_ARGMAX, + Op_AVG_POOL2D, + Op_CONV2D, + Op_CONV3D, + Op_DEPTHWISE_CONV2D, + Op_FULLY_CONNECTED, + Op_MATMUL, + Op_MAX_POOL2D, + Op_TRANSPOSE_CONV2D, + Op_CLAMP, + Op_RELUN, + Op_SIGMOID, + Op_TANH, + Op_ADD, + Op_ARITHMETIC_RIGHT_SHIFT, + Op_BITWISE_AND, + Op_BITWISE_OR, + Op_BITWISE_XOR, + Op_LOGICAL_AND, + Op_LOGICAL_LEFT_SHIFT, + Op_LOGICAL_RIGHT_SHIFT, + Op_LOGICAL_OR, + Op_LOGICAL_XOR, + Op_MAXIMUM, + Op_MINIMUM, + Op_MUL, + Op_POW, + Op_SUB, + Op_TABLE, + Op_ABS, + Op_BITWISE_NOT, + Op_CEIL, + Op_CLZ, + Op_EXP, + Op_FLOOR, + Op_LOG, + Op_LOGICAL_NOT, + Op_NEGATE, + Op_RECIPROCAL, + Op_RSQRT, + Op_SELECT, + Op_EQUAL, + Op_GREATER, + Op_GREATER_EQUAL, + Op_REDUCE_ANY, + Op_REDUCE_ALL, + Op_REDUCE_MAX, + Op_REDUCE_MIN, + Op_REDUCE_PRODUCT, + Op_REDUCE_SUM, + Op_CONCAT, + Op_PAD, + Op_RESHAPE, + Op_REVERSE, + Op_SLICE, + Op_TILE, + Op_TRANSPOSE, + Op_GATHER, + Op_RESIZE, + Op_CAST, + Op_RESCALE, + Op_CONST, + Op_PLACEHOLDER, + Op_IDENTITY, + Op_IDENTITYN, + Op_CUSTOM, + Op_COND_IF, + Op_WHILE_LOOP + }; + return values; +} + +inline const char * const *EnumNamesOp() { + static const char * const names[] = { + "UNKNOWN", + "ARGMAX", + "AVG_POOL2D", + "CONV2D", + "CONV3D", + "DEPTHWISE_CONV2D", + "FULLY_CONNECTED", + "MATMUL", + "MAX_POOL2D", + "TRANSPOSE_CONV2D", + "CLAMP", + "RELUN", + "SIGMOID", + "TANH", + "ADD", + "ARITHMETIC_RIGHT_SHIFT", + "BITWISE_AND", + "BITWISE_OR", + "BITWISE_XOR", + "LOGICAL_AND", + "LOGICAL_LEFT_SHIFT", + "LOGICAL_RIGHT_SHIFT", + "LOGICAL_OR", + "LOGICAL_XOR", + "MAXIMUM", + "MINIMUM", + "MUL", + "POW", + "SUB", + "TABLE", + "ABS", + "BITWISE_NOT", + "CEIL", + "CLZ", + "EXP", + "FLOOR", + "LOG", + "LOGICAL_NOT", + "NEGATE", + "RECIPROCAL", + "RSQRT", + "SELECT", + "EQUAL", + "GREATER", + "GREATER_EQUAL", + "REDUCE_ANY", + "REDUCE_ALL", + "REDUCE_MAX", + "REDUCE_MIN", + "REDUCE_PRODUCT", + "REDUCE_SUM", + "CONCAT", + "PAD", + "RESHAPE", + "REVERSE", + "SLICE", + "TILE", + "TRANSPOSE", + "GATHER", + "RESIZE", + "CAST", + "RESCALE", + "CONST", + "PLACEHOLDER", + "IDENTITY", + "IDENTITYN", + "CUSTOM", + "COND_IF", + "WHILE_LOOP", + nullptr + }; + return names; +} + +inline const char *EnumNameOp(Op e) { + if (e < Op_UNKNOWN || e > Op_WHILE_LOOP) return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesOp()[index]; +} + +enum Attribute { + Attribute_NONE = 0, + Attribute_Pool2dAttribute = 1, + Attribute_Conv2dAttribute = 2, + Attribute_TransposeConv2dAttribute = 3, + Attribute_ReluNAttribute = 4, + Attribute_AxisAttribute = 5, + Attribute_ReshapeAttribute = 6, + Attribute_SliceAttribute = 7, + Attribute_TileAttribute = 8, + Attribute_ResizeAttribute = 9, + Attribute_ClampAttribute = 10, + Attribute_RescaleAttribute = 11, + Attribute_CustomAttribute = 12, + Attribute_CondIfAttribute = 13, + Attribute_WhileLoopAttribute = 14, + Attribute_MIN = Attribute_NONE, + Attribute_MAX = Attribute_WhileLoopAttribute +}; + +inline const Attribute (&EnumValuesAttribute())[15] { + static const Attribute values[] = { + Attribute_NONE, + Attribute_Pool2dAttribute, + Attribute_Conv2dAttribute, + Attribute_TransposeConv2dAttribute, + Attribute_ReluNAttribute, + Attribute_AxisAttribute, + Attribute_ReshapeAttribute, + Attribute_SliceAttribute, + Attribute_TileAttribute, + Attribute_ResizeAttribute, + Attribute_ClampAttribute, + Attribute_RescaleAttribute, + Attribute_CustomAttribute, + Attribute_CondIfAttribute, + Attribute_WhileLoopAttribute + }; + return values; +} + +inline const char * const *EnumNamesAttribute() { + static const char * const names[] = { + "NONE", + "Pool2dAttribute", + "Conv2dAttribute", + "TransposeConv2dAttribute", + "ReluNAttribute", + "AxisAttribute", + "ReshapeAttribute", + "SliceAttribute", + "TileAttribute", + "ResizeAttribute", + "ClampAttribute", + "RescaleAttribute", + "CustomAttribute", + "CondIfAttribute", + "WhileLoopAttribute", + nullptr + }; + return names; +} + +inline const char *EnumNameAttribute(Attribute e) { + if (e < Attribute_NONE || e > Attribute_WhileLoopAttribute) return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesAttribute()[index]; +} + +template<typename T> struct AttributeTraits { + static const Attribute enum_value = Attribute_NONE; +}; + +template<> struct AttributeTraits<Pool2dAttribute> { + static const Attribute enum_value = Attribute_Pool2dAttribute; +}; + +template<> struct AttributeTraits<Conv2dAttribute> { + static const Attribute enum_value = Attribute_Conv2dAttribute; +}; + +template<> struct AttributeTraits<TransposeConv2dAttribute> { + static const Attribute enum_value = Attribute_TransposeConv2dAttribute; +}; + +template<> struct AttributeTraits<ReluNAttribute> { + static const Attribute enum_value = Attribute_ReluNAttribute; +}; + +template<> struct AttributeTraits<AxisAttribute> { + static const Attribute enum_value = Attribute_AxisAttribute; +}; + +template<> struct AttributeTraits<ReshapeAttribute> { + static const Attribute enum_value = Attribute_ReshapeAttribute; +}; + +template<> struct AttributeTraits<SliceAttribute> { + static const Attribute enum_value = Attribute_SliceAttribute; +}; + +template<> struct AttributeTraits<TileAttribute> { + static const Attribute enum_value = Attribute_TileAttribute; +}; + +template<> struct AttributeTraits<ResizeAttribute> { + static const Attribute enum_value = Attribute_ResizeAttribute; +}; + +template<> struct AttributeTraits<ClampAttribute> { + static const Attribute enum_value = Attribute_ClampAttribute; +}; + +template<> struct AttributeTraits<RescaleAttribute> { + static const Attribute enum_value = Attribute_RescaleAttribute; +}; + +template<> struct AttributeTraits<CustomAttribute> { + static const Attribute enum_value = Attribute_CustomAttribute; +}; + +template<> struct AttributeTraits<CondIfAttribute> { + static const Attribute enum_value = Attribute_CondIfAttribute; +}; + +template<> struct AttributeTraits<WhileLoopAttribute> { + static const Attribute enum_value = Attribute_WhileLoopAttribute; +}; + +bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type); +bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types); + +enum QuantInfo { + QuantInfo_NONE = 0, + QuantInfo_UnaryQuantInfo = 1, + QuantInfo_ConvQuantInfo = 2, + QuantInfo_MatMulQuantInfo = 3, + QuantInfo_PadQuantInfo = 4, + QuantInfo_MIN = QuantInfo_NONE, + QuantInfo_MAX = QuantInfo_PadQuantInfo +}; + +inline const QuantInfo (&EnumValuesQuantInfo())[5] { + static const QuantInfo values[] = { + QuantInfo_NONE, + QuantInfo_UnaryQuantInfo, + QuantInfo_ConvQuantInfo, + QuantInfo_MatMulQuantInfo, + QuantInfo_PadQuantInfo + }; + return values; +} + +inline const char * const *EnumNamesQuantInfo() { + static const char * const names[] = { + "NONE", + "UnaryQuantInfo", + "ConvQuantInfo", + "MatMulQuantInfo", + "PadQuantInfo", + nullptr + }; + return names; +} + +inline const char *EnumNameQuantInfo(QuantInfo e) { + if (e < QuantInfo_NONE || e > QuantInfo_PadQuantInfo) return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesQuantInfo()[index]; +} + +template<typename T> struct QuantInfoTraits { + static const QuantInfo enum_value = QuantInfo_NONE; +}; + +template<> struct QuantInfoTraits<UnaryQuantInfo> { + static const QuantInfo enum_value = QuantInfo_UnaryQuantInfo; +}; + +template<> struct QuantInfoTraits<ConvQuantInfo> { + static const QuantInfo enum_value = QuantInfo_ConvQuantInfo; +}; + +template<> struct QuantInfoTraits<MatMulQuantInfo> { + static const QuantInfo enum_value = QuantInfo_MatMulQuantInfo; +}; + +template<> struct QuantInfoTraits<PadQuantInfo> { + static const QuantInfo enum_value = QuantInfo_PadQuantInfo; +}; + +bool VerifyQuantInfo(flatbuffers::Verifier &verifier, const void *obj, QuantInfo type); +bool VerifyQuantInfoVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types); + +struct Pool2dAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PADDING = 4, + VT_KERNEL = 6, + VT_STRIDE = 8 + }; + const flatbuffers::Vector<int32_t> *padding() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PADDING); + } + const flatbuffers::Vector<int32_t> *kernel() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_KERNEL); + } + const flatbuffers::Vector<int32_t> *stride() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_PADDING) && + verifier.VerifyVector(padding()) && + VerifyOffset(verifier, VT_KERNEL) && + verifier.VerifyVector(kernel()) && + VerifyOffset(verifier, VT_STRIDE) && + verifier.VerifyVector(stride()) && + verifier.EndTable(); + } +}; + +struct Pool2dAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_padding(flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding) { + fbb_.AddOffset(Pool2dAttribute::VT_PADDING, padding); + } + void add_kernel(flatbuffers::Offset<flatbuffers::Vector<int32_t>> kernel) { + fbb_.AddOffset(Pool2dAttribute::VT_KERNEL, kernel); + } + void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) { + fbb_.AddOffset(Pool2dAttribute::VT_STRIDE, stride); + } + explicit Pool2dAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + Pool2dAttributeBuilder &operator=(const Pool2dAttributeBuilder &); + flatbuffers::Offset<Pool2dAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<Pool2dAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<Pool2dAttribute> CreatePool2dAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> kernel = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0) { + Pool2dAttributeBuilder builder_(_fbb); + builder_.add_stride(stride); + builder_.add_kernel(kernel); + builder_.add_padding(padding); + return builder_.Finish(); +} + +inline flatbuffers::Offset<Pool2dAttribute> CreatePool2dAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *padding = nullptr, + const std::vector<int32_t> *kernel = nullptr, + const std::vector<int32_t> *stride = nullptr) { + auto padding__ = padding ? _fbb.CreateVector<int32_t>(*padding) : 0; + auto kernel__ = kernel ? _fbb.CreateVector<int32_t>(*kernel) : 0; + auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0; + return tosa::CreatePool2dAttribute( + _fbb, + padding__, + kernel__, + stride__); +} + +struct Conv2dAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PADDING = 4, + VT_STRIDE = 6, + VT_DILATION = 8 + }; + const flatbuffers::Vector<int32_t> *padding() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PADDING); + } + const flatbuffers::Vector<int32_t> *stride() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE); + } + const flatbuffers::Vector<int32_t> *dilation() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_DILATION); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_PADDING) && + verifier.VerifyVector(padding()) && + VerifyOffset(verifier, VT_STRIDE) && + verifier.VerifyVector(stride()) && + VerifyOffset(verifier, VT_DILATION) && + verifier.VerifyVector(dilation()) && + verifier.EndTable(); + } +}; + +struct Conv2dAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_padding(flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding) { + fbb_.AddOffset(Conv2dAttribute::VT_PADDING, padding); + } + void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) { + fbb_.AddOffset(Conv2dAttribute::VT_STRIDE, stride); + } + void add_dilation(flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation) { + fbb_.AddOffset(Conv2dAttribute::VT_DILATION, dilation); + } + explicit Conv2dAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + Conv2dAttributeBuilder &operator=(const Conv2dAttributeBuilder &); + flatbuffers::Offset<Conv2dAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<Conv2dAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<Conv2dAttribute> CreateConv2dAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation = 0) { + Conv2dAttributeBuilder builder_(_fbb); + builder_.add_dilation(dilation); + builder_.add_stride(stride); + builder_.add_padding(padding); + return builder_.Finish(); +} + +inline flatbuffers::Offset<Conv2dAttribute> CreateConv2dAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *padding = nullptr, + const std::vector<int32_t> *stride = nullptr, + const std::vector<int32_t> *dilation = nullptr) { + auto padding__ = padding ? _fbb.CreateVector<int32_t>(*padding) : 0; + auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0; + auto dilation__ = dilation ? _fbb.CreateVector<int32_t>(*dilation) : 0; + return tosa::CreateConv2dAttribute( + _fbb, + padding__, + stride__, + dilation__); +} + +struct TransposeConv2dAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OUTPAD = 4, + VT_STRIDE = 6, + VT_DILATION = 8, + VT_OUTPUT_SHAPE = 10 + }; + const flatbuffers::Vector<int32_t> *outpad() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OUTPAD); + } + const flatbuffers::Vector<int32_t> *stride() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE); + } + const flatbuffers::Vector<int32_t> *dilation() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_DILATION); + } + const flatbuffers::Vector<int32_t> *output_shape() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OUTPUT_SHAPE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_OUTPAD) && + verifier.VerifyVector(outpad()) && + VerifyOffset(verifier, VT_STRIDE) && + verifier.VerifyVector(stride()) && + VerifyOffset(verifier, VT_DILATION) && + verifier.VerifyVector(dilation()) && + VerifyOffset(verifier, VT_OUTPUT_SHAPE) && + verifier.VerifyVector(output_shape()) && + verifier.EndTable(); + } +}; + +struct TransposeConv2dAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_outpad(flatbuffers::Offset<flatbuffers::Vector<int32_t>> outpad) { + fbb_.AddOffset(TransposeConv2dAttribute::VT_OUTPAD, outpad); + } + void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) { + fbb_.AddOffset(TransposeConv2dAttribute::VT_STRIDE, stride); + } + void add_dilation(flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation) { + fbb_.AddOffset(TransposeConv2dAttribute::VT_DILATION, dilation); + } + void add_output_shape(flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_shape) { + fbb_.AddOffset(TransposeConv2dAttribute::VT_OUTPUT_SHAPE, output_shape); + } + explicit TransposeConv2dAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TransposeConv2dAttributeBuilder &operator=(const TransposeConv2dAttributeBuilder &); + flatbuffers::Offset<TransposeConv2dAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<TransposeConv2dAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<TransposeConv2dAttribute> CreateTransposeConv2dAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> outpad = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_shape = 0) { + TransposeConv2dAttributeBuilder builder_(_fbb); + builder_.add_output_shape(output_shape); + builder_.add_dilation(dilation); + builder_.add_stride(stride); + builder_.add_outpad(outpad); + return builder_.Finish(); +} + +inline flatbuffers::Offset<TransposeConv2dAttribute> CreateTransposeConv2dAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *outpad = nullptr, + const std::vector<int32_t> *stride = nullptr, + const std::vector<int32_t> *dilation = nullptr, + const std::vector<int32_t> *output_shape = nullptr) { + auto outpad__ = outpad ? _fbb.CreateVector<int32_t>(*outpad) : 0; + auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0; + auto dilation__ = dilation ? _fbb.CreateVector<int32_t>(*dilation) : 0; + auto output_shape__ = output_shape ? _fbb.CreateVector<int32_t>(*output_shape) : 0; + return tosa::CreateTransposeConv2dAttribute( + _fbb, + outpad__, + stride__, + dilation__, + output_shape__); +} + +struct ReluNAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_MAX_INT = 4, + VT_MAX_FP = 6 + }; + int32_t max_int() const { + return GetField<int32_t>(VT_MAX_INT, 0); + } + float max_fp() const { + return GetField<float>(VT_MAX_FP, 0.0f); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_MAX_INT) && + VerifyField<float>(verifier, VT_MAX_FP) && + verifier.EndTable(); + } +}; + +struct ReluNAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_max_int(int32_t max_int) { + fbb_.AddElement<int32_t>(ReluNAttribute::VT_MAX_INT, max_int, 0); + } + void add_max_fp(float max_fp) { + fbb_.AddElement<float>(ReluNAttribute::VT_MAX_FP, max_fp, 0.0f); + } + explicit ReluNAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ReluNAttributeBuilder &operator=(const ReluNAttributeBuilder &); + flatbuffers::Offset<ReluNAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<ReluNAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<ReluNAttribute> CreateReluNAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t max_int = 0, + float max_fp = 0.0f) { + ReluNAttributeBuilder builder_(_fbb); + builder_.add_max_fp(max_fp); + builder_.add_max_int(max_int); + return builder_.Finish(); +} + +struct AxisAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_AXIS = 4 + }; + int32_t axis() const { + return GetField<int32_t>(VT_AXIS, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_AXIS) && + verifier.EndTable(); + } +}; + +struct AxisAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement<int32_t>(AxisAttribute::VT_AXIS, axis, 0); + } + explicit AxisAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + AxisAttributeBuilder &operator=(const AxisAttributeBuilder &); + flatbuffers::Offset<AxisAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<AxisAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<AxisAttribute> CreateAxisAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0) { + AxisAttributeBuilder builder_(_fbb); + builder_.add_axis(axis); + return builder_.Finish(); +} + +struct ReshapeAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SHAPE = 4 + }; + const flatbuffers::Vector<int32_t> *shape() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_SHAPE) && + verifier.VerifyVector(shape()) && + verifier.EndTable(); + } +}; + +struct ReshapeAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_shape(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape) { + fbb_.AddOffset(ReshapeAttribute::VT_SHAPE, shape); + } + explicit ReshapeAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ReshapeAttributeBuilder &operator=(const ReshapeAttributeBuilder &); + flatbuffers::Offset<ReshapeAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<ReshapeAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<ReshapeAttribute> CreateReshapeAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape = 0) { + ReshapeAttributeBuilder builder_(_fbb); + builder_.add_shape(shape); + return builder_.Finish(); +} + +inline flatbuffers::Offset<ReshapeAttribute> CreateReshapeAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *shape = nullptr) { + auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0; + return tosa::CreateReshapeAttribute( + _fbb, + shape__); +} + +struct SliceAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_BEGIN = 4, + VT_SIZE = 6 + }; + const flatbuffers::Vector<int32_t> *begin() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_BEGIN); + } + const flatbuffers::Vector<int32_t> *size() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SIZE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_BEGIN) && + verifier.VerifyVector(begin()) && + VerifyOffset(verifier, VT_SIZE) && + verifier.VerifyVector(size()) && + verifier.EndTable(); + } +}; + +struct SliceAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_begin(flatbuffers::Offset<flatbuffers::Vector<int32_t>> begin) { + fbb_.AddOffset(SliceAttribute::VT_BEGIN, begin); + } + void add_size(flatbuffers::Offset<flatbuffers::Vector<int32_t>> size) { + fbb_.AddOffset(SliceAttribute::VT_SIZE, size); + } + explicit SliceAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SliceAttributeBuilder &operator=(const SliceAttributeBuilder &); + flatbuffers::Offset<SliceAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<SliceAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<SliceAttribute> CreateSliceAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> begin = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> size = 0) { + SliceAttributeBuilder builder_(_fbb); + builder_.add_size(size); + builder_.add_begin(begin); + return builder_.Finish(); +} + +inline flatbuffers::Offset<SliceAttribute> CreateSliceAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *begin = nullptr, + const std::vector<int32_t> *size = nullptr) { + auto begin__ = begin ? _fbb.CreateVector<int32_t>(*begin) : 0; + auto size__ = size ? _fbb.CreateVector<int32_t>(*size) : 0; + return tosa::CreateSliceAttribute( + _fbb, + begin__, + size__); +} + +struct TileAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_MULTIPLES = 4 + }; + const flatbuffers::Vector<int32_t> *multiples() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_MULTIPLES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_MULTIPLES) && + verifier.VerifyVector(multiples()) && + verifier.EndTable(); + } +}; + +struct TileAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_multiples(flatbuffers::Offset<flatbuffers::Vector<int32_t>> multiples) { + fbb_.AddOffset(TileAttribute::VT_MULTIPLES, multiples); + } + explicit TileAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TileAttributeBuilder &operator=(const TileAttributeBuilder &); + flatbuffers::Offset<TileAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<TileAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<TileAttribute> CreateTileAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> multiples = 0) { + TileAttributeBuilder builder_(_fbb); + builder_.add_multiples(multiples); + return builder_.Finish(); +} + +inline flatbuffers::Offset<TileAttribute> CreateTileAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *multiples = nullptr) { + auto multiples__ = multiples ? _fbb.CreateVector<int32_t>(*multiples) : 0; + return tosa::CreateTileAttribute( + _fbb, + multiples__); +} + +struct ResizeAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OUTPUT_SIZE = 4, + VT_STRIDE = 6, + VT_OFFSET = 8, + VT_SHIFT = 10, + VT_MODE = 12 + }; + const flatbuffers::Vector<int32_t> *output_size() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OUTPUT_SIZE); + } + const flatbuffers::Vector<int32_t> *stride() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE); + } + const flatbuffers::Vector<int32_t> *offset() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OFFSET); + } + int32_t shift() const { + return GetField<int32_t>(VT_SHIFT, 0); + } + ResizeMode mode() const { + return static_cast<ResizeMode>(GetField<uint32_t>(VT_MODE, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_OUTPUT_SIZE) && + verifier.VerifyVector(output_size()) && + VerifyOffset(verifier, VT_STRIDE) && + verifier.VerifyVector(stride()) && + VerifyOffset(verifier, VT_OFFSET) && + verifier.VerifyVector(offset()) && + VerifyField<int32_t>(verifier, VT_SHIFT) && + VerifyField<uint32_t>(verifier, VT_MODE) && + verifier.EndTable(); + } +}; + +struct ResizeAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_output_size(flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_size) { + fbb_.AddOffset(ResizeAttribute::VT_OUTPUT_SIZE, output_size); + } + void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) { + fbb_.AddOffset(ResizeAttribute::VT_STRIDE, stride); + } + void add_offset(flatbuffers::Offset<flatbuffers::Vector<int32_t>> offset) { + fbb_.AddOffset(ResizeAttribute::VT_OFFSET, offset); + } + void add_shift(int32_t shift) { + fbb_.AddElement<int32_t>(ResizeAttribute::VT_SHIFT, shift, 0); + } + void add_mode(ResizeMode mode) { + fbb_.AddElement<uint32_t>(ResizeAttribute::VT_MODE, static_cast<uint32_t>(mode), 0); + } + explicit ResizeAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ResizeAttributeBuilder &operator=(const ResizeAttributeBuilder &); + flatbuffers::Offset<ResizeAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<ResizeAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<ResizeAttribute> CreateResizeAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_size = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> offset = 0, + int32_t shift = 0, + ResizeMode mode = ResizeMode_UNKNOWN) { + ResizeAttributeBuilder builder_(_fbb); + builder_.add_mode(mode); + builder_.add_shift(shift); + builder_.add_offset(offset); + builder_.add_stride(stride); + builder_.add_output_size(output_size); + return builder_.Finish(); +} + +inline flatbuffers::Offset<ResizeAttribute> CreateResizeAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *output_size = nullptr, + const std::vector<int32_t> *stride = nullptr, + const std::vector<int32_t> *offset = nullptr, + int32_t shift = 0, + ResizeMode mode = ResizeMode_UNKNOWN) { + auto output_size__ = output_size ? _fbb.CreateVector<int32_t>(*output_size) : 0; + auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0; + auto offset__ = offset ? _fbb.CreateVector<int32_t>(*offset) : 0; + return tosa::CreateResizeAttribute( + _fbb, + output_size__, + stride__, + offset__, + shift, + mode); +} + +struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_MIN_INT = 4, + VT_MAX_INT = 6, + VT_MIN_FP = 8, + VT_MAX_FP = 10 + }; + int32_t min_int() const { + return GetField<int32_t>(VT_MIN_INT, 0); + } + int32_t max_int() const { + return GetField<int32_t>(VT_MAX_INT, 0); + } + float min_fp() const { + return GetField<float>(VT_MIN_FP, 0.0f); + } + float max_fp() const { + return GetField<float>(VT_MAX_FP, 0.0f); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_MIN_INT) && + VerifyField<int32_t>(verifier, VT_MAX_INT) && + VerifyField<float>(verifier, VT_MIN_FP) && + VerifyField<float>(verifier, VT_MAX_FP) && + verifier.EndTable(); + } +}; + +struct ClampAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_min_int(int32_t min_int) { + fbb_.AddElement<int32_t>(ClampAttribute::VT_MIN_INT, min_int, 0); + } + void add_max_int(int32_t max_int) { + fbb_.AddElement<int32_t>(ClampAttribute::VT_MAX_INT, max_int, 0); + } + void add_min_fp(float min_fp) { + fbb_.AddElement<float>(ClampAttribute::VT_MIN_FP, min_fp, 0.0f); + } + void add_max_fp(float max_fp) { + fbb_.AddElement<float>(ClampAttribute::VT_MAX_FP, max_fp, 0.0f); + } + explicit ClampAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ClampAttributeBuilder &operator=(const ClampAttributeBuilder &); + flatbuffers::Offset<ClampAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<ClampAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<ClampAttribute> CreateClampAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t min_int = 0, + int32_t max_int = 0, + float min_fp = 0.0f, + float max_fp = 0.0f) { + ClampAttributeBuilder builder_(_fbb); + builder_.add_max_fp(max_fp); + builder_.add_min_fp(min_fp); + builder_.add_max_int(max_int); + builder_.add_min_int(min_int); + return builder_.Finish(); +} + +struct RescaleAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INPUT_ZP = 4, + VT_OUTPUT_ZP = 6, + VT_MULTIPLIER = 8, + VT_SHIFT = 10, + VT_SCALE32 = 12, + VT_DOUBLE_ROUND = 14, + VT_PER_CHANNEL = 16 + }; + int32_t input_zp() const { + return GetField<int32_t>(VT_INPUT_ZP, 0); + } + int32_t output_zp() const { + return GetField<int32_t>(VT_OUTPUT_ZP, 0); + } + const flatbuffers::Vector<int32_t> *multiplier() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_MULTIPLIER); + } + const flatbuffers::Vector<int32_t> *shift() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHIFT); + } + bool scale32() const { + return GetField<uint8_t>(VT_SCALE32, 0) != 0; + } + bool double_round() const { + return GetField<uint8_t>(VT_DOUBLE_ROUND, 0) != 0; + } + bool per_channel() const { + return GetField<uint8_t>(VT_PER_CHANNEL, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_INPUT_ZP) && + VerifyField<int32_t>(verifier, VT_OUTPUT_ZP) && + VerifyOffset(verifier, VT_MULTIPLIER) && + verifier.VerifyVector(multiplier()) && + VerifyOffset(verifier, VT_SHIFT) && + verifier.VerifyVector(shift()) && + VerifyField<uint8_t>(verifier, VT_SCALE32) && + VerifyField<uint8_t>(verifier, VT_DOUBLE_ROUND) && + VerifyField<uint8_t>(verifier, VT_PER_CHANNEL) && + verifier.EndTable(); + } +}; + +struct RescaleAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_input_zp(int32_t input_zp) { + fbb_.AddElement<int32_t>(RescaleAttribute::VT_INPUT_ZP, input_zp, 0); + } + void add_output_zp(int32_t output_zp) { + fbb_.AddElement<int32_t>(RescaleAttribute::VT_OUTPUT_ZP, output_zp, 0); + } + void add_multiplier(flatbuffers::Offset<flatbuffers::Vector<int32_t>> multiplier) { + fbb_.AddOffset(RescaleAttribute::VT_MULTIPLIER, multiplier); + } + void add_shift(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shift) { + fbb_.AddOffset(RescaleAttribute::VT_SHIFT, shift); + } + void add_scale32(bool scale32) { + fbb_.AddElement<uint8_t>(RescaleAttribute::VT_SCALE32, static_cast<uint8_t>(scale32), 0); + } + void add_double_round(bool double_round) { + fbb_.AddElement<uint8_t>(RescaleAttribute::VT_DOUBLE_ROUND, static_cast<uint8_t>(double_round), 0); + } + void add_per_channel(bool per_channel) { + fbb_.AddElement<uint8_t>(RescaleAttribute::VT_PER_CHANNEL, static_cast<uint8_t>(per_channel), 0); + } + explicit RescaleAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + RescaleAttributeBuilder &operator=(const RescaleAttributeBuilder &); + flatbuffers::Offset<RescaleAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<RescaleAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<RescaleAttribute> CreateRescaleAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t input_zp = 0, + int32_t output_zp = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> multiplier = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> shift = 0, + bool scale32 = false, + bool double_round = false, + bool per_channel = false) { + RescaleAttributeBuilder builder_(_fbb); + builder_.add_shift(shift); + builder_.add_multiplier(multiplier); + builder_.add_output_zp(output_zp); + builder_.add_input_zp(input_zp); + builder_.add_per_channel(per_channel); + builder_.add_double_round(double_round); + builder_.add_scale32(scale32); + return builder_.Finish(); +} + +inline flatbuffers::Offset<RescaleAttribute> CreateRescaleAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t input_zp = 0, + int32_t output_zp = 0, + const std::vector<int32_t> *multiplier = nullptr, + const std::vector<int32_t> *shift = nullptr, + bool scale32 = false, + bool double_round = false, + bool per_channel = false) { + auto multiplier__ = multiplier ? _fbb.CreateVector<int32_t>(*multiplier) : 0; + auto shift__ = shift ? _fbb.CreateVector<int32_t>(*shift) : 0; + return tosa::CreateRescaleAttribute( + _fbb, + input_zp, + output_zp, + multiplier__, + shift__, + scale32, + double_round, + per_channel); +} + +struct CustomAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_IDENTIFIER = 4 + }; + const flatbuffers::String *identifier() const { + return GetPointer<const flatbuffers::String *>(VT_IDENTIFIER); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_IDENTIFIER) && + verifier.VerifyString(identifier()) && + verifier.EndTable(); + } +}; + +struct CustomAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_identifier(flatbuffers::Offset<flatbuffers::String> identifier) { + fbb_.AddOffset(CustomAttribute::VT_IDENTIFIER, identifier); + } + explicit CustomAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CustomAttributeBuilder &operator=(const CustomAttributeBuilder &); + flatbuffers::Offset<CustomAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<CustomAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<CustomAttribute> CreateCustomAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::String> identifier = 0) { + CustomAttributeBuilder builder_(_fbb); + builder_.add_identifier(identifier); + return builder_.Finish(); +} + +inline flatbuffers::Offset<CustomAttribute> CreateCustomAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *identifier = nullptr) { + auto identifier__ = identifier ? _fbb.CreateString(identifier) : 0; + return tosa::CreateCustomAttribute( + _fbb, + identifier__); +} + +struct CondIfAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_THEN_BRANCH = 4, + VT_ELSE_BRANCH = 6 + }; + const flatbuffers::String *then_branch() const { + return GetPointer<const flatbuffers::String *>(VT_THEN_BRANCH); + } + const flatbuffers::String *else_branch() const { + return GetPointer<const flatbuffers::String *>(VT_ELSE_BRANCH); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_THEN_BRANCH) && + verifier.VerifyString(then_branch()) && + VerifyOffset(verifier, VT_ELSE_BRANCH) && + verifier.VerifyString(else_branch()) && + verifier.EndTable(); + } +}; + +struct CondIfAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_then_branch(flatbuffers::Offset<flatbuffers::String> then_branch) { + fbb_.AddOffset(CondIfAttribute::VT_THEN_BRANCH, then_branch); + } + void add_else_branch(flatbuffers::Offset<flatbuffers::String> else_branch) { + fbb_.AddOffset(CondIfAttribute::VT_ELSE_BRANCH, else_branch); + } + explicit CondIfAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CondIfAttributeBuilder &operator=(const CondIfAttributeBuilder &); + flatbuffers::Offset<CondIfAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<CondIfAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<CondIfAttribute> CreateCondIfAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::String> then_branch = 0, + flatbuffers::Offset<flatbuffers::String> else_branch = 0) { + CondIfAttributeBuilder builder_(_fbb); + builder_.add_else_branch(else_branch); + builder_.add_then_branch(then_branch); + return builder_.Finish(); +} + +inline flatbuffers::Offset<CondIfAttribute> CreateCondIfAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *then_branch = nullptr, + const char *else_branch = nullptr) { + auto then_branch__ = then_branch ? _fbb.CreateString(then_branch) : 0; + auto else_branch__ = else_branch ? _fbb.CreateString(else_branch) : 0; + return tosa::CreateCondIfAttribute( + _fbb, + then_branch__, + else_branch__); +} + +struct WhileLoopAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_COND_BRANCH = 4, + VT_BODY_BRANCH = 6 + }; + const flatbuffers::String *cond_branch() const { + return GetPointer<const flatbuffers::String *>(VT_COND_BRANCH); + } + const flatbuffers::String *body_branch() const { + return GetPointer<const flatbuffers::String *>(VT_BODY_BRANCH); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_COND_BRANCH) && + verifier.VerifyString(cond_branch()) && + VerifyOffset(verifier, VT_BODY_BRANCH) && + verifier.VerifyString(body_branch()) && + verifier.EndTable(); + } +}; + +struct WhileLoopAttributeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_cond_branch(flatbuffers::Offset<flatbuffers::String> cond_branch) { + fbb_.AddOffset(WhileLoopAttribute::VT_COND_BRANCH, cond_branch); + } + void add_body_branch(flatbuffers::Offset<flatbuffers::String> body_branch) { + fbb_.AddOffset(WhileLoopAttribute::VT_BODY_BRANCH, body_branch); + } + explicit WhileLoopAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + WhileLoopAttributeBuilder &operator=(const WhileLoopAttributeBuilder &); + flatbuffers::Offset<WhileLoopAttribute> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<WhileLoopAttribute>(end); + return o; + } +}; + +inline flatbuffers::Offset<WhileLoopAttribute> CreateWhileLoopAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::String> cond_branch = 0, + flatbuffers::Offset<flatbuffers::String> body_branch = 0) { + WhileLoopAttributeBuilder builder_(_fbb); + builder_.add_body_branch(body_branch); + builder_.add_cond_branch(cond_branch); + return builder_.Finish(); +} + +inline flatbuffers::Offset<WhileLoopAttribute> CreateWhileLoopAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *cond_branch = nullptr, + const char *body_branch = nullptr) { + auto cond_branch__ = cond_branch ? _fbb.CreateString(cond_branch) : 0; + auto body_branch__ = body_branch ? _fbb.CreateString(body_branch) : 0; + return tosa::CreateWhileLoopAttribute( + _fbb, + cond_branch__, + body_branch__); +} + +struct UnaryQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INPUT_ZP = 4, + VT_OUTPUT_ZP = 6 + }; + int32_t input_zp() const { + return GetField<int32_t>(VT_INPUT_ZP, 0); + } + int32_t output_zp() const { + return GetField<int32_t>(VT_OUTPUT_ZP, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_INPUT_ZP) && + VerifyField<int32_t>(verifier, VT_OUTPUT_ZP) && + verifier.EndTable(); + } +}; + +struct UnaryQuantInfoBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_input_zp(int32_t input_zp) { + fbb_.AddElement<int32_t>(UnaryQuantInfo::VT_INPUT_ZP, input_zp, 0); + } + void add_output_zp(int32_t output_zp) { + fbb_.AddElement<int32_t>(UnaryQuantInfo::VT_OUTPUT_ZP, output_zp, 0); + } + explicit UnaryQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + UnaryQuantInfoBuilder &operator=(const UnaryQuantInfoBuilder &); + flatbuffers::Offset<UnaryQuantInfo> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<UnaryQuantInfo>(end); + return o; + } +}; + +inline flatbuffers::Offset<UnaryQuantInfo> CreateUnaryQuantInfo( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t input_zp = 0, + int32_t output_zp = 0) { + UnaryQuantInfoBuilder builder_(_fbb); + builder_.add_output_zp(output_zp); + builder_.add_input_zp(input_zp); + return builder_.Finish(); +} + +struct ConvQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INPUT_ZP = 4, + VT_WEIGHT_ZP = 6 + }; + int32_t input_zp() const { + return GetField<int32_t>(VT_INPUT_ZP, 0); + } + int32_t weight_zp() const { + return GetField<int32_t>(VT_WEIGHT_ZP, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_INPUT_ZP) && + VerifyField<int32_t>(verifier, VT_WEIGHT_ZP) && + verifier.EndTable(); + } +}; + +struct ConvQuantInfoBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_input_zp(int32_t input_zp) { + fbb_.AddElement<int32_t>(ConvQuantInfo::VT_INPUT_ZP, input_zp, 0); + } + void add_weight_zp(int32_t weight_zp) { + fbb_.AddElement<int32_t>(ConvQuantInfo::VT_WEIGHT_ZP, weight_zp, 0); + } + explicit ConvQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConvQuantInfoBuilder &operator=(const ConvQuantInfoBuilder &); + flatbuffers::Offset<ConvQuantInfo> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<ConvQuantInfo>(end); + return o; + } +}; + +inline flatbuffers::Offset<ConvQuantInfo> CreateConvQuantInfo( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t input_zp = 0, + int32_t weight_zp = 0) { + ConvQuantInfoBuilder builder_(_fbb); + builder_.add_weight_zp(weight_zp); + builder_.add_input_zp(input_zp); + return builder_.Finish(); +} + +struct MatMulQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_A_ZP = 4, + VT_B_ZP = 6 + }; + int32_t a_zp() const { + return GetField<int32_t>(VT_A_ZP, 0); + } + int32_t b_zp() const { + return GetField<int32_t>(VT_B_ZP, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_A_ZP) && + VerifyField<int32_t>(verifier, VT_B_ZP) && + verifier.EndTable(); + } +}; + +struct MatMulQuantInfoBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_a_zp(int32_t a_zp) { + fbb_.AddElement<int32_t>(MatMulQuantInfo::VT_A_ZP, a_zp, 0); + } + void add_b_zp(int32_t b_zp) { + fbb_.AddElement<int32_t>(MatMulQuantInfo::VT_B_ZP, b_zp, 0); + } + explicit MatMulQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + MatMulQuantInfoBuilder &operator=(const MatMulQuantInfoBuilder &); + flatbuffers::Offset<MatMulQuantInfo> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<MatMulQuantInfo>(end); + return o; + } +}; + +inline flatbuffers::Offset<MatMulQuantInfo> CreateMatMulQuantInfo( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t a_zp = 0, + int32_t b_zp = 0) { + MatMulQuantInfoBuilder builder_(_fbb); + builder_.add_b_zp(b_zp); + builder_.add_a_zp(a_zp); + return builder_.Finish(); +} + +struct PadQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INPUT_ZP = 4 + }; + int32_t input_zp() const { + return GetField<int32_t>(VT_INPUT_ZP, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_INPUT_ZP) && + verifier.EndTable(); + } +}; + +struct PadQuantInfoBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_input_zp(int32_t input_zp) { + fbb_.AddElement<int32_t>(PadQuantInfo::VT_INPUT_ZP, input_zp, 0); + } + explicit PadQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + PadQuantInfoBuilder &operator=(const PadQuantInfoBuilder &); + flatbuffers::Offset<PadQuantInfo> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<PadQuantInfo>(end); + return o; + } +}; + +inline flatbuffers::Offset<PadQuantInfo> CreatePadQuantInfo( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t input_zp = 0) { + PadQuantInfoBuilder builder_(_fbb); + builder_.add_input_zp(input_zp); + return builder_.Finish(); +} + +struct Version FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT__MAJOR = 4, + VT__MINOR = 6, + VT__PATCH = 8, + VT__EXPERIMENTAL = 10 + }; + int32_t _major() const { + return GetField<int32_t>(VT__MAJOR, 0); + } + int32_t _minor() const { + return GetField<int32_t>(VT__MINOR, 20); + } + int32_t _patch() const { + return GetField<int32_t>(VT__PATCH, 0); + } + bool _experimental() const { + return GetField<uint8_t>(VT__EXPERIMENTAL, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT__MAJOR) && + VerifyField<int32_t>(verifier, VT__MINOR) && + VerifyField<int32_t>(verifier, VT__PATCH) && + VerifyField<uint8_t>(verifier, VT__EXPERIMENTAL) && + verifier.EndTable(); + } +}; + +struct VersionBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add__major(int32_t _major) { + fbb_.AddElement<int32_t>(Version::VT__MAJOR, _major, 0); + } + void add__minor(int32_t _minor) { + fbb_.AddElement<int32_t>(Version::VT__MINOR, _minor, 20); + } + void add__patch(int32_t _patch) { + fbb_.AddElement<int32_t>(Version::VT__PATCH, _patch, 0); + } + void add__experimental(bool _experimental) { + fbb_.AddElement<uint8_t>(Version::VT__EXPERIMENTAL, static_cast<uint8_t>(_experimental), 0); + } + explicit VersionBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + VersionBuilder &operator=(const VersionBuilder &); + flatbuffers::Offset<Version> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<Version>(end); + return o; + } +}; + +inline flatbuffers::Offset<Version> CreateVersion( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t _major = 0, + int32_t _minor = 20, + int32_t _patch = 0, + bool _experimental = false) { + VersionBuilder builder_(_fbb); + builder_.add__patch(_patch); + builder_.add__minor(_minor); + builder_.add__major(_major); + builder_.add__experimental(_experimental); + return builder_.Finish(); +} + +struct TosaTensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_SHAPE = 6, + VT_TYPE = 8, + VT_USAGE = 10, + VT_FORMAT = 12, + VT_NPY_FILENAME = 14 + }; + const flatbuffers::String *name() const { + return GetPointer<const flatbuffers::String *>(VT_NAME); + } + const flatbuffers::Vector<int32_t> *shape() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE); + } + DType type() const { + return static_cast<DType>(GetField<uint32_t>(VT_TYPE, 0)); + } + const flatbuffers::Vector<uint32_t> *usage() const { + return GetPointer<const flatbuffers::Vector<uint32_t> *>(VT_USAGE); + } + const flatbuffers::Vector<uint32_t> *format() const { + return GetPointer<const flatbuffers::Vector<uint32_t> *>(VT_FORMAT); + } + const flatbuffers::String *npy_filename() const { + return GetPointer<const flatbuffers::String *>(VT_NPY_FILENAME); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_SHAPE) && + verifier.VerifyVector(shape()) && + VerifyField<uint32_t>(verifier, VT_TYPE) && + VerifyOffset(verifier, VT_USAGE) && + verifier.VerifyVector(usage()) && + VerifyOffset(verifier, VT_FORMAT) && + verifier.VerifyVector(format()) && + VerifyOffset(verifier, VT_NPY_FILENAME) && + verifier.VerifyString(npy_filename()) && + verifier.EndTable(); + } +}; + +struct TosaTensorBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset<flatbuffers::String> name) { + fbb_.AddOffset(TosaTensor::VT_NAME, name); + } + void add_shape(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape) { + fbb_.AddOffset(TosaTensor::VT_SHAPE, shape); + } + void add_type(DType type) { + fbb_.AddElement<uint32_t>(TosaTensor::VT_TYPE, static_cast<uint32_t>(type), 0); + } + void add_usage(flatbuffers::Offset<flatbuffers::Vector<uint32_t>> usage) { + fbb_.AddOffset(TosaTensor::VT_USAGE, usage); + } + void add_format(flatbuffers::Offset<flatbuffers::Vector<uint32_t>> format) { + fbb_.AddOffset(TosaTensor::VT_FORMAT, format); + } + void add_npy_filename(flatbuffers::Offset<flatbuffers::String> npy_filename) { + fbb_.AddOffset(TosaTensor::VT_NPY_FILENAME, npy_filename); + } + explicit TosaTensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TosaTensorBuilder &operator=(const TosaTensorBuilder &); + flatbuffers::Offset<TosaTensor> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<TosaTensor>(end); + return o; + } +}; + +inline flatbuffers::Offset<TosaTensor> CreateTosaTensor( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::String> name = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape = 0, + DType type = DType_UNKNOWN, + flatbuffers::Offset<flatbuffers::Vector<uint32_t>> usage = 0, + flatbuffers::Offset<flatbuffers::Vector<uint32_t>> format = 0, + flatbuffers::Offset<flatbuffers::String> npy_filename = 0) { + TosaTensorBuilder builder_(_fbb); + builder_.add_npy_filename(npy_filename); + builder_.add_format(format); + builder_.add_usage(usage); + builder_.add_type(type); + builder_.add_shape(shape); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + const std::vector<int32_t> *shape = nullptr, + DType type = DType_UNKNOWN, + const std::vector<uint32_t> *usage = nullptr, + const std::vector<uint32_t> *format = nullptr, + const char *npy_filename = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0; + auto usage__ = usage ? _fbb.CreateVector<uint32_t>(*usage) : 0; + auto format__ = format ? _fbb.CreateVector<uint32_t>(*format) : 0; + auto npy_filename__ = npy_filename ? _fbb.CreateString(npy_filename) : 0; + return tosa::CreateTosaTensor( + _fbb, + name__, + shape__, + type, + usage__, + format__, + npy_filename__); +} + +struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OP = 4, + VT_ATTRIBUTE_TYPE = 6, + VT_ATTRIBUTE = 8, + VT_INPUTS = 10, + VT_OUTPUTS = 12, + VT_QUANT_INFO_TYPE = 14, + VT_QUANT_INFO = 16 + }; + Op op() const { + return static_cast<Op>(GetField<uint32_t>(VT_OP, 0)); + } + Attribute attribute_type() const { + return static_cast<Attribute>(GetField<uint8_t>(VT_ATTRIBUTE_TYPE, 0)); + } + const void *attribute() const { + return GetPointer<const void *>(VT_ATTRIBUTE); + } + template<typename T> const T *attribute_as() const; + const Pool2dAttribute *attribute_as_Pool2dAttribute() const { + return attribute_type() == Attribute_Pool2dAttribute ? static_cast<const Pool2dAttribute *>(attribute()) : nullptr; + } + const Conv2dAttribute *attribute_as_Conv2dAttribute() const { + return attribute_type() == Attribute_Conv2dAttribute ? static_cast<const Conv2dAttribute *>(attribute()) : nullptr; + } + const TransposeConv2dAttribute *attribute_as_TransposeConv2dAttribute() const { + return attribute_type() == Attribute_TransposeConv2dAttribute ? static_cast<const TransposeConv2dAttribute *>(attribute()) : nullptr; + } + const ReluNAttribute *attribute_as_ReluNAttribute() const { + return attribute_type() == Attribute_ReluNAttribute ? static_cast<const ReluNAttribute *>(attribute()) : nullptr; + } + const AxisAttribute *attribute_as_AxisAttribute() const { + return attribute_type() == Attribute_AxisAttribute ? static_cast<const AxisAttribute *>(attribute()) : nullptr; + } + const ReshapeAttribute *attribute_as_ReshapeAttribute() const { + return attribute_type() == Attribute_ReshapeAttribute ? static_cast<const ReshapeAttribute *>(attribute()) : nullptr; + } + const SliceAttribute *attribute_as_SliceAttribute() const { + return attribute_type() == Attribute_SliceAttribute ? static_cast<const SliceAttribute *>(attribute()) : nullptr; + } + const TileAttribute *attribute_as_TileAttribute() const { + return attribute_type() == Attribute_TileAttribute ? static_cast<const TileAttribute *>(attribute()) : nullptr; + } + const ResizeAttribute *attribute_as_ResizeAttribute() const { + return attribute_type() == Attribute_ResizeAttribute ? static_cast<const ResizeAttribute *>(attribute()) : nullptr; + } + const ClampAttribute *attribute_as_ClampAttribute() const { + return attribute_type() == Attribute_ClampAttribute ? static_cast<const ClampAttribute *>(attribute()) : nullptr; + } + const RescaleAttribute *attribute_as_RescaleAttribute() const { + return attribute_type() == Attribute_RescaleAttribute ? static_cast<const RescaleAttribute *>(attribute()) : nullptr; + } + const CustomAttribute *attribute_as_CustomAttribute() const { + return attribute_type() == Attribute_CustomAttribute ? static_cast<const CustomAttribute *>(attribute()) : nullptr; + } + const CondIfAttribute *attribute_as_CondIfAttribute() const { + return attribute_type() == Attribute_CondIfAttribute ? static_cast<const CondIfAttribute *>(attribute()) : nullptr; + } + const WhileLoopAttribute *attribute_as_WhileLoopAttribute() const { + return attribute_type() == Attribute_WhileLoopAttribute ? static_cast<const WhileLoopAttribute *>(attribute()) : nullptr; + } + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *inputs() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_INPUTS); + } + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *outputs() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OUTPUTS); + } + QuantInfo quant_info_type() const { + return static_cast<QuantInfo>(GetField<uint8_t>(VT_QUANT_INFO_TYPE, 0)); + } + const void *quant_info() const { + return GetPointer<const void *>(VT_QUANT_INFO); + } + template<typename T> const T *quant_info_as() const; + const UnaryQuantInfo *quant_info_as_UnaryQuantInfo() const { + return quant_info_type() == QuantInfo_UnaryQuantInfo ? static_cast<const UnaryQuantInfo *>(quant_info()) : nullptr; + } + const ConvQuantInfo *quant_info_as_ConvQuantInfo() const { + return quant_info_type() == QuantInfo_ConvQuantInfo ? static_cast<const ConvQuantInfo *>(quant_info()) : nullptr; + } + const MatMulQuantInfo *quant_info_as_MatMulQuantInfo() const { + return quant_info_type() == QuantInfo_MatMulQuantInfo ? static_cast<const MatMulQuantInfo *>(quant_info()) : nullptr; + } + const PadQuantInfo *quant_info_as_PadQuantInfo() const { + return quant_info_type() == QuantInfo_PadQuantInfo ? static_cast<const PadQuantInfo *>(quant_info()) : nullptr; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<uint32_t>(verifier, VT_OP) && + VerifyField<uint8_t>(verifier, VT_ATTRIBUTE_TYPE) && + VerifyOffset(verifier, VT_ATTRIBUTE) && + VerifyAttribute(verifier, attribute(), attribute_type()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && + verifier.VerifyVectorOfStrings(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + verifier.VerifyVectorOfStrings(outputs()) && + VerifyField<uint8_t>(verifier, VT_QUANT_INFO_TYPE) && + VerifyOffset(verifier, VT_QUANT_INFO) && + VerifyQuantInfo(verifier, quant_info(), quant_info_type()) && + verifier.EndTable(); + } +}; + +template<> inline const Pool2dAttribute *TosaOperator::attribute_as<Pool2dAttribute>() const { + return attribute_as_Pool2dAttribute(); +} + +template<> inline const Conv2dAttribute *TosaOperator::attribute_as<Conv2dAttribute>() const { + return attribute_as_Conv2dAttribute(); +} + +template<> inline const TransposeConv2dAttribute *TosaOperator::attribute_as<TransposeConv2dAttribute>() const { + return attribute_as_TransposeConv2dAttribute(); +} + +template<> inline const ReluNAttribute *TosaOperator::attribute_as<ReluNAttribute>() const { + return attribute_as_ReluNAttribute(); +} + +template<> inline const AxisAttribute *TosaOperator::attribute_as<AxisAttribute>() const { + return attribute_as_AxisAttribute(); +} + +template<> inline const ReshapeAttribute *TosaOperator::attribute_as<ReshapeAttribute>() const { + return attribute_as_ReshapeAttribute(); +} + +template<> inline const SliceAttribute *TosaOperator::attribute_as<SliceAttribute>() const { + return attribute_as_SliceAttribute(); +} + +template<> inline const TileAttribute *TosaOperator::attribute_as<TileAttribute>() const { + return attribute_as_TileAttribute(); +} + +template<> inline const ResizeAttribute *TosaOperator::attribute_as<ResizeAttribute>() const { + return attribute_as_ResizeAttribute(); +} + +template<> inline const ClampAttribute *TosaOperator::attribute_as<ClampAttribute>() const { + return attribute_as_ClampAttribute(); +} + +template<> inline const RescaleAttribute *TosaOperator::attribute_as<RescaleAttribute>() const { + return attribute_as_RescaleAttribute(); +} + +template<> inline const CustomAttribute *TosaOperator::attribute_as<CustomAttribute>() const { + return attribute_as_CustomAttribute(); +} + +template<> inline const CondIfAttribute *TosaOperator::attribute_as<CondIfAttribute>() const { + return attribute_as_CondIfAttribute(); +} + +template<> inline const WhileLoopAttribute *TosaOperator::attribute_as<WhileLoopAttribute>() const { + return attribute_as_WhileLoopAttribute(); +} + +template<> inline const UnaryQuantInfo *TosaOperator::quant_info_as<UnaryQuantInfo>() const { + return quant_info_as_UnaryQuantInfo(); +} + +template<> inline const ConvQuantInfo *TosaOperator::quant_info_as<ConvQuantInfo>() const { + return quant_info_as_ConvQuantInfo(); +} + +template<> inline const MatMulQuantInfo *TosaOperator::quant_info_as<MatMulQuantInfo>() const { + return quant_info_as_MatMulQuantInfo(); +} + +template<> inline const PadQuantInfo *TosaOperator::quant_info_as<PadQuantInfo>() const { + return quant_info_as_PadQuantInfo(); +} + +struct TosaOperatorBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_op(Op op) { + fbb_.AddElement<uint32_t>(TosaOperator::VT_OP, static_cast<uint32_t>(op), 0); + } + void add_attribute_type(Attribute attribute_type) { + fbb_.AddElement<uint8_t>(TosaOperator::VT_ATTRIBUTE_TYPE, static_cast<uint8_t>(attribute_type), 0); + } + void add_attribute(flatbuffers::Offset<void> attribute) { + fbb_.AddOffset(TosaOperator::VT_ATTRIBUTE, attribute); + } + void add_inputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs) { + fbb_.AddOffset(TosaOperator::VT_INPUTS, inputs); + } + void add_outputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs) { + fbb_.AddOffset(TosaOperator::VT_OUTPUTS, outputs); + } + void add_quant_info_type(QuantInfo quant_info_type) { + fbb_.AddElement<uint8_t>(TosaOperator::VT_QUANT_INFO_TYPE, static_cast<uint8_t>(quant_info_type), 0); + } + void add_quant_info(flatbuffers::Offset<void> quant_info) { + fbb_.AddOffset(TosaOperator::VT_QUANT_INFO, quant_info); + } + explicit TosaOperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TosaOperatorBuilder &operator=(const TosaOperatorBuilder &); + flatbuffers::Offset<TosaOperator> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<TosaOperator>(end); + return o; + } +}; + +inline flatbuffers::Offset<TosaOperator> CreateTosaOperator( + flatbuffers::FlatBufferBuilder &_fbb, + Op op = Op_UNKNOWN, + Attribute attribute_type = Attribute_NONE, + flatbuffers::Offset<void> attribute = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs = 0, + QuantInfo quant_info_type = QuantInfo_NONE, + flatbuffers::Offset<void> quant_info = 0) { + TosaOperatorBuilder builder_(_fbb); + builder_.add_quant_info(quant_info); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_attribute(attribute); + builder_.add_op(op); + builder_.add_quant_info_type(quant_info_type); + builder_.add_attribute_type(attribute_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset<TosaOperator> CreateTosaOperatorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + Op op = Op_UNKNOWN, + Attribute attribute_type = Attribute_NONE, + flatbuffers::Offset<void> attribute = 0, + const std::vector<flatbuffers::Offset<flatbuffers::String>> *inputs = nullptr, + const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputs = nullptr, + QuantInfo quant_info_type = QuantInfo_NONE, + flatbuffers::Offset<void> quant_info = 0) { + auto inputs__ = inputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*outputs) : 0; + return tosa::CreateTosaOperator( + _fbb, + op, + attribute_type, + attribute, + inputs__, + outputs__, + quant_info_type, + quant_info); +} + +struct TosaBasicBlock FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_OPERATORS = 6, + VT_TENSORS = 8, + VT_INPUTS = 10, + VT_OUTPUTS = 12 + }; + const flatbuffers::String *name() const { + return GetPointer<const flatbuffers::String *>(VT_NAME); + } + const flatbuffers::Vector<flatbuffers::Offset<TosaOperator>> *operators() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TosaOperator>> *>(VT_OPERATORS); + } + const flatbuffers::Vector<flatbuffers::Offset<TosaTensor>> *tensors() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TosaTensor>> *>(VT_TENSORS); + } + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *inputs() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_INPUTS); + } + const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *outputs() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OUTPUTS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_OPERATORS) && + verifier.VerifyVector(operators()) && + verifier.VerifyVectorOfTables(operators()) && + VerifyOffset(verifier, VT_TENSORS) && + verifier.VerifyVector(tensors()) && + verifier.VerifyVectorOfTables(tensors()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && + verifier.VerifyVectorOfStrings(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + verifier.VerifyVectorOfStrings(outputs()) && + verifier.EndTable(); + } +}; + +struct TosaBasicBlockBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset<flatbuffers::String> name) { + fbb_.AddOffset(TosaBasicBlock::VT_NAME, name); + } + void add_operators(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaOperator>>> operators) { + fbb_.AddOffset(TosaBasicBlock::VT_OPERATORS, operators); + } + void add_tensors(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaTensor>>> tensors) { + fbb_.AddOffset(TosaBasicBlock::VT_TENSORS, tensors); + } + void add_inputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs) { + fbb_.AddOffset(TosaBasicBlock::VT_INPUTS, inputs); + } + void add_outputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs) { + fbb_.AddOffset(TosaBasicBlock::VT_OUTPUTS, outputs); + } + explicit TosaBasicBlockBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TosaBasicBlockBuilder &operator=(const TosaBasicBlockBuilder &); + flatbuffers::Offset<TosaBasicBlock> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<TosaBasicBlock>(end); + return o; + } +}; + +inline flatbuffers::Offset<TosaBasicBlock> CreateTosaBasicBlock( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::String> name = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaOperator>>> operators = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaTensor>>> tensors = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs = 0) { + TosaBasicBlockBuilder builder_(_fbb); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_tensors(tensors); + builder_.add_operators(operators); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset<TosaBasicBlock> CreateTosaBasicBlockDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + const std::vector<flatbuffers::Offset<TosaOperator>> *operators = nullptr, + const std::vector<flatbuffers::Offset<TosaTensor>> *tensors = nullptr, + const std::vector<flatbuffers::Offset<flatbuffers::String>> *inputs = nullptr, + const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputs = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto operators__ = operators ? _fbb.CreateVector<flatbuffers::Offset<TosaOperator>>(*operators) : 0; + auto tensors__ = tensors ? _fbb.CreateVector<flatbuffers::Offset<TosaTensor>>(*tensors) : 0; + auto inputs__ = inputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*outputs) : 0; + return tosa::CreateTosaBasicBlock( + _fbb, + name__, + operators__, + tensors__, + inputs__, + outputs__); +} + +struct TosaGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VERSION = 4, + VT_BLOCKS = 6 + }; + const Version *version() const { + return GetPointer<const Version *>(VT_VERSION); + } + const flatbuffers::Vector<flatbuffers::Offset<TosaBasicBlock>> *blocks() const { + return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TosaBasicBlock>> *>(VT_BLOCKS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_VERSION) && + verifier.VerifyTable(version()) && + VerifyOffset(verifier, VT_BLOCKS) && + verifier.VerifyVector(blocks()) && + verifier.VerifyVectorOfTables(blocks()) && + verifier.EndTable(); + } +}; + +struct TosaGraphBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_version(flatbuffers::Offset<Version> version) { + fbb_.AddOffset(TosaGraph::VT_VERSION, version); + } + void add_blocks(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaBasicBlock>>> blocks) { + fbb_.AddOffset(TosaGraph::VT_BLOCKS, blocks); + } + explicit TosaGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TosaGraphBuilder &operator=(const TosaGraphBuilder &); + flatbuffers::Offset<TosaGraph> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<TosaGraph>(end); + return o; + } +}; + +inline flatbuffers::Offset<TosaGraph> CreateTosaGraph( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<Version> version = 0, + flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaBasicBlock>>> blocks = 0) { + TosaGraphBuilder builder_(_fbb); + builder_.add_blocks(blocks); + builder_.add_version(version); + return builder_.Finish(); +} + +inline flatbuffers::Offset<TosaGraph> CreateTosaGraphDirect( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<Version> version = 0, + const std::vector<flatbuffers::Offset<TosaBasicBlock>> *blocks = nullptr) { + auto blocks__ = blocks ? _fbb.CreateVector<flatbuffers::Offset<TosaBasicBlock>>(*blocks) : 0; + return tosa::CreateTosaGraph( + _fbb, + version, + blocks__); +} + +inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type) { + switch (type) { + case Attribute_NONE: { + return true; + } + case Attribute_Pool2dAttribute: { + auto ptr = reinterpret_cast<const Pool2dAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_Conv2dAttribute: { + auto ptr = reinterpret_cast<const Conv2dAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_TransposeConv2dAttribute: { + auto ptr = reinterpret_cast<const TransposeConv2dAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_ReluNAttribute: { + auto ptr = reinterpret_cast<const ReluNAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_AxisAttribute: { + auto ptr = reinterpret_cast<const AxisAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_ReshapeAttribute: { + auto ptr = reinterpret_cast<const ReshapeAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_SliceAttribute: { + auto ptr = reinterpret_cast<const SliceAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_TileAttribute: { + auto ptr = reinterpret_cast<const TileAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_ResizeAttribute: { + auto ptr = reinterpret_cast<const ResizeAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_ClampAttribute: { + auto ptr = reinterpret_cast<const ClampAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_RescaleAttribute: { + auto ptr = reinterpret_cast<const RescaleAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_CustomAttribute: { + auto ptr = reinterpret_cast<const CustomAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_CondIfAttribute: { + auto ptr = reinterpret_cast<const CondIfAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_WhileLoopAttribute: { + auto ptr = reinterpret_cast<const WhileLoopAttribute *>(obj); + return verifier.VerifyTable(ptr); + } + default: return false; + } +} + +inline bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyAttribute( + verifier, values->Get(i), types->GetEnum<Attribute>(i))) { + return false; + } + } + return true; +} + +inline bool VerifyQuantInfo(flatbuffers::Verifier &verifier, const void *obj, QuantInfo type) { + switch (type) { + case QuantInfo_NONE: { + return true; + } + case QuantInfo_UnaryQuantInfo: { + auto ptr = reinterpret_cast<const UnaryQuantInfo *>(obj); + return verifier.VerifyTable(ptr); + } + case QuantInfo_ConvQuantInfo: { + auto ptr = reinterpret_cast<const ConvQuantInfo *>(obj); + return verifier.VerifyTable(ptr); + } + case QuantInfo_MatMulQuantInfo: { + auto ptr = reinterpret_cast<const MatMulQuantInfo *>(obj); + return verifier.VerifyTable(ptr); + } + case QuantInfo_PadQuantInfo: { + auto ptr = reinterpret_cast<const PadQuantInfo *>(obj); + return verifier.VerifyTable(ptr); + } + default: return false; + } +} + +inline bool VerifyQuantInfoVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyQuantInfo( + verifier, values->Get(i), types->GetEnum<QuantInfo>(i))) { + return false; + } + } + return true; +} + +inline const tosa::TosaGraph *GetTosaGraph(const void *buf) { + return flatbuffers::GetRoot<tosa::TosaGraph>(buf); +} + +inline const tosa::TosaGraph *GetSizePrefixedTosaGraph(const void *buf) { + return flatbuffers::GetSizePrefixedRoot<tosa::TosaGraph>(buf); +} + +inline const char *TosaGraphIdentifier() { + return "TOSA"; +} + +inline bool TosaGraphBufferHasIdentifier(const void *buf) { + return flatbuffers::BufferHasIdentifier( + buf, TosaGraphIdentifier()); +} + +inline bool VerifyTosaGraphBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer<tosa::TosaGraph>(TosaGraphIdentifier()); +} + +inline bool VerifySizePrefixedTosaGraphBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer<tosa::TosaGraph>(TosaGraphIdentifier()); +} + +inline const char *TosaGraphExtension() { + return "tosa"; +} + +inline void FinishTosaGraphBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset<tosa::TosaGraph> root) { + fbb.Finish(root, TosaGraphIdentifier()); +} + +inline void FinishSizePrefixedTosaGraphBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset<tosa::TosaGraph> root) { + fbb.FinishSizePrefixed(root, TosaGraphIdentifier()); +} + +} // namespace tosa + +#endif // FLATBUFFERS_GENERATED_TOSA_TOSA_H_ diff --git a/serialization/tosa_serialization_handler.cpp b/serialization/tosa_serialization_handler.cpp new file mode 100644 index 0000000..7fe9f47 --- /dev/null +++ b/serialization/tosa_serialization_handler.cpp @@ -0,0 +1,1526 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tosa_serialization_handler.h" + +#include <iostream> +using namespace tosa; + +TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name, + const flatbuffers::Vector<uint32_t>& usage, + const flatbuffers::Vector<int32_t>& shape, + DType dtype, + const flatbuffers::Vector<uint32_t>& format, + const flatbuffers::String* npy_filename) +{ + _dtype = dtype; + + _usage = new std::vector<Usage>(usage.size()); + for (uint32_t us : usage) + { + _usage->push_back((Usage)us); + } + assert(_usage); + + _format = new std::vector<Format>(format.size()); + for (uint32_t fm : format) + { + _format->push_back((Format)fm); + } + assert(_format); + + _shape = new std::vector<int32_t>(shape.begin(), shape.end()); + + _shape = new std::vector<int32_t>(shape.begin(), shape.end()); + assert(_shape); + + assert(name); + _name = new std::string(name->str()); + assert(_name); + + if (npy_filename) + { + _npy_filename = new std::string(npy_filename->str()); + assert(_npy_filename); + } + else + { + _npy_filename = nullptr; + } +} + +TosaSerializationTensor::TosaSerializationTensor(std::string name, + const std::vector<Usage>& usage, + const std::vector<int32_t>& shape, + DType dtype, + const std::vector<Format>& format, + const std::string* npy_filename) +{ + + _dtype = dtype; + + _usage = new std::vector<Usage>(usage); + assert(_usage); + + _format = new std::vector<Format>(format); + assert(_format); + + _shape = new std::vector<int32_t>(shape); + assert(_shape); + + _name = new std::string(name); + assert(_name); + + if (npy_filename) + { + _npy_filename = new std::string(*npy_filename); + assert(_npy_filename); + } + else + { + _npy_filename = nullptr; + } +} + +TosaSerializationTensor::TosaSerializationTensor() +{ + _dtype = DType_UNKNOWN; + + _usage = new std::vector<Usage>(); + _format = new std::vector<Format>(); + _shape = new std::vector<int32_t>(); + _name = new std::string("UNKNOWN"); + assert(_usage && _format && _shape && _name); + + _npy_filename = nullptr; +} + +TosaSerializationTensor::TosaSerializationTensor(const TosaSerializationTensor& rhs) +{ + _dtype = rhs._dtype; + + assert(rhs._usage); + _usage = new std::vector<Usage>(*rhs._usage); + assert(_usage); + + assert(rhs._format); + _format = new std::vector<Format>(*rhs._format); + assert(_format); + + assert(rhs._shape); + _shape = new std::vector<int32_t>(*rhs._shape); + assert(_shape); + + assert(rhs._name); + _name = new std::string(*rhs._name); + assert(_name); + + if (rhs._npy_filename) + { + _npy_filename = new std::string(*rhs._npy_filename); + assert(_npy_filename); + } + else + { + _npy_filename = nullptr; + } +} + +TosaSerializationTensor& TosaSerializationTensor::operator=(const TosaSerializationTensor& rhs) +{ + _dtype = rhs._dtype; + + delete _usage; + assert(rhs._usage); + _usage = new std::vector<Usage>(*rhs._usage); + assert(_usage); + + delete _format; + assert(rhs._format); + _format = new std::vector<Format>(*rhs._format); + assert(_format); + + delete _shape; + assert(rhs._shape); + _shape = new std::vector<int32_t>(*rhs._shape); + assert(_shape); + + delete _name; + assert(rhs._name); + _name = new std::string(*rhs._name); + assert(_name); + + if (_npy_filename) + delete _npy_filename; + + if (rhs._npy_filename) + { + _npy_filename = new std::string(*rhs._npy_filename); + } + else + { + _npy_filename = nullptr; + } + return *this; +} + +TosaSerializationTensor::TosaSerializationTensor(TosaSerializationTensor&& rhs) +{ + _dtype = rhs._dtype; + std::swap(_format, rhs._format); + std::swap(_usage, rhs._usage); + std::swap(_shape, rhs._shape); + std::swap(_name, rhs._name); + std::swap(_npy_filename, rhs._npy_filename); +} + +TosaSerializationTensor& TosaSerializationTensor::operator=(TosaSerializationTensor&& rhs) +{ + _dtype = rhs._dtype; + std::swap(_format, rhs._format); + std::swap(_usage, rhs._usage); + std::swap(_shape, rhs._shape); + std::swap(_name, rhs._name); + std::swap(_npy_filename, rhs._npy_filename); + return *this; +} + +TosaSerializationTensor::~TosaSerializationTensor() +{ + delete _usage; + delete _format; + delete _shape; + delete _name; + if (_npy_filename) + delete _npy_filename; +} + +TosaSerializationOperator::TosaSerializationOperator(Op op, + Attribute attribute_type, + const TosaAttributeBase* attribute, + QuantInfo qinfo_type, + const TosaQuantInfoBase* qinfo, + std::vector<std::string> input_tensor_names, + std::vector<std::string> output_tensor_names) +{ + _op = op; + _attribute_type = attribute_type; + + switch (attribute_type) + { + case Attribute_NONE: + _attribute = new TosaNoneAttribute(); + break; +#define DEF_ATTRIBUTE(NAME, ...) \ + case Attribute_##NAME##Attribute: \ + _attribute = new Tosa##NAME##Attribute(attribute); \ + break; +#include "attribute.def" +#undef DEF_ATTRIBUTE + default: + printf("TosaSerializationOperator::TosaSerializationOperator(): Attribute %s not implemented yet\n", + EnumNamesAttribute()[attribute_type]); + assert(0); + } + + _qinfo_type = qinfo_type; + switch (qinfo_type) + { + case QuantInfo_NONE: + _qinfo = new TosaNoneQuantInfo(); + break; +#define DEF_QUANTIZATION_INFO(NAME, ...) \ + case QuantInfo_##NAME##QuantInfo: \ + _qinfo = new Tosa##NAME##QuantInfo(qinfo); \ + break; +#include "quant_info.def" +#undef DEF_QUANTIZATION_INFO + default: + printf("TosaSerializationOperator::TosaSerializationOperator(): QuantInfo %s not implemented yet\n", + EnumNamesQuantInfo()[qinfo_type]); + assert(0); + } + + assert(_attribute && _qinfo); + + _input_tensor_names = new std::vector<std::string>(input_tensor_names); + _output_tensor_names = new std::vector<std::string>(output_tensor_names); + + assert(_input_tensor_names && _output_tensor_names); + + _input_tensors = new std::vector<TosaSerializationTensor*>(); + _output_tensors = new std::vector<TosaSerializationTensor*>(); + + assert(_input_tensors && _output_tensors); +} + +TosaSerializationOperator::~TosaSerializationOperator() +{ + delete _attribute; + delete _qinfo; + delete _input_tensor_names; + delete _output_tensor_names; + // TosaSerializationTensor should be free'd in TosaSerializationSerializationHandler destructor + delete _input_tensors; + delete _output_tensors; +} + +TosaSerializationBasicBlock::TosaSerializationBasicBlock(std::string name, + std::vector<TosaSerializationOperator*> operators, + std::vector<TosaSerializationTensor*> tensors, + std::vector<std::string> inputs, + std::vector<std::string> outputs) +{ + + _name = new std::string(name); + assert(_name); + + _operators = new std::vector<TosaSerializationOperator*>(operators); + assert(_operators); + + _tensors = new std::vector<TosaSerializationTensor*>(tensors); + assert(_tensors); + + _inputs = new std::vector<std::string>(inputs); + assert(_inputs); + + _outputs = new std::vector<std::string>(outputs); + assert(_outputs); +} + +TosaSerializationBasicBlock::~TosaSerializationBasicBlock() +{ + delete _name; + + // deallocate all operators + for (auto op : GetOperators()) + { + delete op; // ~TosaSerializationOperator() + } + delete _operators; + + // deallocate all tensors + for (auto ts : GetTensors()) + { + delete ts; // ~TosaSerializationTensor() + } + _tensors->clear(); + + delete _inputs; + delete _outputs; +} + +TosaSerializationHandler::TosaSerializationHandler() +{ + _schemaLoaded = false; + _builder = new flatbuffers::FlatBufferBuilder(); + _parser = new flatbuffers::Parser(); + _blocks = new std::vector<TosaSerializationBasicBlock*>(); + + assert(_builder && _parser && _blocks); + + SetTosaVersion(); +} + +TosaSerializationHandler::~TosaSerializationHandler() +{ + if (_version) + delete _version; + delete _builder; + delete _parser; + + Clear(); // deallocate all basic blocks + + delete _blocks; +} + +tosa_err_t TosaSerializationHandler::SetTosaVersion() +{ + // version is specified within .fbs + // and it's encoded as defaulted value of CreateTosaVersion() + // need to write out one object to read that value out + // TODO: very costly now. is there any better way to encode constant in .fbs? + auto fboffset_version = CreateVersion(*_builder); + auto fboffset_tosa_graph = CreateTosaGraphDirect(*_builder, fboffset_version, nullptr); + _builder->Finish(fboffset_tosa_graph); + std::string jsongen; + uint8_t* buf = _builder->GetBufferPointer(); + auto fb_tosa_graph = GetTosaGraph(buf); + auto fb_tosa_version = fb_tosa_graph->version(); + + _version = new TosaVersion(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(), + fb_tosa_version->_experimental()); + + assert(_version); + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::LoadFileSchema(const char* schema_filename) +{ + std::string schema; + bool ok; + + ok = flatbuffers::LoadFile(schema_filename, false, &schema); + if (!ok) + { + printf("Error loading schema file: %s\n", schema_filename); + return TOSA_FILE_ERROR; + } + + ok = _parser->Parse(schema.c_str()); + if (!ok) + { + printf("Error parsing ISA schema file: %s\n", schema_filename); + return TOSA_FILE_ERROR; + } + _schemaLoaded = true; + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::LoadFileJson(const char* filename) +{ + std::string jsonfile; + bool ok; + tosa_err_t err; + + if (!_schemaLoaded) + { + return TOSA_SCHEMA_MISSING; + } + + ok = flatbuffers::LoadFile(filename, false, &jsonfile); + if (!ok) + { + printf("Error loading json file: %s\n", filename); + return TOSA_FILE_ERROR; + } + + ok = _parser->Parse(jsonfile.c_str()); + if (!ok) + { + printf("Error parsing json file: %s\n", filename); + return TOSA_FILE_ERROR; + } + + uint8_t* buf = _parser->builder_.GetBufferPointer(); + + err = InitWithBuf(buf); + if (err != TOSA_OK) + { + return err; + } + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::SaveFileJson(const char* filename) +{ + std::string jsongen; + tosa_err_t err; + + if (!_schemaLoaded) + { + return TOSA_SCHEMA_MISSING; + } + + err = FreezeBuilder(); + if (err != TOSA_OK) + { + return err; + } + + uint8_t* buf = _builder->GetBufferPointer(); + + if (!GenerateText(*_parser, buf, &jsongen)) + { + printf("Couldn't serialize parsed data to JSON!\n"); + return TOSA_FILE_ERROR; + } + + FILE* file = fopen(filename, "wb"); + + if (!file) + { + printf("Couldn't open output file: %s\n", filename); + return TOSA_FILE_ERROR; + } + + if (fwrite(jsongen.c_str(), sizeof(char), jsongen.size(), file) != jsongen.size()) + { + printf("Error writing to json output file: %s\n", filename); + fclose(file); + return TOSA_FILE_ERROR; + } + + if (file) + fclose(file); + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::LoadFileTosaFlatbuffer(const char* filename) +{ + std::string read_buffer; + tosa_err_t err; + uint8_t* buf; + bool ok; + + ok = flatbuffers::LoadFile(filename, false, &read_buffer); + if (!ok) + { + printf("Error loading flatbuffer file: %s\n", filename); + return TOSA_FILE_ERROR; + } + + buf = (uint8_t*)read_buffer.data(); + + err = InitWithBuf(buf); + if (err != TOSA_OK) + { + return err; + } + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::SaveFileTosaFlatbuffer(const char* filename) +{ + tosa_err_t err; + + err = FreezeBuilder(); + if (err != TOSA_OK) + { + return err; + } + + uint8_t* buf = _builder->GetBufferPointer(); + + bool ok = flatbuffers::SaveFile(filename, (const char*)buf, _builder->GetSize(), false); + if (!ok) + { + printf("Error saving floatbuffer file: %s\n", filename); + return TOSA_FILE_ERROR; + } + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::Clear() +{ + // deallocate all basic blocks + for (auto bb : GetBlocks()) + { + delete bb; + } + _blocks->clear(); + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::CheckTosaVersion(const TosaVersion& read_version) +{ + if ((*_version) != read_version) + { + printf("WARNING: read tosa version: %s != schema tosa version %s\n", read_version.to_string().c_str(), + this->_version->to_string().c_str()); + return TOSA_VERSION_MISMATCH; + } + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::InitWithBuf(const uint8_t* buf) +{ + auto fb_tosa_graph = GetTosaGraph(buf); + auto fb_tosa_version = fb_tosa_graph->version(); + auto fb_tosa_blocks = fb_tosa_graph->blocks(); + + std::vector<std::string> operator_inputs_container; + std::vector<std::string> operator_outputs_container; + + std::vector<TosaSerializationOperator*> block_operators_container; + std::vector<TosaSerializationTensor*> block_tensors_container; + std::vector<std::string> block_inputs_container; + std::vector<std::string> block_outputs_container; + + TosaAttributeBase* typed_attribute = NULL; + TosaQuantInfoBase* typed_qinfo = NULL; + TosaSerializationOperator* new_operator = NULL; + TosaSerializationBasicBlock* new_block = NULL; + TosaSerializationTensor* new_tensor = NULL; + + // erase container + Clear(); + + TosaVersion read_version(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(), + fb_tosa_version->_experimental()); + tosa_err_t err = CheckTosaVersion(read_version); + + if (err != TOSA_OK) + return err; + + for (size_t i = 0; i < fb_tosa_blocks->size(); i++) + { + auto curr_block = fb_tosa_blocks->Get(i); + + auto block_name = curr_block->name()->str(); + + auto fb_tosa_operators = curr_block->operators(); + block_operators_container.clear(); + for (size_t j = 0; j < fb_tosa_operators->size(); j++) + { + auto curr_operator = fb_tosa_operators->Get(j); + + auto operator_op = curr_operator->op(); + auto attribute_type = curr_operator->attribute_type(); + auto attribute = curr_operator->attribute(); + auto operator_qinfo_type = curr_operator->quant_info_type(); + auto operator_qinfo = curr_operator->quant_info(); + + // input tensors + auto operator_inputs = curr_operator->inputs(); + operator_inputs_container.clear(); + if (operator_inputs) + { + for (size_t k = 0; k < operator_inputs->size(); k++) + { + auto curr_input = operator_inputs->Get(k); + operator_inputs_container.push_back(curr_input->str()); + } + } + + // output tensors + auto operator_outputs = curr_operator->outputs(); + operator_outputs_container.clear(); + if (operator_outputs) + { + for (size_t k = 0; k < operator_outputs->size(); k++) + { + auto curr_output = operator_outputs->Get(k); + operator_outputs_container.push_back(curr_output->str()); + } + } + + switch (attribute_type) + { + case Attribute_NONE: + typed_attribute = new TosaNoneAttribute(); + break; +#define DEF_ATTRIBUTE(NAME, ...) \ + case Attribute_##NAME##Attribute: \ + typed_attribute = new Tosa##NAME##Attribute(attribute); \ + break; +#include "attribute.def" +#undef DEF_ATTRIBUTE + default: + printf("TosaSerializationHandler::InitWithBuf(): Attribute %s not implemented yet\n", + EnumNamesAttribute()[attribute_type]); + return TOSA_INTERNAL_ERROR; + } + + switch (operator_qinfo_type) + { + case QuantInfo_NONE: + typed_qinfo = new TosaNoneQuantInfo(); + break; +#define DEF_QUANTIZATION_INFO(NAME, ...) \ + case QuantInfo_##NAME##QuantInfo: \ + typed_qinfo = new Tosa##NAME##QuantInfo(operator_qinfo); \ + break; + +#include "quant_info.def" +#undef DEF_QUANTIZATION_INFO + default: + printf("TosaSerializationHandler::InitWithBuf(): QuantInfo %s not implemented yet\n", + EnumNamesQuantInfo()[operator_qinfo_type]); + return TOSA_INTERNAL_ERROR; + } + + new_operator = + new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, operator_qinfo_type, + typed_qinfo, operator_inputs_container, operator_outputs_container); + if (new_operator) + { + block_operators_container.push_back(new_operator); + } + else + { + return TOSA_MEMORY_ERROR; + } + + if (typed_attribute) + delete typed_attribute; + if (typed_qinfo) + delete typed_qinfo; + } + + auto fb_tosa_tensors = curr_block->tensors(); + block_tensors_container.clear(); + for (size_t j = 0; j < fb_tosa_tensors->size(); j++) + { + auto curr_tensor = fb_tosa_tensors->Get(j); + + auto tensor_name = curr_tensor->name(); + auto tensor_usage = curr_tensor->usage(); + auto tensor_shape = curr_tensor->shape(); + auto tensor_type = curr_tensor->type(); + auto tensor_format = curr_tensor->format(); + auto tensor_npy_filename = curr_tensor->npy_filename(); + + new_tensor = new TosaSerializationTensor(tensor_name, *tensor_usage, *tensor_shape, tensor_type, + *tensor_format, tensor_npy_filename); + if (new_tensor) + { + block_tensors_container.push_back(new_tensor); + } + else + { + return TOSA_MEMORY_ERROR; + } + } + + auto block_inputs = curr_block->inputs(); + auto block_outputs = curr_block->outputs(); + + block_inputs_container.clear(); + block_outputs_container.clear(); + + for (size_t j = 0; j < block_inputs->size(); j++) + { + auto curr_block_input = block_inputs->Get(j); + block_inputs_container.push_back(curr_block_input->str()); + } + for (size_t j = 0; j < block_outputs->size(); j++) + { + auto curr_block_output = block_outputs->Get(j); + block_outputs_container.push_back(curr_block_output->str()); + } + + new_block = new TosaSerializationBasicBlock(block_name, block_operators_container, block_tensors_container, + block_inputs_container, block_outputs_container); + if (new_block) + { + this->GetBlocks().push_back(new_block); + } + else + { + return TOSA_MEMORY_ERROR; + } + } + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::FreezeBuilder() +{ + std::vector<flatbuffers::Offset<TosaBasicBlock>> fboffset_blocks; + + std::vector<flatbuffers::Offset<TosaOperator>> fboffset_block_operators; + std::vector<flatbuffers::Offset<TosaTensor>> fboffset_block_tensors; + std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_inputs; + std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_outputs; + + std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_inputs; + std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_outputs; + + // translate TosaFlatbufferOperator to flatbuffers::Offset<TosaOperator> + for (auto block : GetBlocks()) + { + fboffset_block_operators.clear(); + fboffset_block_tensors.clear(); + fboffset_block_inputs.clear(); + fboffset_block_outputs.clear(); + + auto block_name = _builder->CreateString(block->GetName().c_str()); + + for (auto tensor_str : block->GetInputs()) + { + auto tensor_name = _builder->CreateString(tensor_str.c_str()); + fboffset_block_inputs.push_back(tensor_name); + } + + for (auto tensor_str : block->GetOutputs()) + { + auto tensor_name = _builder->CreateString(tensor_str.c_str()); + fboffset_block_outputs.push_back(tensor_name); + } + + auto fb_block_inputs = _builder->CreateVector(fboffset_block_inputs); + auto fb_block_outputs = _builder->CreateVector(fboffset_block_outputs); + + for (auto op : block->GetOperators()) + { + fboffset_operator_inputs.clear(); + fboffset_operator_outputs.clear(); + + auto operator_op = op->GetOp(); + auto attribute_type = op->GetAttributeType(); + + for (auto tensor_str : op->GetInputTensorNames()) + { + auto tensor_name = _builder->CreateString(tensor_str.c_str()); + fboffset_operator_inputs.push_back(tensor_name); + } + + for (auto tensor_str : op->GetOutputTensorNames()) + { + auto tensor_name = _builder->CreateString(tensor_str.c_str()); + fboffset_operator_outputs.push_back(tensor_name); + } + + auto fb_operator_inputs = _builder->CreateVector(fboffset_operator_inputs); + auto fb_operator_outputs = _builder->CreateVector(fboffset_operator_outputs); + + flatbuffers::Offset<void> fb_attribute; + switch (attribute_type) + { + case Attribute_NONE: + fb_attribute = 0; + break; + +#define DEF_ARGS_S_STR(NAME, V) , _builder->CreateString(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V().c_str()) +#define DEF_ARGS_S_DEFAULT(NAME, V) , reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V() + +#define DEF_ARGS_S_int32_t(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) +#define DEF_ARGS_S_float(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) +#define DEF_ARGS_S_bool(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) +#define DEF_ARGS_S_ResizeMode(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) +#define DEF_ARGS_S_string(NAME, V) DEF_ARGS_S_STR(NAME, V) + +#define DEF_ARGS_S(NAME, T, V) DEF_ARGS_S_##T(NAME, V) +#define DEF_ARGS_V(NAME, T, V) , _builder->CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V()) + +#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0) +#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) +#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) +#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) +#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) +#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) +#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) +#define DEF_ATTRIBUTE(NAME, NUM_ARGS, ...) \ + case Attribute_##NAME##Attribute: \ + fb_attribute = Create##NAME##Attribute(*_builder DEF_ARGS_##NUM_ARGS(NAME##Attribute, __VA_ARGS__)).Union(); \ + break; + +#include "attribute.def" +#undef DEF_ATTRIBUTE +#undef DEF_ARGS_1 +#undef DEF_ARGS_2 +#undef DEF_ARGS_3 +#undef DEF_ARGS_4 +#undef DEF_ARGS_5 +#undef DEF_ARGS_6 +#undef DEF_ARGS_7 +#undef DEF_ARGS_S +#undef DEF_ARGS_V +#undef DEF_ARGS_S_int32_t +#undef DEF_ARGS_S_float +#undef DEF_ARGS_S_bool +#undef DEF_ARGS_S_ResizeMode +#undef DEF_ARGS_S_string +#undef DEF_ARGS_S_STR +#undef DEF_ARGS_S_DEFAULT + default: + printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n", + EnumNamesAttribute()[attribute_type]); + return TOSA_INTERNAL_ERROR; + } + + auto qinfo_type = op->GetQInfoType(); + flatbuffers::Offset<void> fb_operator_qinfo; + switch (qinfo_type) + { + case QuantInfo_NONE: + fb_operator_qinfo = 0; + break; +#define DEF_ARGS_S(NAME, T, V) , reinterpret_cast<Tosa##NAME*>(op->GetQInfo())->V() +#define DEF_ARGS_V(NAME, T, V) , _builder->CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetQInfo())->V()) + +#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0) +#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) +#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) +#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) +#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) +#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) +#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) +#define DEF_ARGS_8(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ + DEF_ARGS_##F7(NAME, T7, V7) +#define DEF_ARGS_9(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7, T8, F8, V8) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ + DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) +#define DEF_ARGS_10(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7, T8, F8, V8, T9, F9, V9) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ + DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) DEF_ARGS_##F9(NAME, T9, V9) +#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \ + case QuantInfo_##NAME##QuantInfo: \ + fb_operator_qinfo = \ + Create##NAME##QuantInfo(*_builder DEF_ARGS_##NUM_ARGS(NAME##QuantInfo, __VA_ARGS__)).Union(); \ + break; + +#include "quant_info.def" +#undef DEF_QUANTIZATION_INFO +#undef DEF_ARGS_1 +#undef DEF_ARGS_2 +#undef DEF_ARGS_3 +#undef DEF_ARGS_4 +#undef DEF_ARGS_5 +#undef DEF_ARGS_6 +#undef DEF_ARGS_7 +#undef DEF_ARGS_8 +#undef DEF_ARGS_9 +#undef DEF_ARGS_10 +#undef DEF_ARGS_S +#undef DEF_ARGS_V + default: + printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n", + EnumNamesAttribute()[attribute_type]); + return TOSA_INTERNAL_ERROR; + } + + auto fboffset_operator = + CreateTosaOperator(*_builder, operator_op, attribute_type, fb_attribute, fb_operator_inputs, + fb_operator_outputs, qinfo_type, fb_operator_qinfo); + fboffset_block_operators.push_back(fboffset_operator); + } + + auto fb_block_operators = _builder->CreateVector(fboffset_block_operators); + + for (auto tensor : block->GetTensors()) + { + + auto tensor_name = _builder->CreateString(tensor->GetName().c_str()); + auto tensor_usage = + _builder->CreateVector(std::vector<uint32_t>(tensor->GetUsage().begin(), tensor->GetUsage().end())); + auto tensor_shape = _builder->CreateVector(tensor->GetShape()); + auto tensor_dtype = tensor->GetDtype(); + auto tensor_format = + _builder->CreateVector(std::vector<uint32_t>(tensor->GetFormat().begin(), tensor->GetFormat().end())); + flatbuffers::Offset<flatbuffers::String> tensor_npy_filename = 0; + if (tensor->GetNpyFilePtr()) + tensor_npy_filename = _builder->CreateString(tensor->GetNpyFilePtr()->c_str()); + + auto fboffset_tensor = CreateTosaTensor(*_builder, tensor_name, tensor_shape, tensor_dtype, tensor_usage, + tensor_format, tensor_npy_filename); + fboffset_block_tensors.push_back(fboffset_tensor); + } + + auto fb_block_tensors = _builder->CreateVector(fboffset_block_tensors); + + auto fboffset_block = CreateTosaBasicBlock(*_builder, block_name, fb_block_operators, fb_block_tensors, + fb_block_inputs, fb_block_outputs); + fboffset_blocks.push_back(fboffset_block); + } + + auto fb_blocks = _builder->CreateVector(fboffset_blocks); + + auto fb_version = CreateVersion(*_builder, GetTosaVersion()->_major, GetTosaVersion()->_minor, + GetTosaVersion()->_patch, GetTosaVersion()->_experimental); + + auto fb_graph = CreateTosaGraph(*_builder, fb_version, fb_blocks); + _builder->Finish(fb_graph); + + return TOSA_OK; +} + +// Magic NUMPY header +static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{"; +static const int NUMPY_HEADER_SZ = 128; + +NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf) +{ + const char dtype_str[] = "'|b1'"; + FILE* infile = nullptr; + NPError rc = NO_ERROR; + + assert(filename); + assert(databuf); + + infile = fopen(filename, "rb"); + if (!infile) + { + rc = FILE_NOT_FOUND; + goto done; + } + + rc = checkNpyHeader(infile, elems, dtype_str); + if (rc != NO_ERROR) + { + goto done; + } + + // Read in the data from numpy byte array to native bool + // array format + for (uint32_t i = 0; i < elems; i++) + { + int val = fgetc(infile); + + if (val == EOF) + { + rc = FILE_IO_ERROR; + goto done; + } + + databuf[i] = val; + } + +done: + + if (infile) + fclose(infile); + + return rc; +} + +NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf) +{ + const char dtype_str[] = "'<i4'"; + FILE* infile = nullptr; + NPError rc = NO_ERROR; + + assert(filename); + assert(databuf); + + infile = fopen(filename, "rb"); + if (!infile) + { + rc = FILE_NOT_FOUND; + goto done; + } + + rc = checkNpyHeader(infile, elems, dtype_str); + if (rc != NO_ERROR) + { + goto done; + } + + // Now we are at the beginning of the data + // Parse based on the datatype and number of dimensions + if (fread(databuf, sizeof(int32_t), elems, infile) != elems) + { + rc = FILE_IO_ERROR; + goto done; + } + +done: + + if (infile) + fclose(infile); + + return rc; +} + +NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf) +{ + const char dtype_str[] = "'<i8'"; + FILE* infile = nullptr; + NPError rc = NO_ERROR; + + assert(filename); + assert(databuf); + + infile = fopen(filename, "rb"); + if (!infile) + { + rc = FILE_NOT_FOUND; + goto done; + } + + rc = checkNpyHeader(infile, elems, dtype_str); + if (rc != NO_ERROR) + { + goto done; + } + + // Now we are at the beginning of the data + // Parse based on the datatype and number of dimensions + if (fread(databuf, sizeof(int64_t), elems, infile) != elems) + { + rc = FILE_IO_ERROR; + goto done; + } + +done: + + if (infile) + fclose(infile); + + return rc; +} + +NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf) +{ + const char dtype_str[] = "'<f4'"; + FILE* infile = nullptr; + NPError rc = NO_ERROR; + + assert(filename); + assert(databuf); + + infile = fopen(filename, "rb"); + if (!infile) + { + rc = FILE_NOT_FOUND; + goto done; + } + + rc = checkNpyHeader(infile, elems, dtype_str); + if (rc != NO_ERROR) + { + goto done; + } + + // Now we are at the beginning of the data + // Parse based on the datatype and number of dimensions + if (fread(databuf, sizeof(float), elems, infile) != elems) + { + rc = FILE_IO_ERROR; + goto done; + } + +done: + + if (infile) + fclose(infile); + + return rc; +} + +NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str) +{ + char buf[NUMPY_HEADER_SZ + 1]; + char* ptr = nullptr; + NPError rc = NO_ERROR; + bool foundFormat = false; + bool foundOrder = false; + bool foundShape = false; + bool fortranOrder = false; + std::vector<int> shape; + uint32_t totalElems = 1; + char* outer_end = NULL; + + assert(infile); + assert(elems > 0); + + if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1) + { + rc = HEADER_PARSE_ERROR; + goto done; + } + + if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1)) + { + rc = HEADER_PARSE_ERROR; + goto done; + } + + ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end); + + // Read in the data type, order, and shape + while (ptr && (!foundFormat || !foundOrder || !foundShape)) + { + + // End of string? + if (!ptr) + break; + + // Skip whitespace + while (isspace(*ptr)) + ptr++; + + // Parse the dictionary field name + if (!strcmp(ptr, "'descr'")) + { + ptr = strtok_r(NULL, ",", &outer_end); + if (!ptr) + break; + + while (isspace(*ptr)) + ptr++; + + if (strcmp(ptr, dtype_str)) + { + rc = FILE_TYPE_MISMATCH; + goto done; + } + + foundFormat = true; + } + else if (!strcmp(ptr, "'fortran_order'")) + { + ptr = strtok_r(NULL, ",", &outer_end); + if (!ptr) + break; + + while (isspace(*ptr)) + ptr++; + + if (!strcmp(ptr, "False")) + { + fortranOrder = false; + } + else + { + rc = FILE_TYPE_MISMATCH; + goto done; + } + + foundOrder = true; + } + else if (!strcmp(ptr, "'shape'")) + { + + ptr = strtok_r(NULL, "(", &outer_end); + if (!ptr) + break; + ptr = strtok_r(NULL, ")", &outer_end); + if (!ptr) + break; + + while (isspace(*ptr)) + ptr++; + + // The shape contains N comma-separated integers. Read up to 4. + char* end = NULL; + + ptr = strtok_r(ptr, ",", &end); + for (int i = 0; i < 4; i++) + { + // Out of dimensions + if (!ptr) + break; + + shape.push_back(atoi(ptr)); + totalElems *= atoi(ptr); + ptr = strtok_r(NULL, ",", &end); + } + + foundShape = true; + } + else + { + rc = HEADER_PARSE_ERROR; + goto done; + } + + if (!ptr) + break; + + ptr = strtok_r(NULL, ":", &outer_end); + } + + if (!foundShape || !foundFormat || !foundOrder) + { + rc = HEADER_PARSE_ERROR; + goto done; + } + + // Validate header + if (fortranOrder != false) + { + rc = FILE_TYPE_MISMATCH; + goto done; + } + + if (totalElems != elems) + { + rc = BUFFER_SIZE_MISMATCH; + goto done; + } + + // Go back to the begininng and read until the end of the header dictionary + rewind(infile); + int val; + + do + { + val = fgetc(infile); + } while (val != EOF && val != '\n'); + +done: + + return rc; +} + +NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf) +{ + std::vector<int32_t> shape = { (int32_t)elems }; + return writeToNpyFile(filename, shape, databuf); +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf) +{ + const char dtype_str[] = "'|b1'"; + FILE* outfile = nullptr; + NPError rc = NO_ERROR; + uint32_t totalElems = 1; + + assert(filename); + assert(shape.size() >= 0); + assert(databuf); + + outfile = fopen(filename, "wb"); + + if (!outfile) + { + rc = FILE_NOT_FOUND; + goto done; + } + + for (uint32_t i = 0; i < shape.size(); i++) + { + totalElems *= shape[i]; + } + + rc = writeNpyHeader(outfile, shape, dtype_str); + + // Numpy save format stores booleans as a byte array + // with one byte per boolean. This somewhat inefficiently + // remaps from system bool[] to this format. + for (uint32_t i = 0; i < totalElems; i++) + { + int val = databuf[i] ? 1 : 0; + if (fputc(val, outfile) == EOF) + { + rc = FILE_IO_ERROR; + goto done; + } + } + +done: + + if (outfile) + fclose(outfile); + + return rc; +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf) +{ + std::vector<int32_t> shape = { (int32_t)elems }; + return writeToNpyFile(filename, shape, databuf); +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf) +{ + const char dtype_str[] = "'<i4'"; + FILE* outfile = nullptr; + NPError rc = NO_ERROR; + uint32_t totalElems = 1; + + assert(filename); + assert(shape.size() >= 0); + assert(databuf); + + outfile = fopen(filename, "wb"); + + if (!outfile) + { + rc = FILE_NOT_FOUND; + goto done; + } + + for (uint32_t i = 0; i < shape.size(); i++) + { + totalElems *= shape[i]; + } + + rc = writeNpyHeader(outfile, shape, dtype_str); + + if (fwrite(databuf, sizeof(int32_t), totalElems, outfile) != totalElems) + { + rc = FILE_IO_ERROR; + goto done; + } + +done: + + if (outfile) + fclose(outfile); + + return rc; +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf) +{ + std::vector<int32_t> shape = { (int32_t)elems }; + return writeToNpyFile(filename, shape, databuf); +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf) +{ + const char dtype_str[] = "'<i8'"; + FILE* outfile = nullptr; + NPError rc = NO_ERROR; + uint32_t totalElems = 1; + + assert(filename); + assert(shape.size() >= 0); + assert(databuf); + + outfile = fopen(filename, "wb"); + + if (!outfile) + { + rc = FILE_NOT_FOUND; + goto done; + } + + for (uint32_t i = 0; i < shape.size(); i++) + { + totalElems *= shape[i]; + } + + rc = writeNpyHeader(outfile, shape, dtype_str); + + if (fwrite(databuf, sizeof(int64_t), totalElems, outfile) != totalElems) + { + rc = FILE_IO_ERROR; + goto done; + } + +done: + + if (outfile) + fclose(outfile); + + return rc; +} + +NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf) +{ + std::vector<int32_t> shape = { (int32_t)elems }; + return writeToNpyFile(filename, shape, databuf); +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf) +{ + const char dtype_str[] = "'<f4'"; + FILE* outfile = nullptr; + NPError rc = NO_ERROR; + uint32_t totalElems = 1; + + assert(filename); + assert(shape.size() >= 0); + assert(databuf); + + outfile = fopen(filename, "wb"); + + if (!outfile) + { + rc = FILE_NOT_FOUND; + goto done; + } + + for (uint32_t i = 0; i < shape.size(); i++) + { + totalElems *= shape[i]; + } + + rc = writeNpyHeader(outfile, shape, dtype_str); + + if (fwrite(databuf, sizeof(float), totalElems, outfile) != totalElems) + { + rc = FILE_IO_ERROR; + goto done; + } + +done: + + if (outfile) + fclose(outfile); + + return rc; +} + +NumpyUtilities::NPError + NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str) +{ + NPError rc = NO_ERROR; + uint32_t i; + char header[NUMPY_HEADER_SZ + 1]; + int headerPos = 0; + + assert(outfile); + assert(shape.size() >= 0); + + // Space-fill the header and end with a newline to start per numpy spec + memset(header, 0x20, NUMPY_HEADER_SZ); + header[NUMPY_HEADER_SZ - 1] = '\n'; + header[NUMPY_HEADER_SZ] = 0; + + // Write out the hard-coded header. We only support a 128-byte 1.0 header + // for now, which should be sufficient for simple tensor types of any + // reasonable rank. + memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1); + headerPos += sizeof(NUMPY_HEADER_STR) - 1; + + // Output the format dictionary + // Hard-coded for I32 for now + headerPos += + snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,", + dtype_str, shape.size() > 0 ? shape[0] : 1); + + // Remainder of shape array + for (i = 1; i < shape.size(); i++) + { + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]); + } + + // Close off the dictionary + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }"); + + // snprintf leaves a NULL at the end. Replace with a space + header[headerPos] = 0x20; + + if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1) + { + rc = FILE_IO_ERROR; + goto done; + } + +done: + + return rc; +} diff --git a/serialization/tosa_serialization_handler.h b/serialization/tosa_serialization_handler.h new file mode 100644 index 0000000..124b8e0 --- /dev/null +++ b/serialization/tosa_serialization_handler.h @@ -0,0 +1,423 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _TOSA_SERIALIZATION_HANDLER_H +#define _TOSA_SERIALIZATION_HANDLER_H +#include "attribute.h" +#include "flatbuffers/idl.h" +#include "flatbuffers/util.h" +#include "quant_info.h" +#include "tosa_generated.h" +#include <cstdint> +#include <memory> +#include <string> +#include <vector> + +namespace tosa +{ + +enum tosa_err_t +{ + TOSA_OK, + TOSA_USER_ERROR, + TOSA_FILE_ERROR, + TOSA_MEMORY_ERROR, + TOSA_SCHEMA_MISSING, + TOSA_INTERNAL_ERROR, + TOSA_VERSION_MISMATCH, + NUM_TOSA_ERROR +}; + +struct TosaVersion +{ + int32_t _major; + int32_t _minor; + int32_t _patch; + bool _experimental; + + TosaVersion() = delete; + TosaVersion(int32_t major, int32_t minor, int32_t patch, bool experimental) + { + _major = major; + _minor = minor; + _patch = patch; + _experimental = experimental; + } + + std::string to_string() const + { + std::string str; + str += std::to_string(_major) + "."; + str += std::to_string(_minor) + "."; + str += std::to_string(_patch); + if (_experimental) + str += "(experimental)"; + return str; + }; + + bool operator==(const TosaVersion& rhs) + { + if (rhs._major == _major && rhs._minor == _minor && rhs._patch == _patch && rhs._experimental == _experimental) + { + return true; + } + return false; + } + + bool operator!=(const TosaVersion& rhs) + { + return !((*this) == rhs); + } +}; + +class TosaSerializationHandler; + +class TosaSerializationTensor +{ +public: + // constructor and destructor + TosaSerializationTensor(const flatbuffers::String* name, + const flatbuffers::Vector<uint32_t>& usage, + const flatbuffers::Vector<int32_t>& shape, + DType dtype, + const flatbuffers::Vector<uint32_t>& format, + const flatbuffers::String* npy_filename); + TosaSerializationTensor(std::string name, + const std::vector<Usage>& usage, + const std::vector<int32_t>& shape, + DType dtype, + const std::vector<Format>& format, + const std::string* npy_filename); + TosaSerializationTensor(); + ~TosaSerializationTensor(); + + // copy constructor/assignment + TosaSerializationTensor(const TosaSerializationTensor& rhs); + TosaSerializationTensor& operator=(const TosaSerializationTensor& rhs); + + // move constructor/assignment + TosaSerializationTensor(TosaSerializationTensor&& rhs); + TosaSerializationTensor& operator=(TosaSerializationTensor&& rhs); + + // accessor + std::string GetName() const + { + return *_name; + } + const std::vector<int32_t>& GetShape() const + { + return *_shape; + } + DType GetDtype() + { + return _dtype; + } + bool HasFormat(Format format) + { + for (Format us : *_format) + { + if (us == format) + return true; + } + return false; + } + std::vector<Format>& GetFormat() + { + return *_format; + } + bool HasUsage(Usage usage) + { + for (Usage us : *_usage) + { + if (us == usage) + return true; + } + return false; + } + std::vector<Usage>& GetUsage() + { + return *_usage; + } + std::string* GetNpyFilePtr() const + { + return _npy_filename; + } + + // modifier + void SetDtype(DType dtype) + { + _dtype = dtype; + } + void SetName(std::string name) + { + *_name = name; + } + +private: + DType _dtype; /* data type enumeration, see tosa_isa_generated.h */ + std::vector<Format>* _format; /* list of possible tensor format */ + std::vector<Usage>* _usage; /* list of possible tensor usage */ + std::vector<int32_t>* _shape; /* shape of the tensor */ + std::string* _name; /* name of the tensor, used for solving dependency */ + std::string* _npy_filename; /* numpy array filename if not null. so null is the distinguisher */ +}; + +class TosaSerializationOperator +{ +public: + // use default copy, void constructor + // constructor and destructor + TosaSerializationOperator(Op op_name, + Attribute attribute_type, + const TosaAttributeBase* attribute, + QuantInfo qinfo_type, + const TosaQuantInfoBase* qinfo, + std::vector<std::string> input_tensor_names, + std::vector<std::string> output_tensor_names); + ~TosaSerializationOperator(); + + // accessor + Op GetOp() const + { + return _op; + } + Attribute GetAttributeType() const + { + return _attribute_type; + } + TosaAttributeBase* GetAttribute() const + { + return _attribute; + } + QuantInfo GetQInfoType() const + { + return _qinfo_type; + } + TosaQuantInfoBase* GetQInfo() const + { + return _qinfo; + } + std::vector<std::string>& GetInputTensorNames() const + { + return *_input_tensor_names; + } + std::vector<std::string>& GetOutputTensorNames() const + { + return *_output_tensor_names; + } + std::vector<TosaSerializationTensor*>& GetInputTensors() const + { + return *_input_tensors; + } + std::vector<TosaSerializationTensor*>& GetOutputTensors() const + { + return *_output_tensors; + } + +private: + Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */ + Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */ + TosaAttributeBase* _attribute; /* real attribute class goes here */ + QuantInfo _qinfo_type; /* QuantInfo enum */ + TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */ + std::vector<std::string>* _input_tensor_names; /* array of input tensor names */ + std::vector<std::string>* _output_tensor_names; /* array of output tensor names */ + + std::vector<TosaSerializationTensor*>* _input_tensors; /* array of input TosaSerializationTensor */ + std::vector<TosaSerializationTensor*>* _output_tensors; /* array of output TosaSerializationTensor */ +}; + +class TosaSerializationBasicBlock +{ +public: + // constructor and destructor + TosaSerializationBasicBlock(std::string name, + std::vector<TosaSerializationOperator*> operators, + std::vector<TosaSerializationTensor*> tensors, + std::vector<std::string> inputs, + std::vector<std::string> outputs); + ~TosaSerializationBasicBlock(); + + // accessor + std::string GetName() const + { + return *_name; + } + std::vector<TosaSerializationOperator*>& GetOperators() + { + return *_operators; + } + std::vector<TosaSerializationTensor*>& GetTensors() + { + return *_tensors; + } + + TosaSerializationTensor* GetTensorByName(std::string name) + { + TosaSerializationTensor* result = nullptr; + for (auto tensor : GetTensors()) + { + if (tensor->GetName() == name) + { + result = tensor; + break; + } + } + return result; + } + + std::vector<std::string>& GetInputs() + { + return *_inputs; + } + std::vector<std::string>& GetOutputs() + { + return *_outputs; + } + +private: + std::string* _name; /* name of basic block */ + std::vector<TosaSerializationOperator*>* _operators; /* TosaSerializationOperator list */ + std::vector<TosaSerializationTensor*>* _tensors; /* TosaSerializationTensor list */ + std::vector<std::string>* _inputs; /* array of string to specify block inputs */ + std::vector<std::string>* _outputs; /* array of string to specify block outputs */ +}; + +/* + * this is a helper class for writing/reading Tosa ISA + * supported format: .tosa (flatbuffer), .json + * and provide high-level std::vector-like interface + * to access internal data structure + */ +class TosaSerializationHandler +{ +public: + // constructor and destructor + TosaSerializationHandler(); + ~TosaSerializationHandler(); + + // file io + tosa_err_t LoadFileJson(const char* filename); + tosa_err_t LoadFileTosaFlatbuffer(const char* filename); + tosa_err_t SaveFileJson(const char* filename); + tosa_err_t SaveFileTosaFlatbuffer(const char* filename); + tosa_err_t LoadFileSchema(const char* filename); + + // version + TosaVersion* GetTosaVersion() const + { + return _version; + } + + // accessor + std::vector<TosaSerializationBasicBlock*>& GetBlocks() + { + return *_blocks; + } + + TosaSerializationBasicBlock* GetBlockByName(std::string name) + { + TosaSerializationBasicBlock* result = nullptr; + for (auto block : GetBlocks()) + { + if (block->GetName() == name) + { + result = block; + break; + } + } + return result; + } + TosaSerializationBasicBlock* GetMainBlock() + { + TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main")); + assert(main_block); + return main_block; + } + + std::vector<std::string>& GetInputs() + { + return GetMainBlock()->GetInputs(); + } + std::vector<std::string>& GetOutputs() + { + return GetMainBlock()->GetOutputs(); + } + + bool GetSchemaLoaded() const + { + return _schemaLoaded; + } + +protected: + tosa_err_t Clear(); + tosa_err_t InitWithBuf(const uint8_t* buf); + tosa_err_t FreezeBuilder(); + tosa_err_t SetTosaVersion(); + tosa_err_t CheckTosaVersion(const TosaVersion& read_version); + +private: + TosaVersion* _version; /* tosa version */ + flatbuffers::FlatBufferBuilder* _builder; /* flatbuffer builder */ + flatbuffers::Parser* _parser; /* flatbuffer parser, used for json parsing */ + std::vector<TosaSerializationBasicBlock*>* _blocks; /* array structure to store all TosaSerializationBasicBlock */ + bool _schemaLoaded; /* is the schema properly loaded? */ +}; + +class NumpyUtilities +{ +public: + enum NPError + { + NO_ERROR = 0, + FILE_NOT_FOUND, + FILE_IO_ERROR, + FILE_TYPE_MISMATCH, + HEADER_PARSE_ERROR, + BUFFER_SIZE_MISMATCH, + }; + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* buf); + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* buf); + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* buf); + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* buf); + + static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* buf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* buf); + + static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* buf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* buf); + + static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* buf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* buf); + + static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* buf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* buf); + +private: + static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str); + static NPError writeNpyHeader(FILE* infile, const std::vector<int32_t>& shape, const char* dtype_str); +}; + +} // namespace tosa + +#endif // _TOSA_SERIALIZATION_HANDLER_H |