aboutsummaryrefslogtreecommitdiff
path: root/serialization
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2020-10-13 16:11:07 -0700
committerKevin Cheng <kevin.cheng@arm.com>2020-10-14 11:11:43 -0700
commite5e2676409a936431f87d31fb74d825257b20804 (patch)
tree304d93d993ef6417b02a515025f9030367682774 /serialization
parent88b7860f180f91b5b66764c61cfd97d8bc53cece (diff)
downloadreference_model-e5e2676409a936431f87d31fb74d825257b20804.tar.gz
Initial checkin of TOSA reference_model and tests
Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6 Signed-off-by: Eric Kunze <eric.kunze@arm.com>
Diffstat (limited to 'serialization')
-rw-r--r--serialization/CMakeLists.txt32
-rw-r--r--serialization/attribute.def90
-rw-r--r--serialization/attribute.h181
-rw-r--r--serialization/operator.def123
-rw-r--r--serialization/quant_info.def43
-rw-r--r--serialization/quant_info.h164
-rw-r--r--serialization/tosa.fbs318
-rw-r--r--serialization/tosa_generated.h2605
-rw-r--r--serialization/tosa_serialization_handler.cpp1526
-rw-r--r--serialization/tosa_serialization_handler.h423
10 files changed, 5505 insertions, 0 deletions
diff --git a/serialization/CMakeLists.txt b/serialization/CMakeLists.txt
new file mode 100644
index 0000000..7bca824
--- /dev/null
+++ b/serialization/CMakeLists.txt
@@ -0,0 +1,32 @@
+cmake_minimum_required (VERSION 3.4)
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+project (tosa)
+
+set (CMAKE_CXX_STANDARD 11)
+set (CMAKE_CXX_FLAGS "-g -Wall")
+set (FLATBUFFERS_SRC_DIR "../thirdparty/flatbuffers")
+
+set (SOURCE
+ tosa_serialization_handler.cpp
+)
+
+add_library(tosa_serialization STATIC ${SOURCE})
+
+include_directories("./")
+
+target_link_libraries(tosa_serialization PRIVATE flatbuffers)
diff --git a/serialization/attribute.def b/serialization/attribute.def
new file mode 100644
index 0000000..88e8c81
--- /dev/null
+++ b/serialization/attribute.def
@@ -0,0 +1,90 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ Syntax:
+ DEF_ATTRIBUTE(ATTRIBUTE_NAME, NUM_ARGS_IN_ATTRIBUTES, ARG0_TYPE, ARG0_SCALAR_OR_VECTOR, ARGS0_NAME, ...)
+
+ Description:
+ ATTRIBUTE_NAME: corresponding attribute name, must match corresponding "table XXXAttribute" in tosa.fbs
+ NUM_ARGS_IN_ATTRIBUTES: number of arguments in this attribute
+ ARG0_TYPE: data type of arg0 in attribute
+ ARG0_SCALAR_OR_VECTOR: is arg0 a scalar(S) or a vector(V)
+ ARG0_NAME: name of arg0
+ ...: variadic variables for more arguments, depending on NUM_ARGS_IN_ATTRIBUTES
+*/
+
+DEF_ATTRIBUTE(Pool2d, 3,
+ int32_t, V, padding,
+ int32_t, V, kernel,
+ int32_t, V, stride)
+
+DEF_ATTRIBUTE(Conv2d, 3,
+ int32_t, V, padding,
+ int32_t, V, stride,
+ int32_t, V, dilation)
+
+DEF_ATTRIBUTE(TransposeConv2d, 4,
+ int32_t, V, outpad,
+ int32_t, V, stride,
+ int32_t, V, dilation,
+ int32_t, V, output_shape)
+
+DEF_ATTRIBUTE(ReluN, 2,
+ int32_t, S, max_int,
+ float, S, max_fp)
+
+DEF_ATTRIBUTE(Axis, 1,
+ int32_t, S, axis)
+
+DEF_ATTRIBUTE(Reshape, 1,
+ int32_t, V, shape)
+
+DEF_ATTRIBUTE(Slice, 2,
+ int32_t, V, begin,
+ int32_t, V, size)
+
+DEF_ATTRIBUTE(Tile, 1,
+ int32_t, V, multiples)
+
+DEF_ATTRIBUTE(Resize, 5,
+ int32_t, V, output_size,
+ int32_t, V, stride,
+ int32_t, V, offset,
+ int32_t, S, shift,
+ ResizeMode, S, mode)
+
+DEF_ATTRIBUTE(Clamp, 4,
+ int32_t, S, min_int,
+ int32_t, S, max_int,
+ float, S, min_fp,
+ float, S, max_fp)
+
+DEF_ATTRIBUTE(Rescale, 7,
+ int32_t, S, input_zp,
+ int32_t, S, output_zp,
+ int32_t, V, multiplier,
+ int32_t, V, shift,
+ bool, S, scale32,
+ bool, S, double_round,
+ bool, S, per_channel)
+
+DEF_ATTRIBUTE(CondIf, 2,
+ string, S, then_branch,
+ string, S, else_branch)
+
+DEF_ATTRIBUTE(WhileLoop, 2,
+ string, S, cond_branch,
+ string, S, body_branch)
diff --git a/serialization/attribute.h b/serialization/attribute.h
new file mode 100644
index 0000000..2a33a8f
--- /dev/null
+++ b/serialization/attribute.h
@@ -0,0 +1,181 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef _TOSA_SERIALIZATION_ATTRIBUTE_H
+#define _TOSA_SERIALIZATION_ATTRIBUTE_H
+#include "flatbuffers/idl.h"
+#include "flatbuffers/util.h"
+#include "tosa_generated.h"
+
+using std::string;
+
+namespace tosa
+{
+
+class TosaAttributeBase
+{
+public:
+ virtual ~TosaAttributeBase()
+ {}
+};
+
+class TosaNoneAttribute : public TosaAttributeBase
+{
+public:
+ TosaNoneAttribute()
+ {}
+ TosaNoneAttribute(TosaNoneAttribute* p)
+ {}
+};
+
+#define DEF_ARGS_VER0_S_STR(V) _##V = p->V()->str();
+#define DEF_ARGS_VER0_S_DEFAULT(V) _##V = p->V();
+
+#define DEF_ARGS_VER0_S_int32_t(V) DEF_ARGS_VER0_S_DEFAULT(V)
+#define DEF_ARGS_VER0_S_float(V) DEF_ARGS_VER0_S_DEFAULT(V)
+#define DEF_ARGS_VER0_S_bool(V) DEF_ARGS_VER0_S_DEFAULT(V)
+#define DEF_ARGS_VER0_S_ResizeMode(V) DEF_ARGS_VER0_S_DEFAULT(V)
+#define DEF_ARGS_VER0_S_string(V) DEF_ARGS_VER0_S_STR(V)
+
+#define DEF_ARGS_VER0_S(T, V) DEF_ARGS_VER0_S_##T(V)
+#define DEF_ARGS_VER0_V(T, V) _##V = std::vector<T>(p->V()->begin(), p->V()->end());
+
+#define DEF_ARGS_VER1_S(T, V) const T& V
+#define DEF_ARGS_VER1_V(T, V) const std::vector<T>& V
+#define DEF_ARGS_VER2_S(T, V) _##V = V;
+#define DEF_ARGS_VER2_V(T, V) _##V = V;
+#define DEF_ARGS_VER3_S(T, V) \
+ T V() const \
+ { \
+ return _##V; \
+ }
+#define DEF_ARGS_VER3_V(T, V) \
+ std::vector<T> V() const \
+ { \
+ return _##V; \
+ }
+#define DEF_ARGS_VER4_S(T, V) T _##V;
+#define DEF_ARGS_VER4_V(T, V) std::vector<T> _##V;
+
+// another level of preprocessor indirection to handle ", " as function's input argument
+#define DEF_ARGS_VER1_TRUE(T, F, V) DEF_ARGS_VER1_##F(T, V)
+#define DEF_ARGS_VER1_FALSE(T, F, V) , DEF_ARGS_VER1_##F(T, V)
+
+#define DEF_ARGS_VER0(FIRST, T, F, V) DEF_ARGS_VER0_##F(T, V)
+#define DEF_ARGS_VER1(FIRST, T, F, V) DEF_ARGS_VER1_##FIRST(T, F, V)
+#define DEF_ARGS_VER2(FIRST, T, F, V) DEF_ARGS_VER2_##F(T, V)
+#define DEF_ARGS_VER3(FIRST, T, F, V) DEF_ARGS_VER3_##F(T, V)
+#define DEF_ARGS_VER4(FIRST, T, F, V) DEF_ARGS_VER4_##F(T, V)
+
+#define DEF_ARGS_0(VER, ...)
+#define DEF_ARGS_1(VER, T0, F0, V0) DEF_ARGS_##VER(TRUE, T0, F0, V0)
+#define DEF_ARGS_2(VER, T0, F0, V0, T1, F1, V1) DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1)
+#define DEF_ARGS_3(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2)
+#define DEF_ARGS_4(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3)
+#define DEF_ARGS_5(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4)
+
+#define DEF_ARGS_6(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5)
+
+#define DEF_ARGS_7(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \
+ DEF_ARGS_##VER(FALSE, T6, F6, V6)
+
+#define DEF_VER0_VAR_DECL_PTR(NAME) const NAME* p = static_cast<const NAME*>(options);
+#define DEF_VER0_VAR_0(NAME)
+#define DEF_VER0_VAR_1(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+#define DEF_VER0_VAR_2(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+#define DEF_VER0_VAR_3(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+#define DEF_VER0_VAR_4(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+#define DEF_VER0_VAR_5(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+#define DEF_VER0_VAR_6(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+#define DEF_VER0_VAR_7(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+
+#define DEF_ATTRIBUTE(NAME, NUM_ARGS, ...) \
+ class Tosa##NAME##Attribute : public TosaAttributeBase \
+ { \
+ public: \
+ Tosa##NAME##Attribute(const TosaAttributeBase* options) \
+ { \
+ const Tosa##NAME##Attribute* p = reinterpret_cast<const Tosa##NAME##Attribute*>(options); \
+ *this = *p; \
+ } \
+ Tosa##NAME##Attribute(const Tosa##NAME##Attribute* p) \
+ { \
+ *this = *p; \
+ } \
+ Tosa##NAME##Attribute(const void* options){ DEF_VER0_VAR_##NUM_ARGS(NAME##Attribute) \
+ DEF_ARGS_##NUM_ARGS(VER0, __VA_ARGS__) } Tosa##NAME \
+ ##Attribute(DEF_ARGS_##NUM_ARGS(VER1, __VA_ARGS__)) \
+ { \
+ DEF_ARGS_##NUM_ARGS(VER2, __VA_ARGS__) \
+ } \
+ virtual ~Tosa##NAME##Attribute() \
+ {} \
+ DEF_ARGS_##NUM_ARGS(VER3, __VA_ARGS__) private : DEF_ARGS_##NUM_ARGS(VER4, __VA_ARGS__) \
+ };
+
+#include "attribute.def"
+#undef DEF_ATTRIBUTE
+#undef DEF_ARGS_0
+#undef DEF_ARGS_1
+#undef DEF_ARGS_2
+#undef DEF_ARGS_3
+#undef DEF_ARGS_4
+#undef DEF_ARGS_5
+#undef DEF_ARGS_6
+#undef DEF_ARGS_7
+#undef DEF_ARGS_VER0
+#undef DEF_ARGS_VER1
+#undef DEF_ARGS_VER2
+#undef DEF_ARGS_VER3
+#undef DEF_ARGS_VER4
+#undef DEF_ARGS_VER0_S_int32_t
+#undef DEF_ARGS_VER0_S_float
+#undef DEF_ARGS_VER0_S_bool
+#undef DEF_ARGS_VER0_S_ResizeMode
+#undef DEF_ARGS_VER0_S_string
+#undef DEF_ARGS_VER0_S_STR
+#undef DEF_ARGS_VER0_S_DEFAULT
+#undef DEF_ARGS_VER1_TRUE
+#undef DEF_ARGS_VER1_FALSE
+#undef DEF_ARGS_VER0_S
+#undef DEF_ARGS_VER0_V
+#undef DEF_ARGS_VER1_S
+#undef DEF_ARGS_VER1_V
+#undef DEF_ARGS_VER2_S
+#undef DEF_ARGS_VER2_V
+#undef DEF_ARGS_VER3_S
+#undef DEF_ARGS_VER3_V
+#undef DEF_ARGS_VER4_S
+#undef DEF_ARGS_VER4_V
+#undef DEF_VER0_VAR_0
+#undef DEF_VER0_VAR_1
+#undef DEF_VER0_VAR_2
+#undef DEF_VER0_VAR_3
+#undef DEF_VER0_VAR_4
+#undef DEF_VER0_VAR_5
+#undef DEF_VER0_VAR_DECL_PTR
+
+} // namespace tosa
+
+#endif // _TOSA_SERIALIZATION_ATTRIBUTE_H
diff --git a/serialization/operator.def b/serialization/operator.def
new file mode 100644
index 0000000..66d3784
--- /dev/null
+++ b/serialization/operator.def
@@ -0,0 +1,123 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ Syntax:
+ DEF_OPERATOR(MLIR_NAME, SCHEMA_NAME, REF_IMPL_NAME, OPTIONS, QUANT_INFO)
+
+ Description:
+ MLIR_NAME: the symbolic string of this op, must match tosa_ops.td
+ SCHEMA_NAME: corresponding operator name, must match "enum Op" in serialization/tosa.fbs
+ REF_IMPL_NAME: name used internally in tosa reference implementation
+ OPTIONS: compile time constant options of this op, corresponding to operator_option.def
+ QUANT_INFO: quantization infomation of this op, corresponding to quant_info.def
+*/
+
+
+/* tensor operators */
+DEF_OPERATOR(argmax, ARGMAX, ArgMax, Axis, None)
+DEF_OPERATOR(avg_pool2d, AVG_POOL2D, AvgPool2d, Pool2d, Unary)
+DEF_OPERATOR(conv2d, CONV2D, Conv2d, Conv2d, Conv)
+DEF_OPERATOR(conv3d, CONV3D, Conv3d, None, None)
+DEF_OPERATOR(depthwise_conv2d, DEPTHWISE_CONV2D, DepthwiseConv2d, Conv2d, Conv)
+DEF_OPERATOR(fully_connected, FULLY_CONNECTED, FullyConnected, None, Conv)
+DEF_OPERATOR(matmul, MATMUL, MatMul, None, MatMul)
+DEF_OPERATOR(max_pool2d, MAX_POOL2D, MaxPool2d, Pool2d, None)
+DEF_OPERATOR(transpose_conv2d, TRANSPOSE_CONV2D, TransposeConv2d, TransposeConv2d, Conv)
+
+/* activation */
+DEF_OPERATOR(clamp, CLAMP, Clamp, Clamp, None)
+DEF_OPERATOR(reluN, RELUN, ReluN, ReluN, None)
+DEF_OPERATOR(sigmoid, SIGMOID, Sigmoid, None, None)
+DEF_OPERATOR(tanh, TANH, Tanh, None, None)
+
+/* elementwise - binary */
+DEF_OPERATOR(add, ADD, Add, None, None)
+DEF_OPERATOR(arithmetic_right_shift, ARITHMETIC_RIGHT_SHIFT, ArithmeticRightShift, None, None)
+DEF_OPERATOR(bitwise_and, BITWISE_AND, BitwiseAnd, None, None)
+DEF_OPERATOR(bitwise_or, BITWISE_OR, BitwiseOr, None, None)
+DEF_OPERATOR(bitwise_xor, BITWISE_XOR, BitwiseXor, None, None)
+DEF_OPERATOR(logical_and, LOGICAL_AND, LogicalAnd, None, None)
+DEF_OPERATOR(logical_left_shift, LOGICAL_LEFT_SHIFT, LogicalLeftShift, None, None)
+DEF_OPERATOR(logical_right_shift, LOGICAL_RIGHT_SHIFT, LogicalRightShift, None, None)
+DEF_OPERATOR(logical_or, LOGICAL_OR, LogicalOr, None, None)
+DEF_OPERATOR(logical_xor, LOGICAL_XOR, LogicalXor, None, None)
+DEF_OPERATOR(maximum, MAXIMUM, Maximum, None, None)
+DEF_OPERATOR(minimum, MINIMUM, Minimum, None, None)
+DEF_OPERATOR(mul, MUL, Mul, None, None)
+DEF_OPERATOR(pow, POW, Pow, None, None)
+DEF_OPERATOR(sub, SUB, Sub, None, None)
+DEF_OPERATOR(table, TABLE, Table, None, None)
+
+/* elementwise - unary */
+DEF_OPERATOR(abs, ABS, Abs, None, None)
+DEF_OPERATOR(bitwise_not, BITWISE_NOT, BitwiseNot, None, None)
+DEF_OPERATOR(ceil, CEIL, Ceil, None, None)
+DEF_OPERATOR(clz, CLZ, Clz, None, None)
+DEF_OPERATOR(exp, EXP, Exp, None, None)
+DEF_OPERATOR(floor, FLOOR, Floor, None, None)
+DEF_OPERATOR(log, LOG, Log, None, None)
+DEF_OPERATOR(logical_not, LOGICAL_NOT, LogicalNot, None, None)
+DEF_OPERATOR(negate, NEGATE, Negate, None, Unary)
+DEF_OPERATOR(reciprocal, RECIPROCAL, Reciprocal, None, None)
+DEF_OPERATOR(rsqrt, RSQRT, Rsqrt, None, None)
+
+/* elementwise - ternary */
+DEF_OPERATOR(select, SELECT, Select, None, None)
+
+/* logical */
+DEF_OPERATOR(equal, EQUAL, Equal, None, None)
+DEF_OPERATOR(greater, GREATER, Greater, None, None)
+DEF_OPERATOR(greater_equal, GREATER_EQUAL, GreaterEqual, None, None)
+
+/* reduction */
+DEF_OPERATOR(reduce_any, REDUCE_ANY, ReduceAny, Reduce, None)
+DEF_OPERATOR(reduce_all, REDUCE_ALL, ReduceAll, Reduce, None)
+DEF_OPERATOR(reduce_max, REDUCE_MAX, ReduceMax, Reduce, None)
+DEF_OPERATOR(reduce_min, REDUCE_MIN, ReduceMin, Reduce, None)
+DEF_OPERATOR(reduce_prod, REDUCE_PRODUCT, ReduceProduct, Reduce, None)
+DEF_OPERATOR(reduce_sum, REDUCE_SUM, ReduceSum, Reduce, None)
+
+/* memory operation */
+DEF_OPERATOR(concat, CONCAT, Concat, Axis, None)
+DEF_OPERATOR(pad, PAD, Pad, None, Pad)
+DEF_OPERATOR(reshape, RESHAPE, Reshape, Reshape, None)
+DEF_OPERATOR(reverse, REVERSE, Reverse, Reverse, None)
+DEF_OPERATOR(slice, SLICE, Slice, Slice, None)
+DEF_OPERATOR(tile, TILE, Tile, Tile, None)
+DEF_OPERATOR(transpose, TRANSPOSE, Transpose, None, None)
+
+/* gather/scatter */
+DEF_OPERATOR(gather, GATHER, Gather, Axis, None)
+
+/* image */
+DEF_OPERATOR(resize, RESIZE, Resize, Resize, None)
+
+/* quantization */
+DEF_OPERATOR(cast, CAST, Cast, None, None)
+DEF_OPERATOR(rescale, RESCALE, Rescale, Rescale, None)
+
+/* data nodes */
+DEF_OPERATOR(const, CONST, Const, None, None)
+DEF_OPERATOR(placeholder, PLACEHOLDER, Placeholder, None, None)
+DEF_OPERATOR(identity, IDENTITY, Identity, None, None)
+DEF_OPERATOR(identityn, IDENTITYN, IdentityN, None, None)
+
+/* custom operations */
+DEF_OPERATOR(custom, CUSTOM, Custom, None, None)
+
+/* control flow operators */
+DEF_OPERATOR(cond_if, COND_IF, CondIf, CondIf, None)
+DEF_OPERATOR(while_loop, WHILE_LOOP, WhileLoop, WhileLoop, None)
diff --git a/serialization/quant_info.def b/serialization/quant_info.def
new file mode 100644
index 0000000..39dc101
--- /dev/null
+++ b/serialization/quant_info.def
@@ -0,0 +1,43 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ Syntax:
+ DEF_QUANTIZATION_INFO(NAME, NUM_ARGS_IN_OPTIONS, ARG0_TYPE, ARG0_SCALAR_OR_VECTOR, ARGS0_NAME, ...)
+
+ Description:
+ NAME: corresponding quantization info name, must match corresponding "table XXXQuantInfo" in tosa.fbs
+ NUM_ARGS_IN_QINFO: number of arguments in this quantization info
+ ARG0_TYPE: data type of arg0
+ ARG0_SCALAR_OR_VECTOR: is arg0 a scalar (S) or a vector (V)
+ ARG0_NAME: name of arg0
+ ...: variadic variables for more arguments, depending on NUM_ARGS_IN_QINFO
+*/
+
+
+DEF_QUANTIZATION_INFO(Unary, 2,
+ int32_t, S, input_zp,
+ int32_t, S, output_zp)
+
+DEF_QUANTIZATION_INFO(Conv, 2,
+ int32_t, S, input_zp,
+ int32_t, S, weight_zp)
+
+DEF_QUANTIZATION_INFO(MatMul, 2,
+ int32_t, S, a_zp,
+ int32_t, S, b_zp)
+
+DEF_QUANTIZATION_INFO(Pad, 1,
+ int32_t, S, input_zp)
diff --git a/serialization/quant_info.h b/serialization/quant_info.h
new file mode 100644
index 0000000..03dcab9
--- /dev/null
+++ b/serialization/quant_info.h
@@ -0,0 +1,164 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef _TOSA_SERIALIZATION_QUANT_INFO_H
+#define _TOSA_SERIALIZATION_QUANT_INFO_H
+#include "flatbuffers/idl.h"
+#include "flatbuffers/util.h"
+#include "tosa_generated.h"
+
+namespace tosa
+{
+
+class TosaQuantInfoBase
+{
+public:
+ virtual ~TosaQuantInfoBase()
+ {}
+};
+
+class TosaNoneQuantInfo : public TosaQuantInfoBase
+{
+public:
+ TosaNoneQuantInfo()
+ {}
+ TosaNoneQuantInfo(TosaNoneQuantInfo* p)
+ {}
+};
+
+#define DEF_ARGS_VER0_S(T, V) _##V = p->V();
+#define DEF_ARGS_VER0_V(T, V) _##V = std::vector<T>(p->V()->begin(), p->V()->end());
+#define DEF_ARGS_VER1_S(T, V) const T& V
+#define DEF_ARGS_VER1_V(T, V) const std::vector<T>& V
+#define DEF_ARGS_VER2_S(T, V) _##V = V;
+#define DEF_ARGS_VER2_V(T, V) _##V = V;
+#define DEF_ARGS_VER3_S(T, V) \
+ T V() const \
+ { \
+ return _##V; \
+ }
+#define DEF_ARGS_VER3_V(T, V) \
+ std::vector<T> V() const \
+ { \
+ return _##V; \
+ }
+#define DEF_ARGS_VER4_S(T, V) T _##V;
+#define DEF_ARGS_VER4_V(T, V) std::vector<T> _##V;
+
+// another level of preprocessor indirection to handle ", " as function's input argument
+#define DEF_ARGS_VER1_TRUE(T, F, V) DEF_ARGS_VER1_##F(T, V)
+#define DEF_ARGS_VER1_FALSE(T, F, V) , DEF_ARGS_VER1_##F(T, V)
+
+#define DEF_ARGS_VER0(FIRST, T, F, V) DEF_ARGS_VER0_##F(T, V)
+#define DEF_ARGS_VER1(FIRST, T, F, V) DEF_ARGS_VER1_##FIRST(T, F, V)
+#define DEF_ARGS_VER2(FIRST, T, F, V) DEF_ARGS_VER2_##F(T, V)
+#define DEF_ARGS_VER3(FIRST, T, F, V) DEF_ARGS_VER3_##F(T, V)
+#define DEF_ARGS_VER4(FIRST, T, F, V) DEF_ARGS_VER4_##F(T, V)
+
+#define DEF_ARGS_1(VER, T0, F0, V0) DEF_ARGS_##VER(TRUE, T0, F0, V0)
+#define DEF_ARGS_2(VER, T0, F0, V0, T1, F1, V1) DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1)
+#define DEF_ARGS_3(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2)
+#define DEF_ARGS_4(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3)
+#define DEF_ARGS_5(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4)
+#define DEF_ARGS_6(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5)
+#define DEF_ARGS_7(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \
+ DEF_ARGS_##VER(FALSE, T6, F6, V6)
+#define DEF_ARGS_8(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \
+ DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7)
+#define DEF_ARGS_9(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7, T8, F8, V8) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \
+ DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) DEF_ARGS_##VER(FALSE, T8, F8, V8)
+#define DEF_ARGS_10(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7, T8, F8, V8, T9, F9, V9) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \
+ DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) DEF_ARGS_##VER(FALSE, T8, F8, V8) \
+ DEF_ARGS_##VER(FALSE, T9, F9, V9)
+
+#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \
+ class Tosa##NAME##QuantInfo : public TosaQuantInfoBase \
+ { \
+ public: \
+ Tosa##NAME##QuantInfo(const TosaQuantInfoBase* qinfo) \
+ { \
+ const Tosa##NAME##QuantInfo* p = dynamic_cast<const Tosa##NAME##QuantInfo*>(qinfo); \
+ assert(p); \
+ *this = *p; \
+ } \
+ Tosa##NAME##QuantInfo(const Tosa##NAME##QuantInfo* p) \
+ { \
+ *this = *p; \
+ } \
+ Tosa##NAME##QuantInfo(const void* qinfo) \
+ { \
+ const NAME##QuantInfo* p = static_cast<const NAME##QuantInfo*>(qinfo); \
+ DEF_ARGS_##NUM_ARGS(VER0, __VA_ARGS__) \
+ } \
+ Tosa##NAME##QuantInfo(DEF_ARGS_##NUM_ARGS(VER1, __VA_ARGS__)) \
+ { \
+ DEF_ARGS_##NUM_ARGS(VER2, __VA_ARGS__) \
+ } \
+ virtual ~Tosa##NAME##QuantInfo() \
+ {} \
+ DEF_ARGS_##NUM_ARGS(VER3, __VA_ARGS__) private : DEF_ARGS_##NUM_ARGS(VER4, __VA_ARGS__) \
+ };
+
+#include "quant_info.def"
+#undef DEF_QUANTIZATION_INFO
+#undef DEF_ARGS_1
+#undef DEF_ARGS_2
+#undef DEF_ARGS_3
+#undef DEF_ARGS_4
+#undef DEF_ARGS_5
+#undef DEF_ARGS_6
+#undef DEF_ARGS_7
+#undef DEF_ARGS_8
+#undef DEF_ARGS_9
+#undef DEF_ARGS_10
+#undef DEF_ARGS_VER0
+#undef DEF_ARGS_VER1
+#undef DEF_ARGS_VER2
+#undef DEF_ARGS_VER3
+#undef DEF_ARGS_VER4
+#undef DEF_ARGS_VER1_TRUE
+#undef DEF_ARGS_VER1_FALSE
+#undef DEF_ARGS_VER0_S
+#undef DEF_ARGS_VER0_V
+#undef DEF_ARGS_VER1_S
+#undef DEF_ARGS_VER1_V
+#undef DEF_ARGS_VER2_S
+#undef DEF_ARGS_VER2_V
+#undef DEF_ARGS_VER3_S
+#undef DEF_ARGS_VER3_V
+#undef DEF_ARGS_VER4_S
+#undef DEF_ARGS_VER4_V
+
+} // namespace tosa
+
+#endif
diff --git a/serialization/tosa.fbs b/serialization/tosa.fbs
new file mode 100644
index 0000000..841cf3d
--- /dev/null
+++ b/serialization/tosa.fbs
@@ -0,0 +1,318 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace tosa;
+
+// This corresponds to the version.
+file_identifier "TOSA";
+// File extension of any written files.
+file_extension "tosa";
+
+enum DType:uint32 {
+ UNKNOWN = 0,
+ BOOL,
+ AINT8,
+ UINT8,
+ INT4,
+ INT8,
+ INT16,
+ INT32,
+ INT48,
+ FLOAT,
+}
+
+enum Format:uint32 {
+ UNKNOWN = 0,
+ NHWC,
+ NDHWC,
+ OHWI,
+ HWIM,
+ DOHWI,
+}
+
+enum Usage:uint32 {
+ UNKNOWN = 0,
+ ACTIVATION,
+ WEIGHT,
+ INDEX,
+}
+
+enum ResizeMode:uint32 {
+ UNKNOWN = 0,
+ NEAREST,
+ BILINEAR,
+}
+
+enum Op:uint32 {
+ UNKNOWN = 0,
+
+ // Tensor Operator
+ ARGMAX,
+ AVG_POOL2D,
+ CONV2D,
+ CONV3D,
+ DEPTHWISE_CONV2D,
+ FULLY_CONNECTED,
+ MATMUL,
+ MAX_POOL2D,
+ TRANSPOSE_CONV2D,
+
+ // Activation
+ CLAMP,
+ RELUN,
+ SIGMOID,
+ TANH,
+
+ // Elementwise-Binary
+ ADD,
+ ARITHMETIC_RIGHT_SHIFT,
+ BITWISE_AND,
+ BITWISE_OR,
+ BITWISE_XOR,
+ LOGICAL_AND,
+ LOGICAL_LEFT_SHIFT,
+ LOGICAL_RIGHT_SHIFT,
+ LOGICAL_OR,
+ LOGICAL_XOR,
+ MAXIMUM,
+ MINIMUM,
+ MUL,
+ POW,
+ SUB,
+ TABLE,
+
+ // Elementwise-Unary
+ ABS,
+ BITWISE_NOT,
+ CEIL,
+ CLZ,
+ EXP,
+ FLOOR,
+ LOG,
+ LOGICAL_NOT,
+ NEGATE,
+ RECIPROCAL,
+ RSQRT,
+
+ // Elementwise-Ternary
+ SELECT,
+
+ // Logical
+ EQUAL,
+ GREATER,
+ GREATER_EQUAL,
+
+ // Reduction
+ REDUCE_ANY,
+ REDUCE_ALL,
+ REDUCE_MAX,
+ REDUCE_MIN,
+ REDUCE_PRODUCT,
+ REDUCE_SUM,
+
+ // Data layout operation
+ CONCAT,
+ PAD,
+ RESHAPE,
+ REVERSE,
+ SLICE,
+ TILE,
+ TRANSPOSE,
+
+ // Gather/scatter operation
+ GATHER,
+
+ // Image
+ RESIZE,
+
+ // Type conversion
+ CAST,
+ RESCALE,
+
+ // Data Nodes
+ CONST,
+ PLACEHOLDER,
+ IDENTITY,
+ IDENTITYN,
+
+ // Custom operations
+ CUSTOM,
+
+ // Control flow operators
+ COND_IF,
+ WHILE_LOOP,
+}
+
+union Attribute {
+ Pool2dAttribute,
+ Conv2dAttribute,
+ TransposeConv2dAttribute,
+ ReluNAttribute,
+ AxisAttribute,
+ ReshapeAttribute,
+ SliceAttribute,
+ TileAttribute,
+ ResizeAttribute,
+ ClampAttribute,
+ RescaleAttribute,
+ CustomAttribute,
+ CondIfAttribute,
+ WhileLoopAttribute,
+}
+
+table Pool2dAttribute {
+ padding: [int32];
+ kernel: [int32];
+ stride: [int32];
+}
+
+table Conv2dAttribute {
+ padding: [int32];
+ stride: [int32];
+ dilation: [int32];
+}
+
+table TransposeConv2dAttribute {
+ outpad: [int32];
+ stride: [int32];
+ dilation: [int32];
+ output_shape: [int32];
+}
+
+table ReluNAttribute {
+ max_int: int32;
+ max_fp: float;
+}
+
+table AxisAttribute {
+ axis: int32;
+}
+
+table ReshapeAttribute {
+ shape: [int32];
+}
+
+table SliceAttribute {
+ begin: [int32];
+ size: [int32];
+}
+
+table TileAttribute {
+ multiples: [int32];
+}
+
+table ResizeAttribute {
+ output_size: [int32];
+ stride: [int32];
+ offset: [int32];
+ shift: int32;
+ mode: ResizeMode;
+}
+
+table ClampAttribute {
+ min_int: int32;
+ max_int: int32;
+ min_fp: float;
+ max_fp: float;
+}
+
+table RescaleAttribute {
+ input_zp: int32;
+ output_zp: int32;
+ multiplier: [int32];
+ shift: [int32];
+ scale32: bool;
+ double_round: bool;
+ per_channel: bool;
+}
+
+table CustomAttribute {
+ identifier: string;
+}
+
+table CondIfAttribute {
+ then_branch: string;
+ else_branch: string;
+}
+
+table WhileLoopAttribute {
+ cond_branch: string;
+ body_branch: string;
+}
+
+union QuantInfo {
+ UnaryQuantInfo,
+ ConvQuantInfo,
+ MatMulQuantInfo,
+ PadQuantInfo,
+}
+
+table UnaryQuantInfo {
+ input_zp: int32;
+ output_zp: int32;
+}
+
+table ConvQuantInfo {
+ input_zp: int32;
+ weight_zp: int32;
+}
+
+table MatMulQuantInfo {
+ a_zp: int32;
+ b_zp: int32;
+}
+
+table PadQuantInfo {
+ input_zp: int32;
+}
+
+table Version {
+ _major: int32 = 0;
+ _minor: int32 = 20;
+ _patch: int32 = 0;
+ _experimental: bool = false;
+}
+
+table TosaTensor {
+ name:string; // name of the tensor, used for solving dependency
+ shape:[int32]; // shape of the tensor
+ type:DType; // data type of the tensor
+ usage:[Usage]; // vector of possible usages. for the convenience of debugging only.
+ format:[Format]; // vector of possible formats. for the convenience of debugging only.
+ npy_filename: string; // numpy array filename
+}
+
+table TosaOperator {
+ op:Op; // operator enum
+ attribute: Attribute; // union structure. operator attribute
+ inputs:[string]; // list of input tensor names
+ outputs:[string]; // list of output tensor names
+ quant_info: QuantInfo; // op-based quantization information
+}
+
+table TosaBasicBlock {
+ name:string; // basic block name
+ operators:[TosaOperator]; // operators array
+ tensors:[TosaTensor]; // tensors array
+ inputs:[string]; // name of graph inputs
+ outputs:[string]; // name of graph outputs
+}
+
+table TosaGraph {
+ version: Version;
+ blocks:[TosaBasicBlock]; // basic blocks array
+}
+
+root_type TosaGraph;
diff --git a/serialization/tosa_generated.h b/serialization/tosa_generated.h
new file mode 100644
index 0000000..5bb21f3
--- /dev/null
+++ b/serialization/tosa_generated.h
@@ -0,0 +1,2605 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_TOSA_TOSA_H_
+#define FLATBUFFERS_GENERATED_TOSA_TOSA_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+namespace tosa {
+
+struct Pool2dAttribute;
+
+struct Conv2dAttribute;
+
+struct TransposeConv2dAttribute;
+
+struct ReluNAttribute;
+
+struct AxisAttribute;
+
+struct ReshapeAttribute;
+
+struct SliceAttribute;
+
+struct TileAttribute;
+
+struct ResizeAttribute;
+
+struct ClampAttribute;
+
+struct RescaleAttribute;
+
+struct CustomAttribute;
+
+struct CondIfAttribute;
+
+struct WhileLoopAttribute;
+
+struct UnaryQuantInfo;
+
+struct ConvQuantInfo;
+
+struct MatMulQuantInfo;
+
+struct PadQuantInfo;
+
+struct Version;
+
+struct TosaTensor;
+
+struct TosaOperator;
+
+struct TosaBasicBlock;
+
+struct TosaGraph;
+
+enum DType {
+ DType_UNKNOWN = 0,
+ DType_BOOL = 1,
+ DType_AINT8 = 2,
+ DType_UINT8 = 3,
+ DType_INT4 = 4,
+ DType_INT8 = 5,
+ DType_INT16 = 6,
+ DType_INT32 = 7,
+ DType_INT48 = 8,
+ DType_FLOAT = 9,
+ DType_MIN = DType_UNKNOWN,
+ DType_MAX = DType_FLOAT
+};
+
+inline const DType (&EnumValuesDType())[10] {
+ static const DType values[] = {
+ DType_UNKNOWN,
+ DType_BOOL,
+ DType_AINT8,
+ DType_UINT8,
+ DType_INT4,
+ DType_INT8,
+ DType_INT16,
+ DType_INT32,
+ DType_INT48,
+ DType_FLOAT
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesDType() {
+ static const char * const names[] = {
+ "UNKNOWN",
+ "BOOL",
+ "AINT8",
+ "UINT8",
+ "INT4",
+ "INT8",
+ "INT16",
+ "INT32",
+ "INT48",
+ "FLOAT",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameDType(DType e) {
+ if (e < DType_UNKNOWN || e > DType_FLOAT) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesDType()[index];
+}
+
+enum Format {
+ Format_UNKNOWN = 0,
+ Format_NHWC = 1,
+ Format_NDHWC = 2,
+ Format_OHWI = 3,
+ Format_HWIM = 4,
+ Format_DOHWI = 5,
+ Format_MIN = Format_UNKNOWN,
+ Format_MAX = Format_DOHWI
+};
+
+inline const Format (&EnumValuesFormat())[6] {
+ static const Format values[] = {
+ Format_UNKNOWN,
+ Format_NHWC,
+ Format_NDHWC,
+ Format_OHWI,
+ Format_HWIM,
+ Format_DOHWI
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesFormat() {
+ static const char * const names[] = {
+ "UNKNOWN",
+ "NHWC",
+ "NDHWC",
+ "OHWI",
+ "HWIM",
+ "DOHWI",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameFormat(Format e) {
+ if (e < Format_UNKNOWN || e > Format_DOHWI) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesFormat()[index];
+}
+
+enum Usage {
+ Usage_UNKNOWN = 0,
+ Usage_ACTIVATION = 1,
+ Usage_WEIGHT = 2,
+ Usage_INDEX = 3,
+ Usage_MIN = Usage_UNKNOWN,
+ Usage_MAX = Usage_INDEX
+};
+
+inline const Usage (&EnumValuesUsage())[4] {
+ static const Usage values[] = {
+ Usage_UNKNOWN,
+ Usage_ACTIVATION,
+ Usage_WEIGHT,
+ Usage_INDEX
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesUsage() {
+ static const char * const names[] = {
+ "UNKNOWN",
+ "ACTIVATION",
+ "WEIGHT",
+ "INDEX",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameUsage(Usage e) {
+ if (e < Usage_UNKNOWN || e > Usage_INDEX) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesUsage()[index];
+}
+
+enum ResizeMode {
+ ResizeMode_UNKNOWN = 0,
+ ResizeMode_NEAREST = 1,
+ ResizeMode_BILINEAR = 2,
+ ResizeMode_MIN = ResizeMode_UNKNOWN,
+ ResizeMode_MAX = ResizeMode_BILINEAR
+};
+
+inline const ResizeMode (&EnumValuesResizeMode())[3] {
+ static const ResizeMode values[] = {
+ ResizeMode_UNKNOWN,
+ ResizeMode_NEAREST,
+ ResizeMode_BILINEAR
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesResizeMode() {
+ static const char * const names[] = {
+ "UNKNOWN",
+ "NEAREST",
+ "BILINEAR",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameResizeMode(ResizeMode e) {
+ if (e < ResizeMode_UNKNOWN || e > ResizeMode_BILINEAR) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesResizeMode()[index];
+}
+
+enum Op {
+ Op_UNKNOWN = 0,
+ Op_ARGMAX = 1,
+ Op_AVG_POOL2D = 2,
+ Op_CONV2D = 3,
+ Op_CONV3D = 4,
+ Op_DEPTHWISE_CONV2D = 5,
+ Op_FULLY_CONNECTED = 6,
+ Op_MATMUL = 7,
+ Op_MAX_POOL2D = 8,
+ Op_TRANSPOSE_CONV2D = 9,
+ Op_CLAMP = 10,
+ Op_RELUN = 11,
+ Op_SIGMOID = 12,
+ Op_TANH = 13,
+ Op_ADD = 14,
+ Op_ARITHMETIC_RIGHT_SHIFT = 15,
+ Op_BITWISE_AND = 16,
+ Op_BITWISE_OR = 17,
+ Op_BITWISE_XOR = 18,
+ Op_LOGICAL_AND = 19,
+ Op_LOGICAL_LEFT_SHIFT = 20,
+ Op_LOGICAL_RIGHT_SHIFT = 21,
+ Op_LOGICAL_OR = 22,
+ Op_LOGICAL_XOR = 23,
+ Op_MAXIMUM = 24,
+ Op_MINIMUM = 25,
+ Op_MUL = 26,
+ Op_POW = 27,
+ Op_SUB = 28,
+ Op_TABLE = 29,
+ Op_ABS = 30,
+ Op_BITWISE_NOT = 31,
+ Op_CEIL = 32,
+ Op_CLZ = 33,
+ Op_EXP = 34,
+ Op_FLOOR = 35,
+ Op_LOG = 36,
+ Op_LOGICAL_NOT = 37,
+ Op_NEGATE = 38,
+ Op_RECIPROCAL = 39,
+ Op_RSQRT = 40,
+ Op_SELECT = 41,
+ Op_EQUAL = 42,
+ Op_GREATER = 43,
+ Op_GREATER_EQUAL = 44,
+ Op_REDUCE_ANY = 45,
+ Op_REDUCE_ALL = 46,
+ Op_REDUCE_MAX = 47,
+ Op_REDUCE_MIN = 48,
+ Op_REDUCE_PRODUCT = 49,
+ Op_REDUCE_SUM = 50,
+ Op_CONCAT = 51,
+ Op_PAD = 52,
+ Op_RESHAPE = 53,
+ Op_REVERSE = 54,
+ Op_SLICE = 55,
+ Op_TILE = 56,
+ Op_TRANSPOSE = 57,
+ Op_GATHER = 58,
+ Op_RESIZE = 59,
+ Op_CAST = 60,
+ Op_RESCALE = 61,
+ Op_CONST = 62,
+ Op_PLACEHOLDER = 63,
+ Op_IDENTITY = 64,
+ Op_IDENTITYN = 65,
+ Op_CUSTOM = 66,
+ Op_COND_IF = 67,
+ Op_WHILE_LOOP = 68,
+ Op_MIN = Op_UNKNOWN,
+ Op_MAX = Op_WHILE_LOOP
+};
+
+inline const Op (&EnumValuesOp())[69] {
+ static const Op values[] = {
+ Op_UNKNOWN,
+ Op_ARGMAX,
+ Op_AVG_POOL2D,
+ Op_CONV2D,
+ Op_CONV3D,
+ Op_DEPTHWISE_CONV2D,
+ Op_FULLY_CONNECTED,
+ Op_MATMUL,
+ Op_MAX_POOL2D,
+ Op_TRANSPOSE_CONV2D,
+ Op_CLAMP,
+ Op_RELUN,
+ Op_SIGMOID,
+ Op_TANH,
+ Op_ADD,
+ Op_ARITHMETIC_RIGHT_SHIFT,
+ Op_BITWISE_AND,
+ Op_BITWISE_OR,
+ Op_BITWISE_XOR,
+ Op_LOGICAL_AND,
+ Op_LOGICAL_LEFT_SHIFT,
+ Op_LOGICAL_RIGHT_SHIFT,
+ Op_LOGICAL_OR,
+ Op_LOGICAL_XOR,
+ Op_MAXIMUM,
+ Op_MINIMUM,
+ Op_MUL,
+ Op_POW,
+ Op_SUB,
+ Op_TABLE,
+ Op_ABS,
+ Op_BITWISE_NOT,
+ Op_CEIL,
+ Op_CLZ,
+ Op_EXP,
+ Op_FLOOR,
+ Op_LOG,
+ Op_LOGICAL_NOT,
+ Op_NEGATE,
+ Op_RECIPROCAL,
+ Op_RSQRT,
+ Op_SELECT,
+ Op_EQUAL,
+ Op_GREATER,
+ Op_GREATER_EQUAL,
+ Op_REDUCE_ANY,
+ Op_REDUCE_ALL,
+ Op_REDUCE_MAX,
+ Op_REDUCE_MIN,
+ Op_REDUCE_PRODUCT,
+ Op_REDUCE_SUM,
+ Op_CONCAT,
+ Op_PAD,
+ Op_RESHAPE,
+ Op_REVERSE,
+ Op_SLICE,
+ Op_TILE,
+ Op_TRANSPOSE,
+ Op_GATHER,
+ Op_RESIZE,
+ Op_CAST,
+ Op_RESCALE,
+ Op_CONST,
+ Op_PLACEHOLDER,
+ Op_IDENTITY,
+ Op_IDENTITYN,
+ Op_CUSTOM,
+ Op_COND_IF,
+ Op_WHILE_LOOP
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesOp() {
+ static const char * const names[] = {
+ "UNKNOWN",
+ "ARGMAX",
+ "AVG_POOL2D",
+ "CONV2D",
+ "CONV3D",
+ "DEPTHWISE_CONV2D",
+ "FULLY_CONNECTED",
+ "MATMUL",
+ "MAX_POOL2D",
+ "TRANSPOSE_CONV2D",
+ "CLAMP",
+ "RELUN",
+ "SIGMOID",
+ "TANH",
+ "ADD",
+ "ARITHMETIC_RIGHT_SHIFT",
+ "BITWISE_AND",
+ "BITWISE_OR",
+ "BITWISE_XOR",
+ "LOGICAL_AND",
+ "LOGICAL_LEFT_SHIFT",
+ "LOGICAL_RIGHT_SHIFT",
+ "LOGICAL_OR",
+ "LOGICAL_XOR",
+ "MAXIMUM",
+ "MINIMUM",
+ "MUL",
+ "POW",
+ "SUB",
+ "TABLE",
+ "ABS",
+ "BITWISE_NOT",
+ "CEIL",
+ "CLZ",
+ "EXP",
+ "FLOOR",
+ "LOG",
+ "LOGICAL_NOT",
+ "NEGATE",
+ "RECIPROCAL",
+ "RSQRT",
+ "SELECT",
+ "EQUAL",
+ "GREATER",
+ "GREATER_EQUAL",
+ "REDUCE_ANY",
+ "REDUCE_ALL",
+ "REDUCE_MAX",
+ "REDUCE_MIN",
+ "REDUCE_PRODUCT",
+ "REDUCE_SUM",
+ "CONCAT",
+ "PAD",
+ "RESHAPE",
+ "REVERSE",
+ "SLICE",
+ "TILE",
+ "TRANSPOSE",
+ "GATHER",
+ "RESIZE",
+ "CAST",
+ "RESCALE",
+ "CONST",
+ "PLACEHOLDER",
+ "IDENTITY",
+ "IDENTITYN",
+ "CUSTOM",
+ "COND_IF",
+ "WHILE_LOOP",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameOp(Op e) {
+ if (e < Op_UNKNOWN || e > Op_WHILE_LOOP) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesOp()[index];
+}
+
+enum Attribute {
+ Attribute_NONE = 0,
+ Attribute_Pool2dAttribute = 1,
+ Attribute_Conv2dAttribute = 2,
+ Attribute_TransposeConv2dAttribute = 3,
+ Attribute_ReluNAttribute = 4,
+ Attribute_AxisAttribute = 5,
+ Attribute_ReshapeAttribute = 6,
+ Attribute_SliceAttribute = 7,
+ Attribute_TileAttribute = 8,
+ Attribute_ResizeAttribute = 9,
+ Attribute_ClampAttribute = 10,
+ Attribute_RescaleAttribute = 11,
+ Attribute_CustomAttribute = 12,
+ Attribute_CondIfAttribute = 13,
+ Attribute_WhileLoopAttribute = 14,
+ Attribute_MIN = Attribute_NONE,
+ Attribute_MAX = Attribute_WhileLoopAttribute
+};
+
+inline const Attribute (&EnumValuesAttribute())[15] {
+ static const Attribute values[] = {
+ Attribute_NONE,
+ Attribute_Pool2dAttribute,
+ Attribute_Conv2dAttribute,
+ Attribute_TransposeConv2dAttribute,
+ Attribute_ReluNAttribute,
+ Attribute_AxisAttribute,
+ Attribute_ReshapeAttribute,
+ Attribute_SliceAttribute,
+ Attribute_TileAttribute,
+ Attribute_ResizeAttribute,
+ Attribute_ClampAttribute,
+ Attribute_RescaleAttribute,
+ Attribute_CustomAttribute,
+ Attribute_CondIfAttribute,
+ Attribute_WhileLoopAttribute
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesAttribute() {
+ static const char * const names[] = {
+ "NONE",
+ "Pool2dAttribute",
+ "Conv2dAttribute",
+ "TransposeConv2dAttribute",
+ "ReluNAttribute",
+ "AxisAttribute",
+ "ReshapeAttribute",
+ "SliceAttribute",
+ "TileAttribute",
+ "ResizeAttribute",
+ "ClampAttribute",
+ "RescaleAttribute",
+ "CustomAttribute",
+ "CondIfAttribute",
+ "WhileLoopAttribute",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameAttribute(Attribute e) {
+ if (e < Attribute_NONE || e > Attribute_WhileLoopAttribute) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesAttribute()[index];
+}
+
+template<typename T> struct AttributeTraits {
+ static const Attribute enum_value = Attribute_NONE;
+};
+
+template<> struct AttributeTraits<Pool2dAttribute> {
+ static const Attribute enum_value = Attribute_Pool2dAttribute;
+};
+
+template<> struct AttributeTraits<Conv2dAttribute> {
+ static const Attribute enum_value = Attribute_Conv2dAttribute;
+};
+
+template<> struct AttributeTraits<TransposeConv2dAttribute> {
+ static const Attribute enum_value = Attribute_TransposeConv2dAttribute;
+};
+
+template<> struct AttributeTraits<ReluNAttribute> {
+ static const Attribute enum_value = Attribute_ReluNAttribute;
+};
+
+template<> struct AttributeTraits<AxisAttribute> {
+ static const Attribute enum_value = Attribute_AxisAttribute;
+};
+
+template<> struct AttributeTraits<ReshapeAttribute> {
+ static const Attribute enum_value = Attribute_ReshapeAttribute;
+};
+
+template<> struct AttributeTraits<SliceAttribute> {
+ static const Attribute enum_value = Attribute_SliceAttribute;
+};
+
+template<> struct AttributeTraits<TileAttribute> {
+ static const Attribute enum_value = Attribute_TileAttribute;
+};
+
+template<> struct AttributeTraits<ResizeAttribute> {
+ static const Attribute enum_value = Attribute_ResizeAttribute;
+};
+
+template<> struct AttributeTraits<ClampAttribute> {
+ static const Attribute enum_value = Attribute_ClampAttribute;
+};
+
+template<> struct AttributeTraits<RescaleAttribute> {
+ static const Attribute enum_value = Attribute_RescaleAttribute;
+};
+
+template<> struct AttributeTraits<CustomAttribute> {
+ static const Attribute enum_value = Attribute_CustomAttribute;
+};
+
+template<> struct AttributeTraits<CondIfAttribute> {
+ static const Attribute enum_value = Attribute_CondIfAttribute;
+};
+
+template<> struct AttributeTraits<WhileLoopAttribute> {
+ static const Attribute enum_value = Attribute_WhileLoopAttribute;
+};
+
+bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type);
+bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+enum QuantInfo {
+ QuantInfo_NONE = 0,
+ QuantInfo_UnaryQuantInfo = 1,
+ QuantInfo_ConvQuantInfo = 2,
+ QuantInfo_MatMulQuantInfo = 3,
+ QuantInfo_PadQuantInfo = 4,
+ QuantInfo_MIN = QuantInfo_NONE,
+ QuantInfo_MAX = QuantInfo_PadQuantInfo
+};
+
+inline const QuantInfo (&EnumValuesQuantInfo())[5] {
+ static const QuantInfo values[] = {
+ QuantInfo_NONE,
+ QuantInfo_UnaryQuantInfo,
+ QuantInfo_ConvQuantInfo,
+ QuantInfo_MatMulQuantInfo,
+ QuantInfo_PadQuantInfo
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesQuantInfo() {
+ static const char * const names[] = {
+ "NONE",
+ "UnaryQuantInfo",
+ "ConvQuantInfo",
+ "MatMulQuantInfo",
+ "PadQuantInfo",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameQuantInfo(QuantInfo e) {
+ if (e < QuantInfo_NONE || e > QuantInfo_PadQuantInfo) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesQuantInfo()[index];
+}
+
+template<typename T> struct QuantInfoTraits {
+ static const QuantInfo enum_value = QuantInfo_NONE;
+};
+
+template<> struct QuantInfoTraits<UnaryQuantInfo> {
+ static const QuantInfo enum_value = QuantInfo_UnaryQuantInfo;
+};
+
+template<> struct QuantInfoTraits<ConvQuantInfo> {
+ static const QuantInfo enum_value = QuantInfo_ConvQuantInfo;
+};
+
+template<> struct QuantInfoTraits<MatMulQuantInfo> {
+ static const QuantInfo enum_value = QuantInfo_MatMulQuantInfo;
+};
+
+template<> struct QuantInfoTraits<PadQuantInfo> {
+ static const QuantInfo enum_value = QuantInfo_PadQuantInfo;
+};
+
+bool VerifyQuantInfo(flatbuffers::Verifier &verifier, const void *obj, QuantInfo type);
+bool VerifyQuantInfoVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+struct Pool2dAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_PADDING = 4,
+ VT_KERNEL = 6,
+ VT_STRIDE = 8
+ };
+ const flatbuffers::Vector<int32_t> *padding() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PADDING);
+ }
+ const flatbuffers::Vector<int32_t> *kernel() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_KERNEL);
+ }
+ const flatbuffers::Vector<int32_t> *stride() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_PADDING) &&
+ verifier.VerifyVector(padding()) &&
+ VerifyOffset(verifier, VT_KERNEL) &&
+ verifier.VerifyVector(kernel()) &&
+ VerifyOffset(verifier, VT_STRIDE) &&
+ verifier.VerifyVector(stride()) &&
+ verifier.EndTable();
+ }
+};
+
+struct Pool2dAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_padding(flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding) {
+ fbb_.AddOffset(Pool2dAttribute::VT_PADDING, padding);
+ }
+ void add_kernel(flatbuffers::Offset<flatbuffers::Vector<int32_t>> kernel) {
+ fbb_.AddOffset(Pool2dAttribute::VT_KERNEL, kernel);
+ }
+ void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) {
+ fbb_.AddOffset(Pool2dAttribute::VT_STRIDE, stride);
+ }
+ explicit Pool2dAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Pool2dAttributeBuilder &operator=(const Pool2dAttributeBuilder &);
+ flatbuffers::Offset<Pool2dAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Pool2dAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Pool2dAttribute> CreatePool2dAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> kernel = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0) {
+ Pool2dAttributeBuilder builder_(_fbb);
+ builder_.add_stride(stride);
+ builder_.add_kernel(kernel);
+ builder_.add_padding(padding);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Pool2dAttribute> CreatePool2dAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *padding = nullptr,
+ const std::vector<int32_t> *kernel = nullptr,
+ const std::vector<int32_t> *stride = nullptr) {
+ auto padding__ = padding ? _fbb.CreateVector<int32_t>(*padding) : 0;
+ auto kernel__ = kernel ? _fbb.CreateVector<int32_t>(*kernel) : 0;
+ auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
+ return tosa::CreatePool2dAttribute(
+ _fbb,
+ padding__,
+ kernel__,
+ stride__);
+}
+
+struct Conv2dAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_PADDING = 4,
+ VT_STRIDE = 6,
+ VT_DILATION = 8
+ };
+ const flatbuffers::Vector<int32_t> *padding() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PADDING);
+ }
+ const flatbuffers::Vector<int32_t> *stride() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE);
+ }
+ const flatbuffers::Vector<int32_t> *dilation() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_DILATION);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_PADDING) &&
+ verifier.VerifyVector(padding()) &&
+ VerifyOffset(verifier, VT_STRIDE) &&
+ verifier.VerifyVector(stride()) &&
+ VerifyOffset(verifier, VT_DILATION) &&
+ verifier.VerifyVector(dilation()) &&
+ verifier.EndTable();
+ }
+};
+
+struct Conv2dAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_padding(flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding) {
+ fbb_.AddOffset(Conv2dAttribute::VT_PADDING, padding);
+ }
+ void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) {
+ fbb_.AddOffset(Conv2dAttribute::VT_STRIDE, stride);
+ }
+ void add_dilation(flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation) {
+ fbb_.AddOffset(Conv2dAttribute::VT_DILATION, dilation);
+ }
+ explicit Conv2dAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Conv2dAttributeBuilder &operator=(const Conv2dAttributeBuilder &);
+ flatbuffers::Offset<Conv2dAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Conv2dAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Conv2dAttribute> CreateConv2dAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation = 0) {
+ Conv2dAttributeBuilder builder_(_fbb);
+ builder_.add_dilation(dilation);
+ builder_.add_stride(stride);
+ builder_.add_padding(padding);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Conv2dAttribute> CreateConv2dAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *padding = nullptr,
+ const std::vector<int32_t> *stride = nullptr,
+ const std::vector<int32_t> *dilation = nullptr) {
+ auto padding__ = padding ? _fbb.CreateVector<int32_t>(*padding) : 0;
+ auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
+ auto dilation__ = dilation ? _fbb.CreateVector<int32_t>(*dilation) : 0;
+ return tosa::CreateConv2dAttribute(
+ _fbb,
+ padding__,
+ stride__,
+ dilation__);
+}
+
+struct TransposeConv2dAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OUTPAD = 4,
+ VT_STRIDE = 6,
+ VT_DILATION = 8,
+ VT_OUTPUT_SHAPE = 10
+ };
+ const flatbuffers::Vector<int32_t> *outpad() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OUTPAD);
+ }
+ const flatbuffers::Vector<int32_t> *stride() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE);
+ }
+ const flatbuffers::Vector<int32_t> *dilation() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_DILATION);
+ }
+ const flatbuffers::Vector<int32_t> *output_shape() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OUTPUT_SHAPE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OUTPAD) &&
+ verifier.VerifyVector(outpad()) &&
+ VerifyOffset(verifier, VT_STRIDE) &&
+ verifier.VerifyVector(stride()) &&
+ VerifyOffset(verifier, VT_DILATION) &&
+ verifier.VerifyVector(dilation()) &&
+ VerifyOffset(verifier, VT_OUTPUT_SHAPE) &&
+ verifier.VerifyVector(output_shape()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TransposeConv2dAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_outpad(flatbuffers::Offset<flatbuffers::Vector<int32_t>> outpad) {
+ fbb_.AddOffset(TransposeConv2dAttribute::VT_OUTPAD, outpad);
+ }
+ void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) {
+ fbb_.AddOffset(TransposeConv2dAttribute::VT_STRIDE, stride);
+ }
+ void add_dilation(flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation) {
+ fbb_.AddOffset(TransposeConv2dAttribute::VT_DILATION, dilation);
+ }
+ void add_output_shape(flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_shape) {
+ fbb_.AddOffset(TransposeConv2dAttribute::VT_OUTPUT_SHAPE, output_shape);
+ }
+ explicit TransposeConv2dAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TransposeConv2dAttributeBuilder &operator=(const TransposeConv2dAttributeBuilder &);
+ flatbuffers::Offset<TransposeConv2dAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TransposeConv2dAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TransposeConv2dAttribute> CreateTransposeConv2dAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> outpad = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_shape = 0) {
+ TransposeConv2dAttributeBuilder builder_(_fbb);
+ builder_.add_output_shape(output_shape);
+ builder_.add_dilation(dilation);
+ builder_.add_stride(stride);
+ builder_.add_outpad(outpad);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TransposeConv2dAttribute> CreateTransposeConv2dAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *outpad = nullptr,
+ const std::vector<int32_t> *stride = nullptr,
+ const std::vector<int32_t> *dilation = nullptr,
+ const std::vector<int32_t> *output_shape = nullptr) {
+ auto outpad__ = outpad ? _fbb.CreateVector<int32_t>(*outpad) : 0;
+ auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
+ auto dilation__ = dilation ? _fbb.CreateVector<int32_t>(*dilation) : 0;
+ auto output_shape__ = output_shape ? _fbb.CreateVector<int32_t>(*output_shape) : 0;
+ return tosa::CreateTransposeConv2dAttribute(
+ _fbb,
+ outpad__,
+ stride__,
+ dilation__,
+ output_shape__);
+}
+
+struct ReluNAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_MAX_INT = 4,
+ VT_MAX_FP = 6
+ };
+ int32_t max_int() const {
+ return GetField<int32_t>(VT_MAX_INT, 0);
+ }
+ float max_fp() const {
+ return GetField<float>(VT_MAX_FP, 0.0f);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_MAX_INT) &&
+ VerifyField<float>(verifier, VT_MAX_FP) &&
+ verifier.EndTable();
+ }
+};
+
+struct ReluNAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_max_int(int32_t max_int) {
+ fbb_.AddElement<int32_t>(ReluNAttribute::VT_MAX_INT, max_int, 0);
+ }
+ void add_max_fp(float max_fp) {
+ fbb_.AddElement<float>(ReluNAttribute::VT_MAX_FP, max_fp, 0.0f);
+ }
+ explicit ReluNAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ReluNAttributeBuilder &operator=(const ReluNAttributeBuilder &);
+ flatbuffers::Offset<ReluNAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ReluNAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ReluNAttribute> CreateReluNAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t max_int = 0,
+ float max_fp = 0.0f) {
+ ReluNAttributeBuilder builder_(_fbb);
+ builder_.add_max_fp(max_fp);
+ builder_.add_max_int(max_int);
+ return builder_.Finish();
+}
+
+struct AxisAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_AXIS = 4
+ };
+ int32_t axis() const {
+ return GetField<int32_t>(VT_AXIS, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_AXIS) &&
+ verifier.EndTable();
+ }
+};
+
+struct AxisAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_axis(int32_t axis) {
+ fbb_.AddElement<int32_t>(AxisAttribute::VT_AXIS, axis, 0);
+ }
+ explicit AxisAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ AxisAttributeBuilder &operator=(const AxisAttributeBuilder &);
+ flatbuffers::Offset<AxisAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<AxisAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<AxisAttribute> CreateAxisAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t axis = 0) {
+ AxisAttributeBuilder builder_(_fbb);
+ builder_.add_axis(axis);
+ return builder_.Finish();
+}
+
+struct ReshapeAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_SHAPE = 4
+ };
+ const flatbuffers::Vector<int32_t> *shape() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_SHAPE) &&
+ verifier.VerifyVector(shape()) &&
+ verifier.EndTable();
+ }
+};
+
+struct ReshapeAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_shape(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape) {
+ fbb_.AddOffset(ReshapeAttribute::VT_SHAPE, shape);
+ }
+ explicit ReshapeAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ReshapeAttributeBuilder &operator=(const ReshapeAttributeBuilder &);
+ flatbuffers::Offset<ReshapeAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ReshapeAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ReshapeAttribute> CreateReshapeAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape = 0) {
+ ReshapeAttributeBuilder builder_(_fbb);
+ builder_.add_shape(shape);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<ReshapeAttribute> CreateReshapeAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *shape = nullptr) {
+ auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
+ return tosa::CreateReshapeAttribute(
+ _fbb,
+ shape__);
+}
+
+struct SliceAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BEGIN = 4,
+ VT_SIZE = 6
+ };
+ const flatbuffers::Vector<int32_t> *begin() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_BEGIN);
+ }
+ const flatbuffers::Vector<int32_t> *size() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SIZE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_BEGIN) &&
+ verifier.VerifyVector(begin()) &&
+ VerifyOffset(verifier, VT_SIZE) &&
+ verifier.VerifyVector(size()) &&
+ verifier.EndTable();
+ }
+};
+
+struct SliceAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_begin(flatbuffers::Offset<flatbuffers::Vector<int32_t>> begin) {
+ fbb_.AddOffset(SliceAttribute::VT_BEGIN, begin);
+ }
+ void add_size(flatbuffers::Offset<flatbuffers::Vector<int32_t>> size) {
+ fbb_.AddOffset(SliceAttribute::VT_SIZE, size);
+ }
+ explicit SliceAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SliceAttributeBuilder &operator=(const SliceAttributeBuilder &);
+ flatbuffers::Offset<SliceAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SliceAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SliceAttribute> CreateSliceAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> begin = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> size = 0) {
+ SliceAttributeBuilder builder_(_fbb);
+ builder_.add_size(size);
+ builder_.add_begin(begin);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<SliceAttribute> CreateSliceAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *begin = nullptr,
+ const std::vector<int32_t> *size = nullptr) {
+ auto begin__ = begin ? _fbb.CreateVector<int32_t>(*begin) : 0;
+ auto size__ = size ? _fbb.CreateVector<int32_t>(*size) : 0;
+ return tosa::CreateSliceAttribute(
+ _fbb,
+ begin__,
+ size__);
+}
+
+struct TileAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_MULTIPLES = 4
+ };
+ const flatbuffers::Vector<int32_t> *multiples() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_MULTIPLES);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_MULTIPLES) &&
+ verifier.VerifyVector(multiples()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TileAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_multiples(flatbuffers::Offset<flatbuffers::Vector<int32_t>> multiples) {
+ fbb_.AddOffset(TileAttribute::VT_MULTIPLES, multiples);
+ }
+ explicit TileAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TileAttributeBuilder &operator=(const TileAttributeBuilder &);
+ flatbuffers::Offset<TileAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TileAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TileAttribute> CreateTileAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> multiples = 0) {
+ TileAttributeBuilder builder_(_fbb);
+ builder_.add_multiples(multiples);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TileAttribute> CreateTileAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *multiples = nullptr) {
+ auto multiples__ = multiples ? _fbb.CreateVector<int32_t>(*multiples) : 0;
+ return tosa::CreateTileAttribute(
+ _fbb,
+ multiples__);
+}
+
+struct ResizeAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OUTPUT_SIZE = 4,
+ VT_STRIDE = 6,
+ VT_OFFSET = 8,
+ VT_SHIFT = 10,
+ VT_MODE = 12
+ };
+ const flatbuffers::Vector<int32_t> *output_size() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OUTPUT_SIZE);
+ }
+ const flatbuffers::Vector<int32_t> *stride() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE);
+ }
+ const flatbuffers::Vector<int32_t> *offset() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OFFSET);
+ }
+ int32_t shift() const {
+ return GetField<int32_t>(VT_SHIFT, 0);
+ }
+ ResizeMode mode() const {
+ return static_cast<ResizeMode>(GetField<uint32_t>(VT_MODE, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OUTPUT_SIZE) &&
+ verifier.VerifyVector(output_size()) &&
+ VerifyOffset(verifier, VT_STRIDE) &&
+ verifier.VerifyVector(stride()) &&
+ VerifyOffset(verifier, VT_OFFSET) &&
+ verifier.VerifyVector(offset()) &&
+ VerifyField<int32_t>(verifier, VT_SHIFT) &&
+ VerifyField<uint32_t>(verifier, VT_MODE) &&
+ verifier.EndTable();
+ }
+};
+
+struct ResizeAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_output_size(flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_size) {
+ fbb_.AddOffset(ResizeAttribute::VT_OUTPUT_SIZE, output_size);
+ }
+ void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) {
+ fbb_.AddOffset(ResizeAttribute::VT_STRIDE, stride);
+ }
+ void add_offset(flatbuffers::Offset<flatbuffers::Vector<int32_t>> offset) {
+ fbb_.AddOffset(ResizeAttribute::VT_OFFSET, offset);
+ }
+ void add_shift(int32_t shift) {
+ fbb_.AddElement<int32_t>(ResizeAttribute::VT_SHIFT, shift, 0);
+ }
+ void add_mode(ResizeMode mode) {
+ fbb_.AddElement<uint32_t>(ResizeAttribute::VT_MODE, static_cast<uint32_t>(mode), 0);
+ }
+ explicit ResizeAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ResizeAttributeBuilder &operator=(const ResizeAttributeBuilder &);
+ flatbuffers::Offset<ResizeAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ResizeAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ResizeAttribute> CreateResizeAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_size = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> offset = 0,
+ int32_t shift = 0,
+ ResizeMode mode = ResizeMode_UNKNOWN) {
+ ResizeAttributeBuilder builder_(_fbb);
+ builder_.add_mode(mode);
+ builder_.add_shift(shift);
+ builder_.add_offset(offset);
+ builder_.add_stride(stride);
+ builder_.add_output_size(output_size);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<ResizeAttribute> CreateResizeAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *output_size = nullptr,
+ const std::vector<int32_t> *stride = nullptr,
+ const std::vector<int32_t> *offset = nullptr,
+ int32_t shift = 0,
+ ResizeMode mode = ResizeMode_UNKNOWN) {
+ auto output_size__ = output_size ? _fbb.CreateVector<int32_t>(*output_size) : 0;
+ auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
+ auto offset__ = offset ? _fbb.CreateVector<int32_t>(*offset) : 0;
+ return tosa::CreateResizeAttribute(
+ _fbb,
+ output_size__,
+ stride__,
+ offset__,
+ shift,
+ mode);
+}
+
+struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_MIN_INT = 4,
+ VT_MAX_INT = 6,
+ VT_MIN_FP = 8,
+ VT_MAX_FP = 10
+ };
+ int32_t min_int() const {
+ return GetField<int32_t>(VT_MIN_INT, 0);
+ }
+ int32_t max_int() const {
+ return GetField<int32_t>(VT_MAX_INT, 0);
+ }
+ float min_fp() const {
+ return GetField<float>(VT_MIN_FP, 0.0f);
+ }
+ float max_fp() const {
+ return GetField<float>(VT_MAX_FP, 0.0f);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_MIN_INT) &&
+ VerifyField<int32_t>(verifier, VT_MAX_INT) &&
+ VerifyField<float>(verifier, VT_MIN_FP) &&
+ VerifyField<float>(verifier, VT_MAX_FP) &&
+ verifier.EndTable();
+ }
+};
+
+struct ClampAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_min_int(int32_t min_int) {
+ fbb_.AddElement<int32_t>(ClampAttribute::VT_MIN_INT, min_int, 0);
+ }
+ void add_max_int(int32_t max_int) {
+ fbb_.AddElement<int32_t>(ClampAttribute::VT_MAX_INT, max_int, 0);
+ }
+ void add_min_fp(float min_fp) {
+ fbb_.AddElement<float>(ClampAttribute::VT_MIN_FP, min_fp, 0.0f);
+ }
+ void add_max_fp(float max_fp) {
+ fbb_.AddElement<float>(ClampAttribute::VT_MAX_FP, max_fp, 0.0f);
+ }
+ explicit ClampAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ClampAttributeBuilder &operator=(const ClampAttributeBuilder &);
+ flatbuffers::Offset<ClampAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ClampAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ClampAttribute> CreateClampAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t min_int = 0,
+ int32_t max_int = 0,
+ float min_fp = 0.0f,
+ float max_fp = 0.0f) {
+ ClampAttributeBuilder builder_(_fbb);
+ builder_.add_max_fp(max_fp);
+ builder_.add_min_fp(min_fp);
+ builder_.add_max_int(max_int);
+ builder_.add_min_int(min_int);
+ return builder_.Finish();
+}
+
+struct RescaleAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_INPUT_ZP = 4,
+ VT_OUTPUT_ZP = 6,
+ VT_MULTIPLIER = 8,
+ VT_SHIFT = 10,
+ VT_SCALE32 = 12,
+ VT_DOUBLE_ROUND = 14,
+ VT_PER_CHANNEL = 16
+ };
+ int32_t input_zp() const {
+ return GetField<int32_t>(VT_INPUT_ZP, 0);
+ }
+ int32_t output_zp() const {
+ return GetField<int32_t>(VT_OUTPUT_ZP, 0);
+ }
+ const flatbuffers::Vector<int32_t> *multiplier() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_MULTIPLIER);
+ }
+ const flatbuffers::Vector<int32_t> *shift() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHIFT);
+ }
+ bool scale32() const {
+ return GetField<uint8_t>(VT_SCALE32, 0) != 0;
+ }
+ bool double_round() const {
+ return GetField<uint8_t>(VT_DOUBLE_ROUND, 0) != 0;
+ }
+ bool per_channel() const {
+ return GetField<uint8_t>(VT_PER_CHANNEL, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_INPUT_ZP) &&
+ VerifyField<int32_t>(verifier, VT_OUTPUT_ZP) &&
+ VerifyOffset(verifier, VT_MULTIPLIER) &&
+ verifier.VerifyVector(multiplier()) &&
+ VerifyOffset(verifier, VT_SHIFT) &&
+ verifier.VerifyVector(shift()) &&
+ VerifyField<uint8_t>(verifier, VT_SCALE32) &&
+ VerifyField<uint8_t>(verifier, VT_DOUBLE_ROUND) &&
+ VerifyField<uint8_t>(verifier, VT_PER_CHANNEL) &&
+ verifier.EndTable();
+ }
+};
+
+struct RescaleAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_input_zp(int32_t input_zp) {
+ fbb_.AddElement<int32_t>(RescaleAttribute::VT_INPUT_ZP, input_zp, 0);
+ }
+ void add_output_zp(int32_t output_zp) {
+ fbb_.AddElement<int32_t>(RescaleAttribute::VT_OUTPUT_ZP, output_zp, 0);
+ }
+ void add_multiplier(flatbuffers::Offset<flatbuffers::Vector<int32_t>> multiplier) {
+ fbb_.AddOffset(RescaleAttribute::VT_MULTIPLIER, multiplier);
+ }
+ void add_shift(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shift) {
+ fbb_.AddOffset(RescaleAttribute::VT_SHIFT, shift);
+ }
+ void add_scale32(bool scale32) {
+ fbb_.AddElement<uint8_t>(RescaleAttribute::VT_SCALE32, static_cast<uint8_t>(scale32), 0);
+ }
+ void add_double_round(bool double_round) {
+ fbb_.AddElement<uint8_t>(RescaleAttribute::VT_DOUBLE_ROUND, static_cast<uint8_t>(double_round), 0);
+ }
+ void add_per_channel(bool per_channel) {
+ fbb_.AddElement<uint8_t>(RescaleAttribute::VT_PER_CHANNEL, static_cast<uint8_t>(per_channel), 0);
+ }
+ explicit RescaleAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ RescaleAttributeBuilder &operator=(const RescaleAttributeBuilder &);
+ flatbuffers::Offset<RescaleAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<RescaleAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<RescaleAttribute> CreateRescaleAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t input_zp = 0,
+ int32_t output_zp = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> multiplier = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> shift = 0,
+ bool scale32 = false,
+ bool double_round = false,
+ bool per_channel = false) {
+ RescaleAttributeBuilder builder_(_fbb);
+ builder_.add_shift(shift);
+ builder_.add_multiplier(multiplier);
+ builder_.add_output_zp(output_zp);
+ builder_.add_input_zp(input_zp);
+ builder_.add_per_channel(per_channel);
+ builder_.add_double_round(double_round);
+ builder_.add_scale32(scale32);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<RescaleAttribute> CreateRescaleAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t input_zp = 0,
+ int32_t output_zp = 0,
+ const std::vector<int32_t> *multiplier = nullptr,
+ const std::vector<int32_t> *shift = nullptr,
+ bool scale32 = false,
+ bool double_round = false,
+ bool per_channel = false) {
+ auto multiplier__ = multiplier ? _fbb.CreateVector<int32_t>(*multiplier) : 0;
+ auto shift__ = shift ? _fbb.CreateVector<int32_t>(*shift) : 0;
+ return tosa::CreateRescaleAttribute(
+ _fbb,
+ input_zp,
+ output_zp,
+ multiplier__,
+ shift__,
+ scale32,
+ double_round,
+ per_channel);
+}
+
+struct CustomAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_IDENTIFIER = 4
+ };
+ const flatbuffers::String *identifier() const {
+ return GetPointer<const flatbuffers::String *>(VT_IDENTIFIER);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_IDENTIFIER) &&
+ verifier.VerifyString(identifier()) &&
+ verifier.EndTable();
+ }
+};
+
+struct CustomAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_identifier(flatbuffers::Offset<flatbuffers::String> identifier) {
+ fbb_.AddOffset(CustomAttribute::VT_IDENTIFIER, identifier);
+ }
+ explicit CustomAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CustomAttributeBuilder &operator=(const CustomAttributeBuilder &);
+ flatbuffers::Offset<CustomAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CustomAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CustomAttribute> CreateCustomAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> identifier = 0) {
+ CustomAttributeBuilder builder_(_fbb);
+ builder_.add_identifier(identifier);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<CustomAttribute> CreateCustomAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *identifier = nullptr) {
+ auto identifier__ = identifier ? _fbb.CreateString(identifier) : 0;
+ return tosa::CreateCustomAttribute(
+ _fbb,
+ identifier__);
+}
+
+struct CondIfAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_THEN_BRANCH = 4,
+ VT_ELSE_BRANCH = 6
+ };
+ const flatbuffers::String *then_branch() const {
+ return GetPointer<const flatbuffers::String *>(VT_THEN_BRANCH);
+ }
+ const flatbuffers::String *else_branch() const {
+ return GetPointer<const flatbuffers::String *>(VT_ELSE_BRANCH);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_THEN_BRANCH) &&
+ verifier.VerifyString(then_branch()) &&
+ VerifyOffset(verifier, VT_ELSE_BRANCH) &&
+ verifier.VerifyString(else_branch()) &&
+ verifier.EndTable();
+ }
+};
+
+struct CondIfAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_then_branch(flatbuffers::Offset<flatbuffers::String> then_branch) {
+ fbb_.AddOffset(CondIfAttribute::VT_THEN_BRANCH, then_branch);
+ }
+ void add_else_branch(flatbuffers::Offset<flatbuffers::String> else_branch) {
+ fbb_.AddOffset(CondIfAttribute::VT_ELSE_BRANCH, else_branch);
+ }
+ explicit CondIfAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CondIfAttributeBuilder &operator=(const CondIfAttributeBuilder &);
+ flatbuffers::Offset<CondIfAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CondIfAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CondIfAttribute> CreateCondIfAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> then_branch = 0,
+ flatbuffers::Offset<flatbuffers::String> else_branch = 0) {
+ CondIfAttributeBuilder builder_(_fbb);
+ builder_.add_else_branch(else_branch);
+ builder_.add_then_branch(then_branch);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<CondIfAttribute> CreateCondIfAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *then_branch = nullptr,
+ const char *else_branch = nullptr) {
+ auto then_branch__ = then_branch ? _fbb.CreateString(then_branch) : 0;
+ auto else_branch__ = else_branch ? _fbb.CreateString(else_branch) : 0;
+ return tosa::CreateCondIfAttribute(
+ _fbb,
+ then_branch__,
+ else_branch__);
+}
+
+struct WhileLoopAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_COND_BRANCH = 4,
+ VT_BODY_BRANCH = 6
+ };
+ const flatbuffers::String *cond_branch() const {
+ return GetPointer<const flatbuffers::String *>(VT_COND_BRANCH);
+ }
+ const flatbuffers::String *body_branch() const {
+ return GetPointer<const flatbuffers::String *>(VT_BODY_BRANCH);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_COND_BRANCH) &&
+ verifier.VerifyString(cond_branch()) &&
+ VerifyOffset(verifier, VT_BODY_BRANCH) &&
+ verifier.VerifyString(body_branch()) &&
+ verifier.EndTable();
+ }
+};
+
+struct WhileLoopAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_cond_branch(flatbuffers::Offset<flatbuffers::String> cond_branch) {
+ fbb_.AddOffset(WhileLoopAttribute::VT_COND_BRANCH, cond_branch);
+ }
+ void add_body_branch(flatbuffers::Offset<flatbuffers::String> body_branch) {
+ fbb_.AddOffset(WhileLoopAttribute::VT_BODY_BRANCH, body_branch);
+ }
+ explicit WhileLoopAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ WhileLoopAttributeBuilder &operator=(const WhileLoopAttributeBuilder &);
+ flatbuffers::Offset<WhileLoopAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<WhileLoopAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<WhileLoopAttribute> CreateWhileLoopAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> cond_branch = 0,
+ flatbuffers::Offset<flatbuffers::String> body_branch = 0) {
+ WhileLoopAttributeBuilder builder_(_fbb);
+ builder_.add_body_branch(body_branch);
+ builder_.add_cond_branch(cond_branch);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<WhileLoopAttribute> CreateWhileLoopAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *cond_branch = nullptr,
+ const char *body_branch = nullptr) {
+ auto cond_branch__ = cond_branch ? _fbb.CreateString(cond_branch) : 0;
+ auto body_branch__ = body_branch ? _fbb.CreateString(body_branch) : 0;
+ return tosa::CreateWhileLoopAttribute(
+ _fbb,
+ cond_branch__,
+ body_branch__);
+}
+
+struct UnaryQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_INPUT_ZP = 4,
+ VT_OUTPUT_ZP = 6
+ };
+ int32_t input_zp() const {
+ return GetField<int32_t>(VT_INPUT_ZP, 0);
+ }
+ int32_t output_zp() const {
+ return GetField<int32_t>(VT_OUTPUT_ZP, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_INPUT_ZP) &&
+ VerifyField<int32_t>(verifier, VT_OUTPUT_ZP) &&
+ verifier.EndTable();
+ }
+};
+
+struct UnaryQuantInfoBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_input_zp(int32_t input_zp) {
+ fbb_.AddElement<int32_t>(UnaryQuantInfo::VT_INPUT_ZP, input_zp, 0);
+ }
+ void add_output_zp(int32_t output_zp) {
+ fbb_.AddElement<int32_t>(UnaryQuantInfo::VT_OUTPUT_ZP, output_zp, 0);
+ }
+ explicit UnaryQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UnaryQuantInfoBuilder &operator=(const UnaryQuantInfoBuilder &);
+ flatbuffers::Offset<UnaryQuantInfo> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<UnaryQuantInfo>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<UnaryQuantInfo> CreateUnaryQuantInfo(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t input_zp = 0,
+ int32_t output_zp = 0) {
+ UnaryQuantInfoBuilder builder_(_fbb);
+ builder_.add_output_zp(output_zp);
+ builder_.add_input_zp(input_zp);
+ return builder_.Finish();
+}
+
+struct ConvQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_INPUT_ZP = 4,
+ VT_WEIGHT_ZP = 6
+ };
+ int32_t input_zp() const {
+ return GetField<int32_t>(VT_INPUT_ZP, 0);
+ }
+ int32_t weight_zp() const {
+ return GetField<int32_t>(VT_WEIGHT_ZP, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_INPUT_ZP) &&
+ VerifyField<int32_t>(verifier, VT_WEIGHT_ZP) &&
+ verifier.EndTable();
+ }
+};
+
+struct ConvQuantInfoBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_input_zp(int32_t input_zp) {
+ fbb_.AddElement<int32_t>(ConvQuantInfo::VT_INPUT_ZP, input_zp, 0);
+ }
+ void add_weight_zp(int32_t weight_zp) {
+ fbb_.AddElement<int32_t>(ConvQuantInfo::VT_WEIGHT_ZP, weight_zp, 0);
+ }
+ explicit ConvQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ConvQuantInfoBuilder &operator=(const ConvQuantInfoBuilder &);
+ flatbuffers::Offset<ConvQuantInfo> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ConvQuantInfo>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ConvQuantInfo> CreateConvQuantInfo(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t input_zp = 0,
+ int32_t weight_zp = 0) {
+ ConvQuantInfoBuilder builder_(_fbb);
+ builder_.add_weight_zp(weight_zp);
+ builder_.add_input_zp(input_zp);
+ return builder_.Finish();
+}
+
+struct MatMulQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_A_ZP = 4,
+ VT_B_ZP = 6
+ };
+ int32_t a_zp() const {
+ return GetField<int32_t>(VT_A_ZP, 0);
+ }
+ int32_t b_zp() const {
+ return GetField<int32_t>(VT_B_ZP, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_A_ZP) &&
+ VerifyField<int32_t>(verifier, VT_B_ZP) &&
+ verifier.EndTable();
+ }
+};
+
+struct MatMulQuantInfoBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_a_zp(int32_t a_zp) {
+ fbb_.AddElement<int32_t>(MatMulQuantInfo::VT_A_ZP, a_zp, 0);
+ }
+ void add_b_zp(int32_t b_zp) {
+ fbb_.AddElement<int32_t>(MatMulQuantInfo::VT_B_ZP, b_zp, 0);
+ }
+ explicit MatMulQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ MatMulQuantInfoBuilder &operator=(const MatMulQuantInfoBuilder &);
+ flatbuffers::Offset<MatMulQuantInfo> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<MatMulQuantInfo>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<MatMulQuantInfo> CreateMatMulQuantInfo(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t a_zp = 0,
+ int32_t b_zp = 0) {
+ MatMulQuantInfoBuilder builder_(_fbb);
+ builder_.add_b_zp(b_zp);
+ builder_.add_a_zp(a_zp);
+ return builder_.Finish();
+}
+
+struct PadQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_INPUT_ZP = 4
+ };
+ int32_t input_zp() const {
+ return GetField<int32_t>(VT_INPUT_ZP, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_INPUT_ZP) &&
+ verifier.EndTable();
+ }
+};
+
+struct PadQuantInfoBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_input_zp(int32_t input_zp) {
+ fbb_.AddElement<int32_t>(PadQuantInfo::VT_INPUT_ZP, input_zp, 0);
+ }
+ explicit PadQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PadQuantInfoBuilder &operator=(const PadQuantInfoBuilder &);
+ flatbuffers::Offset<PadQuantInfo> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PadQuantInfo>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PadQuantInfo> CreatePadQuantInfo(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t input_zp = 0) {
+ PadQuantInfoBuilder builder_(_fbb);
+ builder_.add_input_zp(input_zp);
+ return builder_.Finish();
+}
+
+struct Version FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT__MAJOR = 4,
+ VT__MINOR = 6,
+ VT__PATCH = 8,
+ VT__EXPERIMENTAL = 10
+ };
+ int32_t _major() const {
+ return GetField<int32_t>(VT__MAJOR, 0);
+ }
+ int32_t _minor() const {
+ return GetField<int32_t>(VT__MINOR, 20);
+ }
+ int32_t _patch() const {
+ return GetField<int32_t>(VT__PATCH, 0);
+ }
+ bool _experimental() const {
+ return GetField<uint8_t>(VT__EXPERIMENTAL, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT__MAJOR) &&
+ VerifyField<int32_t>(verifier, VT__MINOR) &&
+ VerifyField<int32_t>(verifier, VT__PATCH) &&
+ VerifyField<uint8_t>(verifier, VT__EXPERIMENTAL) &&
+ verifier.EndTable();
+ }
+};
+
+struct VersionBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add__major(int32_t _major) {
+ fbb_.AddElement<int32_t>(Version::VT__MAJOR, _major, 0);
+ }
+ void add__minor(int32_t _minor) {
+ fbb_.AddElement<int32_t>(Version::VT__MINOR, _minor, 20);
+ }
+ void add__patch(int32_t _patch) {
+ fbb_.AddElement<int32_t>(Version::VT__PATCH, _patch, 0);
+ }
+ void add__experimental(bool _experimental) {
+ fbb_.AddElement<uint8_t>(Version::VT__EXPERIMENTAL, static_cast<uint8_t>(_experimental), 0);
+ }
+ explicit VersionBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ VersionBuilder &operator=(const VersionBuilder &);
+ flatbuffers::Offset<Version> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Version>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Version> CreateVersion(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t _major = 0,
+ int32_t _minor = 20,
+ int32_t _patch = 0,
+ bool _experimental = false) {
+ VersionBuilder builder_(_fbb);
+ builder_.add__patch(_patch);
+ builder_.add__minor(_minor);
+ builder_.add__major(_major);
+ builder_.add__experimental(_experimental);
+ return builder_.Finish();
+}
+
+struct TosaTensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_NAME = 4,
+ VT_SHAPE = 6,
+ VT_TYPE = 8,
+ VT_USAGE = 10,
+ VT_FORMAT = 12,
+ VT_NPY_FILENAME = 14
+ };
+ const flatbuffers::String *name() const {
+ return GetPointer<const flatbuffers::String *>(VT_NAME);
+ }
+ const flatbuffers::Vector<int32_t> *shape() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE);
+ }
+ DType type() const {
+ return static_cast<DType>(GetField<uint32_t>(VT_TYPE, 0));
+ }
+ const flatbuffers::Vector<uint32_t> *usage() const {
+ return GetPointer<const flatbuffers::Vector<uint32_t> *>(VT_USAGE);
+ }
+ const flatbuffers::Vector<uint32_t> *format() const {
+ return GetPointer<const flatbuffers::Vector<uint32_t> *>(VT_FORMAT);
+ }
+ const flatbuffers::String *npy_filename() const {
+ return GetPointer<const flatbuffers::String *>(VT_NPY_FILENAME);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) &&
+ VerifyOffset(verifier, VT_SHAPE) &&
+ verifier.VerifyVector(shape()) &&
+ VerifyField<uint32_t>(verifier, VT_TYPE) &&
+ VerifyOffset(verifier, VT_USAGE) &&
+ verifier.VerifyVector(usage()) &&
+ VerifyOffset(verifier, VT_FORMAT) &&
+ verifier.VerifyVector(format()) &&
+ VerifyOffset(verifier, VT_NPY_FILENAME) &&
+ verifier.VerifyString(npy_filename()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TosaTensorBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_name(flatbuffers::Offset<flatbuffers::String> name) {
+ fbb_.AddOffset(TosaTensor::VT_NAME, name);
+ }
+ void add_shape(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape) {
+ fbb_.AddOffset(TosaTensor::VT_SHAPE, shape);
+ }
+ void add_type(DType type) {
+ fbb_.AddElement<uint32_t>(TosaTensor::VT_TYPE, static_cast<uint32_t>(type), 0);
+ }
+ void add_usage(flatbuffers::Offset<flatbuffers::Vector<uint32_t>> usage) {
+ fbb_.AddOffset(TosaTensor::VT_USAGE, usage);
+ }
+ void add_format(flatbuffers::Offset<flatbuffers::Vector<uint32_t>> format) {
+ fbb_.AddOffset(TosaTensor::VT_FORMAT, format);
+ }
+ void add_npy_filename(flatbuffers::Offset<flatbuffers::String> npy_filename) {
+ fbb_.AddOffset(TosaTensor::VT_NPY_FILENAME, npy_filename);
+ }
+ explicit TosaTensorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TosaTensorBuilder &operator=(const TosaTensorBuilder &);
+ flatbuffers::Offset<TosaTensor> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TosaTensor>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TosaTensor> CreateTosaTensor(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> name = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape = 0,
+ DType type = DType_UNKNOWN,
+ flatbuffers::Offset<flatbuffers::Vector<uint32_t>> usage = 0,
+ flatbuffers::Offset<flatbuffers::Vector<uint32_t>> format = 0,
+ flatbuffers::Offset<flatbuffers::String> npy_filename = 0) {
+ TosaTensorBuilder builder_(_fbb);
+ builder_.add_npy_filename(npy_filename);
+ builder_.add_format(format);
+ builder_.add_usage(usage);
+ builder_.add_type(type);
+ builder_.add_shape(shape);
+ builder_.add_name(name);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *name = nullptr,
+ const std::vector<int32_t> *shape = nullptr,
+ DType type = DType_UNKNOWN,
+ const std::vector<uint32_t> *usage = nullptr,
+ const std::vector<uint32_t> *format = nullptr,
+ const char *npy_filename = nullptr) {
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
+ auto usage__ = usage ? _fbb.CreateVector<uint32_t>(*usage) : 0;
+ auto format__ = format ? _fbb.CreateVector<uint32_t>(*format) : 0;
+ auto npy_filename__ = npy_filename ? _fbb.CreateString(npy_filename) : 0;
+ return tosa::CreateTosaTensor(
+ _fbb,
+ name__,
+ shape__,
+ type,
+ usage__,
+ format__,
+ npy_filename__);
+}
+
+struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OP = 4,
+ VT_ATTRIBUTE_TYPE = 6,
+ VT_ATTRIBUTE = 8,
+ VT_INPUTS = 10,
+ VT_OUTPUTS = 12,
+ VT_QUANT_INFO_TYPE = 14,
+ VT_QUANT_INFO = 16
+ };
+ Op op() const {
+ return static_cast<Op>(GetField<uint32_t>(VT_OP, 0));
+ }
+ Attribute attribute_type() const {
+ return static_cast<Attribute>(GetField<uint8_t>(VT_ATTRIBUTE_TYPE, 0));
+ }
+ const void *attribute() const {
+ return GetPointer<const void *>(VT_ATTRIBUTE);
+ }
+ template<typename T> const T *attribute_as() const;
+ const Pool2dAttribute *attribute_as_Pool2dAttribute() const {
+ return attribute_type() == Attribute_Pool2dAttribute ? static_cast<const Pool2dAttribute *>(attribute()) : nullptr;
+ }
+ const Conv2dAttribute *attribute_as_Conv2dAttribute() const {
+ return attribute_type() == Attribute_Conv2dAttribute ? static_cast<const Conv2dAttribute *>(attribute()) : nullptr;
+ }
+ const TransposeConv2dAttribute *attribute_as_TransposeConv2dAttribute() const {
+ return attribute_type() == Attribute_TransposeConv2dAttribute ? static_cast<const TransposeConv2dAttribute *>(attribute()) : nullptr;
+ }
+ const ReluNAttribute *attribute_as_ReluNAttribute() const {
+ return attribute_type() == Attribute_ReluNAttribute ? static_cast<const ReluNAttribute *>(attribute()) : nullptr;
+ }
+ const AxisAttribute *attribute_as_AxisAttribute() const {
+ return attribute_type() == Attribute_AxisAttribute ? static_cast<const AxisAttribute *>(attribute()) : nullptr;
+ }
+ const ReshapeAttribute *attribute_as_ReshapeAttribute() const {
+ return attribute_type() == Attribute_ReshapeAttribute ? static_cast<const ReshapeAttribute *>(attribute()) : nullptr;
+ }
+ const SliceAttribute *attribute_as_SliceAttribute() const {
+ return attribute_type() == Attribute_SliceAttribute ? static_cast<const SliceAttribute *>(attribute()) : nullptr;
+ }
+ const TileAttribute *attribute_as_TileAttribute() const {
+ return attribute_type() == Attribute_TileAttribute ? static_cast<const TileAttribute *>(attribute()) : nullptr;
+ }
+ const ResizeAttribute *attribute_as_ResizeAttribute() const {
+ return attribute_type() == Attribute_ResizeAttribute ? static_cast<const ResizeAttribute *>(attribute()) : nullptr;
+ }
+ const ClampAttribute *attribute_as_ClampAttribute() const {
+ return attribute_type() == Attribute_ClampAttribute ? static_cast<const ClampAttribute *>(attribute()) : nullptr;
+ }
+ const RescaleAttribute *attribute_as_RescaleAttribute() const {
+ return attribute_type() == Attribute_RescaleAttribute ? static_cast<const RescaleAttribute *>(attribute()) : nullptr;
+ }
+ const CustomAttribute *attribute_as_CustomAttribute() const {
+ return attribute_type() == Attribute_CustomAttribute ? static_cast<const CustomAttribute *>(attribute()) : nullptr;
+ }
+ const CondIfAttribute *attribute_as_CondIfAttribute() const {
+ return attribute_type() == Attribute_CondIfAttribute ? static_cast<const CondIfAttribute *>(attribute()) : nullptr;
+ }
+ const WhileLoopAttribute *attribute_as_WhileLoopAttribute() const {
+ return attribute_type() == Attribute_WhileLoopAttribute ? static_cast<const WhileLoopAttribute *>(attribute()) : nullptr;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *inputs() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_INPUTS);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *outputs() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OUTPUTS);
+ }
+ QuantInfo quant_info_type() const {
+ return static_cast<QuantInfo>(GetField<uint8_t>(VT_QUANT_INFO_TYPE, 0));
+ }
+ const void *quant_info() const {
+ return GetPointer<const void *>(VT_QUANT_INFO);
+ }
+ template<typename T> const T *quant_info_as() const;
+ const UnaryQuantInfo *quant_info_as_UnaryQuantInfo() const {
+ return quant_info_type() == QuantInfo_UnaryQuantInfo ? static_cast<const UnaryQuantInfo *>(quant_info()) : nullptr;
+ }
+ const ConvQuantInfo *quant_info_as_ConvQuantInfo() const {
+ return quant_info_type() == QuantInfo_ConvQuantInfo ? static_cast<const ConvQuantInfo *>(quant_info()) : nullptr;
+ }
+ const MatMulQuantInfo *quant_info_as_MatMulQuantInfo() const {
+ return quant_info_type() == QuantInfo_MatMulQuantInfo ? static_cast<const MatMulQuantInfo *>(quant_info()) : nullptr;
+ }
+ const PadQuantInfo *quant_info_as_PadQuantInfo() const {
+ return quant_info_type() == QuantInfo_PadQuantInfo ? static_cast<const PadQuantInfo *>(quant_info()) : nullptr;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint32_t>(verifier, VT_OP) &&
+ VerifyField<uint8_t>(verifier, VT_ATTRIBUTE_TYPE) &&
+ VerifyOffset(verifier, VT_ATTRIBUTE) &&
+ VerifyAttribute(verifier, attribute(), attribute_type()) &&
+ VerifyOffset(verifier, VT_INPUTS) &&
+ verifier.VerifyVector(inputs()) &&
+ verifier.VerifyVectorOfStrings(inputs()) &&
+ VerifyOffset(verifier, VT_OUTPUTS) &&
+ verifier.VerifyVector(outputs()) &&
+ verifier.VerifyVectorOfStrings(outputs()) &&
+ VerifyField<uint8_t>(verifier, VT_QUANT_INFO_TYPE) &&
+ VerifyOffset(verifier, VT_QUANT_INFO) &&
+ VerifyQuantInfo(verifier, quant_info(), quant_info_type()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const Pool2dAttribute *TosaOperator::attribute_as<Pool2dAttribute>() const {
+ return attribute_as_Pool2dAttribute();
+}
+
+template<> inline const Conv2dAttribute *TosaOperator::attribute_as<Conv2dAttribute>() const {
+ return attribute_as_Conv2dAttribute();
+}
+
+template<> inline const TransposeConv2dAttribute *TosaOperator::attribute_as<TransposeConv2dAttribute>() const {
+ return attribute_as_TransposeConv2dAttribute();
+}
+
+template<> inline const ReluNAttribute *TosaOperator::attribute_as<ReluNAttribute>() const {
+ return attribute_as_ReluNAttribute();
+}
+
+template<> inline const AxisAttribute *TosaOperator::attribute_as<AxisAttribute>() const {
+ return attribute_as_AxisAttribute();
+}
+
+template<> inline const ReshapeAttribute *TosaOperator::attribute_as<ReshapeAttribute>() const {
+ return attribute_as_ReshapeAttribute();
+}
+
+template<> inline const SliceAttribute *TosaOperator::attribute_as<SliceAttribute>() const {
+ return attribute_as_SliceAttribute();
+}
+
+template<> inline const TileAttribute *TosaOperator::attribute_as<TileAttribute>() const {
+ return attribute_as_TileAttribute();
+}
+
+template<> inline const ResizeAttribute *TosaOperator::attribute_as<ResizeAttribute>() const {
+ return attribute_as_ResizeAttribute();
+}
+
+template<> inline const ClampAttribute *TosaOperator::attribute_as<ClampAttribute>() const {
+ return attribute_as_ClampAttribute();
+}
+
+template<> inline const RescaleAttribute *TosaOperator::attribute_as<RescaleAttribute>() const {
+ return attribute_as_RescaleAttribute();
+}
+
+template<> inline const CustomAttribute *TosaOperator::attribute_as<CustomAttribute>() const {
+ return attribute_as_CustomAttribute();
+}
+
+template<> inline const CondIfAttribute *TosaOperator::attribute_as<CondIfAttribute>() const {
+ return attribute_as_CondIfAttribute();
+}
+
+template<> inline const WhileLoopAttribute *TosaOperator::attribute_as<WhileLoopAttribute>() const {
+ return attribute_as_WhileLoopAttribute();
+}
+
+template<> inline const UnaryQuantInfo *TosaOperator::quant_info_as<UnaryQuantInfo>() const {
+ return quant_info_as_UnaryQuantInfo();
+}
+
+template<> inline const ConvQuantInfo *TosaOperator::quant_info_as<ConvQuantInfo>() const {
+ return quant_info_as_ConvQuantInfo();
+}
+
+template<> inline const MatMulQuantInfo *TosaOperator::quant_info_as<MatMulQuantInfo>() const {
+ return quant_info_as_MatMulQuantInfo();
+}
+
+template<> inline const PadQuantInfo *TosaOperator::quant_info_as<PadQuantInfo>() const {
+ return quant_info_as_PadQuantInfo();
+}
+
+struct TosaOperatorBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_op(Op op) {
+ fbb_.AddElement<uint32_t>(TosaOperator::VT_OP, static_cast<uint32_t>(op), 0);
+ }
+ void add_attribute_type(Attribute attribute_type) {
+ fbb_.AddElement<uint8_t>(TosaOperator::VT_ATTRIBUTE_TYPE, static_cast<uint8_t>(attribute_type), 0);
+ }
+ void add_attribute(flatbuffers::Offset<void> attribute) {
+ fbb_.AddOffset(TosaOperator::VT_ATTRIBUTE, attribute);
+ }
+ void add_inputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs) {
+ fbb_.AddOffset(TosaOperator::VT_INPUTS, inputs);
+ }
+ void add_outputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs) {
+ fbb_.AddOffset(TosaOperator::VT_OUTPUTS, outputs);
+ }
+ void add_quant_info_type(QuantInfo quant_info_type) {
+ fbb_.AddElement<uint8_t>(TosaOperator::VT_QUANT_INFO_TYPE, static_cast<uint8_t>(quant_info_type), 0);
+ }
+ void add_quant_info(flatbuffers::Offset<void> quant_info) {
+ fbb_.AddOffset(TosaOperator::VT_QUANT_INFO, quant_info);
+ }
+ explicit TosaOperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TosaOperatorBuilder &operator=(const TosaOperatorBuilder &);
+ flatbuffers::Offset<TosaOperator> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TosaOperator>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TosaOperator> CreateTosaOperator(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ Op op = Op_UNKNOWN,
+ Attribute attribute_type = Attribute_NONE,
+ flatbuffers::Offset<void> attribute = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs = 0,
+ QuantInfo quant_info_type = QuantInfo_NONE,
+ flatbuffers::Offset<void> quant_info = 0) {
+ TosaOperatorBuilder builder_(_fbb);
+ builder_.add_quant_info(quant_info);
+ builder_.add_outputs(outputs);
+ builder_.add_inputs(inputs);
+ builder_.add_attribute(attribute);
+ builder_.add_op(op);
+ builder_.add_quant_info_type(quant_info_type);
+ builder_.add_attribute_type(attribute_type);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TosaOperator> CreateTosaOperatorDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ Op op = Op_UNKNOWN,
+ Attribute attribute_type = Attribute_NONE,
+ flatbuffers::Offset<void> attribute = 0,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *inputs = nullptr,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputs = nullptr,
+ QuantInfo quant_info_type = QuantInfo_NONE,
+ flatbuffers::Offset<void> quant_info = 0) {
+ auto inputs__ = inputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*inputs) : 0;
+ auto outputs__ = outputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*outputs) : 0;
+ return tosa::CreateTosaOperator(
+ _fbb,
+ op,
+ attribute_type,
+ attribute,
+ inputs__,
+ outputs__,
+ quant_info_type,
+ quant_info);
+}
+
+struct TosaBasicBlock FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_NAME = 4,
+ VT_OPERATORS = 6,
+ VT_TENSORS = 8,
+ VT_INPUTS = 10,
+ VT_OUTPUTS = 12
+ };
+ const flatbuffers::String *name() const {
+ return GetPointer<const flatbuffers::String *>(VT_NAME);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<TosaOperator>> *operators() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TosaOperator>> *>(VT_OPERATORS);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<TosaTensor>> *tensors() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TosaTensor>> *>(VT_TENSORS);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *inputs() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_INPUTS);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *outputs() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OUTPUTS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) &&
+ VerifyOffset(verifier, VT_OPERATORS) &&
+ verifier.VerifyVector(operators()) &&
+ verifier.VerifyVectorOfTables(operators()) &&
+ VerifyOffset(verifier, VT_TENSORS) &&
+ verifier.VerifyVector(tensors()) &&
+ verifier.VerifyVectorOfTables(tensors()) &&
+ VerifyOffset(verifier, VT_INPUTS) &&
+ verifier.VerifyVector(inputs()) &&
+ verifier.VerifyVectorOfStrings(inputs()) &&
+ VerifyOffset(verifier, VT_OUTPUTS) &&
+ verifier.VerifyVector(outputs()) &&
+ verifier.VerifyVectorOfStrings(outputs()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TosaBasicBlockBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_name(flatbuffers::Offset<flatbuffers::String> name) {
+ fbb_.AddOffset(TosaBasicBlock::VT_NAME, name);
+ }
+ void add_operators(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaOperator>>> operators) {
+ fbb_.AddOffset(TosaBasicBlock::VT_OPERATORS, operators);
+ }
+ void add_tensors(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaTensor>>> tensors) {
+ fbb_.AddOffset(TosaBasicBlock::VT_TENSORS, tensors);
+ }
+ void add_inputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs) {
+ fbb_.AddOffset(TosaBasicBlock::VT_INPUTS, inputs);
+ }
+ void add_outputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs) {
+ fbb_.AddOffset(TosaBasicBlock::VT_OUTPUTS, outputs);
+ }
+ explicit TosaBasicBlockBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TosaBasicBlockBuilder &operator=(const TosaBasicBlockBuilder &);
+ flatbuffers::Offset<TosaBasicBlock> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TosaBasicBlock>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TosaBasicBlock> CreateTosaBasicBlock(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> name = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaOperator>>> operators = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaTensor>>> tensors = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs = 0) {
+ TosaBasicBlockBuilder builder_(_fbb);
+ builder_.add_outputs(outputs);
+ builder_.add_inputs(inputs);
+ builder_.add_tensors(tensors);
+ builder_.add_operators(operators);
+ builder_.add_name(name);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TosaBasicBlock> CreateTosaBasicBlockDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *name = nullptr,
+ const std::vector<flatbuffers::Offset<TosaOperator>> *operators = nullptr,
+ const std::vector<flatbuffers::Offset<TosaTensor>> *tensors = nullptr,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *inputs = nullptr,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputs = nullptr) {
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ auto operators__ = operators ? _fbb.CreateVector<flatbuffers::Offset<TosaOperator>>(*operators) : 0;
+ auto tensors__ = tensors ? _fbb.CreateVector<flatbuffers::Offset<TosaTensor>>(*tensors) : 0;
+ auto inputs__ = inputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*inputs) : 0;
+ auto outputs__ = outputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*outputs) : 0;
+ return tosa::CreateTosaBasicBlock(
+ _fbb,
+ name__,
+ operators__,
+ tensors__,
+ inputs__,
+ outputs__);
+}
+
+struct TosaGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VERSION = 4,
+ VT_BLOCKS = 6
+ };
+ const Version *version() const {
+ return GetPointer<const Version *>(VT_VERSION);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<TosaBasicBlock>> *blocks() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TosaBasicBlock>> *>(VT_BLOCKS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_VERSION) &&
+ verifier.VerifyTable(version()) &&
+ VerifyOffset(verifier, VT_BLOCKS) &&
+ verifier.VerifyVector(blocks()) &&
+ verifier.VerifyVectorOfTables(blocks()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TosaGraphBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_version(flatbuffers::Offset<Version> version) {
+ fbb_.AddOffset(TosaGraph::VT_VERSION, version);
+ }
+ void add_blocks(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaBasicBlock>>> blocks) {
+ fbb_.AddOffset(TosaGraph::VT_BLOCKS, blocks);
+ }
+ explicit TosaGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TosaGraphBuilder &operator=(const TosaGraphBuilder &);
+ flatbuffers::Offset<TosaGraph> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TosaGraph>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TosaGraph> CreateTosaGraph(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<Version> version = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaBasicBlock>>> blocks = 0) {
+ TosaGraphBuilder builder_(_fbb);
+ builder_.add_blocks(blocks);
+ builder_.add_version(version);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TosaGraph> CreateTosaGraphDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<Version> version = 0,
+ const std::vector<flatbuffers::Offset<TosaBasicBlock>> *blocks = nullptr) {
+ auto blocks__ = blocks ? _fbb.CreateVector<flatbuffers::Offset<TosaBasicBlock>>(*blocks) : 0;
+ return tosa::CreateTosaGraph(
+ _fbb,
+ version,
+ blocks__);
+}
+
+inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type) {
+ switch (type) {
+ case Attribute_NONE: {
+ return true;
+ }
+ case Attribute_Pool2dAttribute: {
+ auto ptr = reinterpret_cast<const Pool2dAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_Conv2dAttribute: {
+ auto ptr = reinterpret_cast<const Conv2dAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_TransposeConv2dAttribute: {
+ auto ptr = reinterpret_cast<const TransposeConv2dAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_ReluNAttribute: {
+ auto ptr = reinterpret_cast<const ReluNAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_AxisAttribute: {
+ auto ptr = reinterpret_cast<const AxisAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_ReshapeAttribute: {
+ auto ptr = reinterpret_cast<const ReshapeAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_SliceAttribute: {
+ auto ptr = reinterpret_cast<const SliceAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_TileAttribute: {
+ auto ptr = reinterpret_cast<const TileAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_ResizeAttribute: {
+ auto ptr = reinterpret_cast<const ResizeAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_ClampAttribute: {
+ auto ptr = reinterpret_cast<const ClampAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_RescaleAttribute: {
+ auto ptr = reinterpret_cast<const RescaleAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_CustomAttribute: {
+ auto ptr = reinterpret_cast<const CustomAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_CondIfAttribute: {
+ auto ptr = reinterpret_cast<const CondIfAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_WhileLoopAttribute: {
+ auto ptr = reinterpret_cast<const WhileLoopAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return false;
+ }
+}
+
+inline bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyAttribute(
+ verifier, values->Get(i), types->GetEnum<Attribute>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline bool VerifyQuantInfo(flatbuffers::Verifier &verifier, const void *obj, QuantInfo type) {
+ switch (type) {
+ case QuantInfo_NONE: {
+ return true;
+ }
+ case QuantInfo_UnaryQuantInfo: {
+ auto ptr = reinterpret_cast<const UnaryQuantInfo *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case QuantInfo_ConvQuantInfo: {
+ auto ptr = reinterpret_cast<const ConvQuantInfo *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case QuantInfo_MatMulQuantInfo: {
+ auto ptr = reinterpret_cast<const MatMulQuantInfo *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case QuantInfo_PadQuantInfo: {
+ auto ptr = reinterpret_cast<const PadQuantInfo *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return false;
+ }
+}
+
+inline bool VerifyQuantInfoVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyQuantInfo(
+ verifier, values->Get(i), types->GetEnum<QuantInfo>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline const tosa::TosaGraph *GetTosaGraph(const void *buf) {
+ return flatbuffers::GetRoot<tosa::TosaGraph>(buf);
+}
+
+inline const tosa::TosaGraph *GetSizePrefixedTosaGraph(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<tosa::TosaGraph>(buf);
+}
+
+inline const char *TosaGraphIdentifier() {
+ return "TOSA";
+}
+
+inline bool TosaGraphBufferHasIdentifier(const void *buf) {
+ return flatbuffers::BufferHasIdentifier(
+ buf, TosaGraphIdentifier());
+}
+
+inline bool VerifyTosaGraphBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<tosa::TosaGraph>(TosaGraphIdentifier());
+}
+
+inline bool VerifySizePrefixedTosaGraphBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<tosa::TosaGraph>(TosaGraphIdentifier());
+}
+
+inline const char *TosaGraphExtension() {
+ return "tosa";
+}
+
+inline void FinishTosaGraphBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<tosa::TosaGraph> root) {
+ fbb.Finish(root, TosaGraphIdentifier());
+}
+
+inline void FinishSizePrefixedTosaGraphBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<tosa::TosaGraph> root) {
+ fbb.FinishSizePrefixed(root, TosaGraphIdentifier());
+}
+
+} // namespace tosa
+
+#endif // FLATBUFFERS_GENERATED_TOSA_TOSA_H_
diff --git a/serialization/tosa_serialization_handler.cpp b/serialization/tosa_serialization_handler.cpp
new file mode 100644
index 0000000..7fe9f47
--- /dev/null
+++ b/serialization/tosa_serialization_handler.cpp
@@ -0,0 +1,1526 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tosa_serialization_handler.h"
+
+#include <iostream>
+using namespace tosa;
+
+TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name,
+ const flatbuffers::Vector<uint32_t>& usage,
+ const flatbuffers::Vector<int32_t>& shape,
+ DType dtype,
+ const flatbuffers::Vector<uint32_t>& format,
+ const flatbuffers::String* npy_filename)
+{
+ _dtype = dtype;
+
+ _usage = new std::vector<Usage>(usage.size());
+ for (uint32_t us : usage)
+ {
+ _usage->push_back((Usage)us);
+ }
+ assert(_usage);
+
+ _format = new std::vector<Format>(format.size());
+ for (uint32_t fm : format)
+ {
+ _format->push_back((Format)fm);
+ }
+ assert(_format);
+
+ _shape = new std::vector<int32_t>(shape.begin(), shape.end());
+
+ _shape = new std::vector<int32_t>(shape.begin(), shape.end());
+ assert(_shape);
+
+ assert(name);
+ _name = new std::string(name->str());
+ assert(_name);
+
+ if (npy_filename)
+ {
+ _npy_filename = new std::string(npy_filename->str());
+ assert(_npy_filename);
+ }
+ else
+ {
+ _npy_filename = nullptr;
+ }
+}
+
+TosaSerializationTensor::TosaSerializationTensor(std::string name,
+ const std::vector<Usage>& usage,
+ const std::vector<int32_t>& shape,
+ DType dtype,
+ const std::vector<Format>& format,
+ const std::string* npy_filename)
+{
+
+ _dtype = dtype;
+
+ _usage = new std::vector<Usage>(usage);
+ assert(_usage);
+
+ _format = new std::vector<Format>(format);
+ assert(_format);
+
+ _shape = new std::vector<int32_t>(shape);
+ assert(_shape);
+
+ _name = new std::string(name);
+ assert(_name);
+
+ if (npy_filename)
+ {
+ _npy_filename = new std::string(*npy_filename);
+ assert(_npy_filename);
+ }
+ else
+ {
+ _npy_filename = nullptr;
+ }
+}
+
+TosaSerializationTensor::TosaSerializationTensor()
+{
+ _dtype = DType_UNKNOWN;
+
+ _usage = new std::vector<Usage>();
+ _format = new std::vector<Format>();
+ _shape = new std::vector<int32_t>();
+ _name = new std::string("UNKNOWN");
+ assert(_usage && _format && _shape && _name);
+
+ _npy_filename = nullptr;
+}
+
+TosaSerializationTensor::TosaSerializationTensor(const TosaSerializationTensor& rhs)
+{
+ _dtype = rhs._dtype;
+
+ assert(rhs._usage);
+ _usage = new std::vector<Usage>(*rhs._usage);
+ assert(_usage);
+
+ assert(rhs._format);
+ _format = new std::vector<Format>(*rhs._format);
+ assert(_format);
+
+ assert(rhs._shape);
+ _shape = new std::vector<int32_t>(*rhs._shape);
+ assert(_shape);
+
+ assert(rhs._name);
+ _name = new std::string(*rhs._name);
+ assert(_name);
+
+ if (rhs._npy_filename)
+ {
+ _npy_filename = new std::string(*rhs._npy_filename);
+ assert(_npy_filename);
+ }
+ else
+ {
+ _npy_filename = nullptr;
+ }
+}
+
+TosaSerializationTensor& TosaSerializationTensor::operator=(const TosaSerializationTensor& rhs)
+{
+ _dtype = rhs._dtype;
+
+ delete _usage;
+ assert(rhs._usage);
+ _usage = new std::vector<Usage>(*rhs._usage);
+ assert(_usage);
+
+ delete _format;
+ assert(rhs._format);
+ _format = new std::vector<Format>(*rhs._format);
+ assert(_format);
+
+ delete _shape;
+ assert(rhs._shape);
+ _shape = new std::vector<int32_t>(*rhs._shape);
+ assert(_shape);
+
+ delete _name;
+ assert(rhs._name);
+ _name = new std::string(*rhs._name);
+ assert(_name);
+
+ if (_npy_filename)
+ delete _npy_filename;
+
+ if (rhs._npy_filename)
+ {
+ _npy_filename = new std::string(*rhs._npy_filename);
+ }
+ else
+ {
+ _npy_filename = nullptr;
+ }
+ return *this;
+}
+
+TosaSerializationTensor::TosaSerializationTensor(TosaSerializationTensor&& rhs)
+{
+ _dtype = rhs._dtype;
+ std::swap(_format, rhs._format);
+ std::swap(_usage, rhs._usage);
+ std::swap(_shape, rhs._shape);
+ std::swap(_name, rhs._name);
+ std::swap(_npy_filename, rhs._npy_filename);
+}
+
+TosaSerializationTensor& TosaSerializationTensor::operator=(TosaSerializationTensor&& rhs)
+{
+ _dtype = rhs._dtype;
+ std::swap(_format, rhs._format);
+ std::swap(_usage, rhs._usage);
+ std::swap(_shape, rhs._shape);
+ std::swap(_name, rhs._name);
+ std::swap(_npy_filename, rhs._npy_filename);
+ return *this;
+}
+
+TosaSerializationTensor::~TosaSerializationTensor()
+{
+ delete _usage;
+ delete _format;
+ delete _shape;
+ delete _name;
+ if (_npy_filename)
+ delete _npy_filename;
+}
+
+TosaSerializationOperator::TosaSerializationOperator(Op op,
+ Attribute attribute_type,
+ const TosaAttributeBase* attribute,
+ QuantInfo qinfo_type,
+ const TosaQuantInfoBase* qinfo,
+ std::vector<std::string> input_tensor_names,
+ std::vector<std::string> output_tensor_names)
+{
+ _op = op;
+ _attribute_type = attribute_type;
+
+ switch (attribute_type)
+ {
+ case Attribute_NONE:
+ _attribute = new TosaNoneAttribute();
+ break;
+#define DEF_ATTRIBUTE(NAME, ...) \
+ case Attribute_##NAME##Attribute: \
+ _attribute = new Tosa##NAME##Attribute(attribute); \
+ break;
+#include "attribute.def"
+#undef DEF_ATTRIBUTE
+ default:
+ printf("TosaSerializationOperator::TosaSerializationOperator(): Attribute %s not implemented yet\n",
+ EnumNamesAttribute()[attribute_type]);
+ assert(0);
+ }
+
+ _qinfo_type = qinfo_type;
+ switch (qinfo_type)
+ {
+ case QuantInfo_NONE:
+ _qinfo = new TosaNoneQuantInfo();
+ break;
+#define DEF_QUANTIZATION_INFO(NAME, ...) \
+ case QuantInfo_##NAME##QuantInfo: \
+ _qinfo = new Tosa##NAME##QuantInfo(qinfo); \
+ break;
+#include "quant_info.def"
+#undef DEF_QUANTIZATION_INFO
+ default:
+ printf("TosaSerializationOperator::TosaSerializationOperator(): QuantInfo %s not implemented yet\n",
+ EnumNamesQuantInfo()[qinfo_type]);
+ assert(0);
+ }
+
+ assert(_attribute && _qinfo);
+
+ _input_tensor_names = new std::vector<std::string>(input_tensor_names);
+ _output_tensor_names = new std::vector<std::string>(output_tensor_names);
+
+ assert(_input_tensor_names && _output_tensor_names);
+
+ _input_tensors = new std::vector<TosaSerializationTensor*>();
+ _output_tensors = new std::vector<TosaSerializationTensor*>();
+
+ assert(_input_tensors && _output_tensors);
+}
+
+TosaSerializationOperator::~TosaSerializationOperator()
+{
+ delete _attribute;
+ delete _qinfo;
+ delete _input_tensor_names;
+ delete _output_tensor_names;
+ // TosaSerializationTensor should be free'd in TosaSerializationSerializationHandler destructor
+ delete _input_tensors;
+ delete _output_tensors;
+}
+
+TosaSerializationBasicBlock::TosaSerializationBasicBlock(std::string name,
+ std::vector<TosaSerializationOperator*> operators,
+ std::vector<TosaSerializationTensor*> tensors,
+ std::vector<std::string> inputs,
+ std::vector<std::string> outputs)
+{
+
+ _name = new std::string(name);
+ assert(_name);
+
+ _operators = new std::vector<TosaSerializationOperator*>(operators);
+ assert(_operators);
+
+ _tensors = new std::vector<TosaSerializationTensor*>(tensors);
+ assert(_tensors);
+
+ _inputs = new std::vector<std::string>(inputs);
+ assert(_inputs);
+
+ _outputs = new std::vector<std::string>(outputs);
+ assert(_outputs);
+}
+
+TosaSerializationBasicBlock::~TosaSerializationBasicBlock()
+{
+ delete _name;
+
+ // deallocate all operators
+ for (auto op : GetOperators())
+ {
+ delete op; // ~TosaSerializationOperator()
+ }
+ delete _operators;
+
+ // deallocate all tensors
+ for (auto ts : GetTensors())
+ {
+ delete ts; // ~TosaSerializationTensor()
+ }
+ _tensors->clear();
+
+ delete _inputs;
+ delete _outputs;
+}
+
+TosaSerializationHandler::TosaSerializationHandler()
+{
+ _schemaLoaded = false;
+ _builder = new flatbuffers::FlatBufferBuilder();
+ _parser = new flatbuffers::Parser();
+ _blocks = new std::vector<TosaSerializationBasicBlock*>();
+
+ assert(_builder && _parser && _blocks);
+
+ SetTosaVersion();
+}
+
+TosaSerializationHandler::~TosaSerializationHandler()
+{
+ if (_version)
+ delete _version;
+ delete _builder;
+ delete _parser;
+
+ Clear(); // deallocate all basic blocks
+
+ delete _blocks;
+}
+
+tosa_err_t TosaSerializationHandler::SetTosaVersion()
+{
+ // version is specified within .fbs
+ // and it's encoded as defaulted value of CreateTosaVersion()
+ // need to write out one object to read that value out
+ // TODO: very costly now. is there any better way to encode constant in .fbs?
+ auto fboffset_version = CreateVersion(*_builder);
+ auto fboffset_tosa_graph = CreateTosaGraphDirect(*_builder, fboffset_version, nullptr);
+ _builder->Finish(fboffset_tosa_graph);
+ std::string jsongen;
+ uint8_t* buf = _builder->GetBufferPointer();
+ auto fb_tosa_graph = GetTosaGraph(buf);
+ auto fb_tosa_version = fb_tosa_graph->version();
+
+ _version = new TosaVersion(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(),
+ fb_tosa_version->_experimental());
+
+ assert(_version);
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::LoadFileSchema(const char* schema_filename)
+{
+ std::string schema;
+ bool ok;
+
+ ok = flatbuffers::LoadFile(schema_filename, false, &schema);
+ if (!ok)
+ {
+ printf("Error loading schema file: %s\n", schema_filename);
+ return TOSA_FILE_ERROR;
+ }
+
+ ok = _parser->Parse(schema.c_str());
+ if (!ok)
+ {
+ printf("Error parsing ISA schema file: %s\n", schema_filename);
+ return TOSA_FILE_ERROR;
+ }
+ _schemaLoaded = true;
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::LoadFileJson(const char* filename)
+{
+ std::string jsonfile;
+ bool ok;
+ tosa_err_t err;
+
+ if (!_schemaLoaded)
+ {
+ return TOSA_SCHEMA_MISSING;
+ }
+
+ ok = flatbuffers::LoadFile(filename, false, &jsonfile);
+ if (!ok)
+ {
+ printf("Error loading json file: %s\n", filename);
+ return TOSA_FILE_ERROR;
+ }
+
+ ok = _parser->Parse(jsonfile.c_str());
+ if (!ok)
+ {
+ printf("Error parsing json file: %s\n", filename);
+ return TOSA_FILE_ERROR;
+ }
+
+ uint8_t* buf = _parser->builder_.GetBufferPointer();
+
+ err = InitWithBuf(buf);
+ if (err != TOSA_OK)
+ {
+ return err;
+ }
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::SaveFileJson(const char* filename)
+{
+ std::string jsongen;
+ tosa_err_t err;
+
+ if (!_schemaLoaded)
+ {
+ return TOSA_SCHEMA_MISSING;
+ }
+
+ err = FreezeBuilder();
+ if (err != TOSA_OK)
+ {
+ return err;
+ }
+
+ uint8_t* buf = _builder->GetBufferPointer();
+
+ if (!GenerateText(*_parser, buf, &jsongen))
+ {
+ printf("Couldn't serialize parsed data to JSON!\n");
+ return TOSA_FILE_ERROR;
+ }
+
+ FILE* file = fopen(filename, "wb");
+
+ if (!file)
+ {
+ printf("Couldn't open output file: %s\n", filename);
+ return TOSA_FILE_ERROR;
+ }
+
+ if (fwrite(jsongen.c_str(), sizeof(char), jsongen.size(), file) != jsongen.size())
+ {
+ printf("Error writing to json output file: %s\n", filename);
+ fclose(file);
+ return TOSA_FILE_ERROR;
+ }
+
+ if (file)
+ fclose(file);
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::LoadFileTosaFlatbuffer(const char* filename)
+{
+ std::string read_buffer;
+ tosa_err_t err;
+ uint8_t* buf;
+ bool ok;
+
+ ok = flatbuffers::LoadFile(filename, false, &read_buffer);
+ if (!ok)
+ {
+ printf("Error loading flatbuffer file: %s\n", filename);
+ return TOSA_FILE_ERROR;
+ }
+
+ buf = (uint8_t*)read_buffer.data();
+
+ err = InitWithBuf(buf);
+ if (err != TOSA_OK)
+ {
+ return err;
+ }
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::SaveFileTosaFlatbuffer(const char* filename)
+{
+ tosa_err_t err;
+
+ err = FreezeBuilder();
+ if (err != TOSA_OK)
+ {
+ return err;
+ }
+
+ uint8_t* buf = _builder->GetBufferPointer();
+
+ bool ok = flatbuffers::SaveFile(filename, (const char*)buf, _builder->GetSize(), false);
+ if (!ok)
+ {
+ printf("Error saving floatbuffer file: %s\n", filename);
+ return TOSA_FILE_ERROR;
+ }
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::Clear()
+{
+ // deallocate all basic blocks
+ for (auto bb : GetBlocks())
+ {
+ delete bb;
+ }
+ _blocks->clear();
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::CheckTosaVersion(const TosaVersion& read_version)
+{
+ if ((*_version) != read_version)
+ {
+ printf("WARNING: read tosa version: %s != schema tosa version %s\n", read_version.to_string().c_str(),
+ this->_version->to_string().c_str());
+ return TOSA_VERSION_MISMATCH;
+ }
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::InitWithBuf(const uint8_t* buf)
+{
+ auto fb_tosa_graph = GetTosaGraph(buf);
+ auto fb_tosa_version = fb_tosa_graph->version();
+ auto fb_tosa_blocks = fb_tosa_graph->blocks();
+
+ std::vector<std::string> operator_inputs_container;
+ std::vector<std::string> operator_outputs_container;
+
+ std::vector<TosaSerializationOperator*> block_operators_container;
+ std::vector<TosaSerializationTensor*> block_tensors_container;
+ std::vector<std::string> block_inputs_container;
+ std::vector<std::string> block_outputs_container;
+
+ TosaAttributeBase* typed_attribute = NULL;
+ TosaQuantInfoBase* typed_qinfo = NULL;
+ TosaSerializationOperator* new_operator = NULL;
+ TosaSerializationBasicBlock* new_block = NULL;
+ TosaSerializationTensor* new_tensor = NULL;
+
+ // erase container
+ Clear();
+
+ TosaVersion read_version(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(),
+ fb_tosa_version->_experimental());
+ tosa_err_t err = CheckTosaVersion(read_version);
+
+ if (err != TOSA_OK)
+ return err;
+
+ for (size_t i = 0; i < fb_tosa_blocks->size(); i++)
+ {
+ auto curr_block = fb_tosa_blocks->Get(i);
+
+ auto block_name = curr_block->name()->str();
+
+ auto fb_tosa_operators = curr_block->operators();
+ block_operators_container.clear();
+ for (size_t j = 0; j < fb_tosa_operators->size(); j++)
+ {
+ auto curr_operator = fb_tosa_operators->Get(j);
+
+ auto operator_op = curr_operator->op();
+ auto attribute_type = curr_operator->attribute_type();
+ auto attribute = curr_operator->attribute();
+ auto operator_qinfo_type = curr_operator->quant_info_type();
+ auto operator_qinfo = curr_operator->quant_info();
+
+ // input tensors
+ auto operator_inputs = curr_operator->inputs();
+ operator_inputs_container.clear();
+ if (operator_inputs)
+ {
+ for (size_t k = 0; k < operator_inputs->size(); k++)
+ {
+ auto curr_input = operator_inputs->Get(k);
+ operator_inputs_container.push_back(curr_input->str());
+ }
+ }
+
+ // output tensors
+ auto operator_outputs = curr_operator->outputs();
+ operator_outputs_container.clear();
+ if (operator_outputs)
+ {
+ for (size_t k = 0; k < operator_outputs->size(); k++)
+ {
+ auto curr_output = operator_outputs->Get(k);
+ operator_outputs_container.push_back(curr_output->str());
+ }
+ }
+
+ switch (attribute_type)
+ {
+ case Attribute_NONE:
+ typed_attribute = new TosaNoneAttribute();
+ break;
+#define DEF_ATTRIBUTE(NAME, ...) \
+ case Attribute_##NAME##Attribute: \
+ typed_attribute = new Tosa##NAME##Attribute(attribute); \
+ break;
+#include "attribute.def"
+#undef DEF_ATTRIBUTE
+ default:
+ printf("TosaSerializationHandler::InitWithBuf(): Attribute %s not implemented yet\n",
+ EnumNamesAttribute()[attribute_type]);
+ return TOSA_INTERNAL_ERROR;
+ }
+
+ switch (operator_qinfo_type)
+ {
+ case QuantInfo_NONE:
+ typed_qinfo = new TosaNoneQuantInfo();
+ break;
+#define DEF_QUANTIZATION_INFO(NAME, ...) \
+ case QuantInfo_##NAME##QuantInfo: \
+ typed_qinfo = new Tosa##NAME##QuantInfo(operator_qinfo); \
+ break;
+
+#include "quant_info.def"
+#undef DEF_QUANTIZATION_INFO
+ default:
+ printf("TosaSerializationHandler::InitWithBuf(): QuantInfo %s not implemented yet\n",
+ EnumNamesQuantInfo()[operator_qinfo_type]);
+ return TOSA_INTERNAL_ERROR;
+ }
+
+ new_operator =
+ new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, operator_qinfo_type,
+ typed_qinfo, operator_inputs_container, operator_outputs_container);
+ if (new_operator)
+ {
+ block_operators_container.push_back(new_operator);
+ }
+ else
+ {
+ return TOSA_MEMORY_ERROR;
+ }
+
+ if (typed_attribute)
+ delete typed_attribute;
+ if (typed_qinfo)
+ delete typed_qinfo;
+ }
+
+ auto fb_tosa_tensors = curr_block->tensors();
+ block_tensors_container.clear();
+ for (size_t j = 0; j < fb_tosa_tensors->size(); j++)
+ {
+ auto curr_tensor = fb_tosa_tensors->Get(j);
+
+ auto tensor_name = curr_tensor->name();
+ auto tensor_usage = curr_tensor->usage();
+ auto tensor_shape = curr_tensor->shape();
+ auto tensor_type = curr_tensor->type();
+ auto tensor_format = curr_tensor->format();
+ auto tensor_npy_filename = curr_tensor->npy_filename();
+
+ new_tensor = new TosaSerializationTensor(tensor_name, *tensor_usage, *tensor_shape, tensor_type,
+ *tensor_format, tensor_npy_filename);
+ if (new_tensor)
+ {
+ block_tensors_container.push_back(new_tensor);
+ }
+ else
+ {
+ return TOSA_MEMORY_ERROR;
+ }
+ }
+
+ auto block_inputs = curr_block->inputs();
+ auto block_outputs = curr_block->outputs();
+
+ block_inputs_container.clear();
+ block_outputs_container.clear();
+
+ for (size_t j = 0; j < block_inputs->size(); j++)
+ {
+ auto curr_block_input = block_inputs->Get(j);
+ block_inputs_container.push_back(curr_block_input->str());
+ }
+ for (size_t j = 0; j < block_outputs->size(); j++)
+ {
+ auto curr_block_output = block_outputs->Get(j);
+ block_outputs_container.push_back(curr_block_output->str());
+ }
+
+ new_block = new TosaSerializationBasicBlock(block_name, block_operators_container, block_tensors_container,
+ block_inputs_container, block_outputs_container);
+ if (new_block)
+ {
+ this->GetBlocks().push_back(new_block);
+ }
+ else
+ {
+ return TOSA_MEMORY_ERROR;
+ }
+ }
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::FreezeBuilder()
+{
+ std::vector<flatbuffers::Offset<TosaBasicBlock>> fboffset_blocks;
+
+ std::vector<flatbuffers::Offset<TosaOperator>> fboffset_block_operators;
+ std::vector<flatbuffers::Offset<TosaTensor>> fboffset_block_tensors;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_inputs;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_outputs;
+
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_inputs;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_outputs;
+
+ // translate TosaFlatbufferOperator to flatbuffers::Offset<TosaOperator>
+ for (auto block : GetBlocks())
+ {
+ fboffset_block_operators.clear();
+ fboffset_block_tensors.clear();
+ fboffset_block_inputs.clear();
+ fboffset_block_outputs.clear();
+
+ auto block_name = _builder->CreateString(block->GetName().c_str());
+
+ for (auto tensor_str : block->GetInputs())
+ {
+ auto tensor_name = _builder->CreateString(tensor_str.c_str());
+ fboffset_block_inputs.push_back(tensor_name);
+ }
+
+ for (auto tensor_str : block->GetOutputs())
+ {
+ auto tensor_name = _builder->CreateString(tensor_str.c_str());
+ fboffset_block_outputs.push_back(tensor_name);
+ }
+
+ auto fb_block_inputs = _builder->CreateVector(fboffset_block_inputs);
+ auto fb_block_outputs = _builder->CreateVector(fboffset_block_outputs);
+
+ for (auto op : block->GetOperators())
+ {
+ fboffset_operator_inputs.clear();
+ fboffset_operator_outputs.clear();
+
+ auto operator_op = op->GetOp();
+ auto attribute_type = op->GetAttributeType();
+
+ for (auto tensor_str : op->GetInputTensorNames())
+ {
+ auto tensor_name = _builder->CreateString(tensor_str.c_str());
+ fboffset_operator_inputs.push_back(tensor_name);
+ }
+
+ for (auto tensor_str : op->GetOutputTensorNames())
+ {
+ auto tensor_name = _builder->CreateString(tensor_str.c_str());
+ fboffset_operator_outputs.push_back(tensor_name);
+ }
+
+ auto fb_operator_inputs = _builder->CreateVector(fboffset_operator_inputs);
+ auto fb_operator_outputs = _builder->CreateVector(fboffset_operator_outputs);
+
+ flatbuffers::Offset<void> fb_attribute;
+ switch (attribute_type)
+ {
+ case Attribute_NONE:
+ fb_attribute = 0;
+ break;
+
+#define DEF_ARGS_S_STR(NAME, V) , _builder->CreateString(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V().c_str())
+#define DEF_ARGS_S_DEFAULT(NAME, V) , reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V()
+
+#define DEF_ARGS_S_int32_t(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
+#define DEF_ARGS_S_float(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
+#define DEF_ARGS_S_bool(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
+#define DEF_ARGS_S_ResizeMode(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
+#define DEF_ARGS_S_string(NAME, V) DEF_ARGS_S_STR(NAME, V)
+
+#define DEF_ARGS_S(NAME, T, V) DEF_ARGS_S_##T(NAME, V)
+#define DEF_ARGS_V(NAME, T, V) , _builder->CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V())
+
+#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0)
+#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1)
+#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2)
+#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3)
+#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4)
+#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5)
+#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6)
+#define DEF_ATTRIBUTE(NAME, NUM_ARGS, ...) \
+ case Attribute_##NAME##Attribute: \
+ fb_attribute = Create##NAME##Attribute(*_builder DEF_ARGS_##NUM_ARGS(NAME##Attribute, __VA_ARGS__)).Union(); \
+ break;
+
+#include "attribute.def"
+#undef DEF_ATTRIBUTE
+#undef DEF_ARGS_1
+#undef DEF_ARGS_2
+#undef DEF_ARGS_3
+#undef DEF_ARGS_4
+#undef DEF_ARGS_5
+#undef DEF_ARGS_6
+#undef DEF_ARGS_7
+#undef DEF_ARGS_S
+#undef DEF_ARGS_V
+#undef DEF_ARGS_S_int32_t
+#undef DEF_ARGS_S_float
+#undef DEF_ARGS_S_bool
+#undef DEF_ARGS_S_ResizeMode
+#undef DEF_ARGS_S_string
+#undef DEF_ARGS_S_STR
+#undef DEF_ARGS_S_DEFAULT
+ default:
+ printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n",
+ EnumNamesAttribute()[attribute_type]);
+ return TOSA_INTERNAL_ERROR;
+ }
+
+ auto qinfo_type = op->GetQInfoType();
+ flatbuffers::Offset<void> fb_operator_qinfo;
+ switch (qinfo_type)
+ {
+ case QuantInfo_NONE:
+ fb_operator_qinfo = 0;
+ break;
+#define DEF_ARGS_S(NAME, T, V) , reinterpret_cast<Tosa##NAME*>(op->GetQInfo())->V()
+#define DEF_ARGS_V(NAME, T, V) , _builder->CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetQInfo())->V())
+
+#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0)
+#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1)
+#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2)
+#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3)
+#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4)
+#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5)
+#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6)
+#define DEF_ARGS_8(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
+ DEF_ARGS_##F7(NAME, T7, V7)
+#define DEF_ARGS_9(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7, T8, F8, V8) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
+ DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8)
+#define DEF_ARGS_10(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7, T8, F8, V8, T9, F9, V9) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
+ DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) DEF_ARGS_##F9(NAME, T9, V9)
+#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \
+ case QuantInfo_##NAME##QuantInfo: \
+ fb_operator_qinfo = \
+ Create##NAME##QuantInfo(*_builder DEF_ARGS_##NUM_ARGS(NAME##QuantInfo, __VA_ARGS__)).Union(); \
+ break;
+
+#include "quant_info.def"
+#undef DEF_QUANTIZATION_INFO
+#undef DEF_ARGS_1
+#undef DEF_ARGS_2
+#undef DEF_ARGS_3
+#undef DEF_ARGS_4
+#undef DEF_ARGS_5
+#undef DEF_ARGS_6
+#undef DEF_ARGS_7
+#undef DEF_ARGS_8
+#undef DEF_ARGS_9
+#undef DEF_ARGS_10
+#undef DEF_ARGS_S
+#undef DEF_ARGS_V
+ default:
+ printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n",
+ EnumNamesAttribute()[attribute_type]);
+ return TOSA_INTERNAL_ERROR;
+ }
+
+ auto fboffset_operator =
+ CreateTosaOperator(*_builder, operator_op, attribute_type, fb_attribute, fb_operator_inputs,
+ fb_operator_outputs, qinfo_type, fb_operator_qinfo);
+ fboffset_block_operators.push_back(fboffset_operator);
+ }
+
+ auto fb_block_operators = _builder->CreateVector(fboffset_block_operators);
+
+ for (auto tensor : block->GetTensors())
+ {
+
+ auto tensor_name = _builder->CreateString(tensor->GetName().c_str());
+ auto tensor_usage =
+ _builder->CreateVector(std::vector<uint32_t>(tensor->GetUsage().begin(), tensor->GetUsage().end()));
+ auto tensor_shape = _builder->CreateVector(tensor->GetShape());
+ auto tensor_dtype = tensor->GetDtype();
+ auto tensor_format =
+ _builder->CreateVector(std::vector<uint32_t>(tensor->GetFormat().begin(), tensor->GetFormat().end()));
+ flatbuffers::Offset<flatbuffers::String> tensor_npy_filename = 0;
+ if (tensor->GetNpyFilePtr())
+ tensor_npy_filename = _builder->CreateString(tensor->GetNpyFilePtr()->c_str());
+
+ auto fboffset_tensor = CreateTosaTensor(*_builder, tensor_name, tensor_shape, tensor_dtype, tensor_usage,
+ tensor_format, tensor_npy_filename);
+ fboffset_block_tensors.push_back(fboffset_tensor);
+ }
+
+ auto fb_block_tensors = _builder->CreateVector(fboffset_block_tensors);
+
+ auto fboffset_block = CreateTosaBasicBlock(*_builder, block_name, fb_block_operators, fb_block_tensors,
+ fb_block_inputs, fb_block_outputs);
+ fboffset_blocks.push_back(fboffset_block);
+ }
+
+ auto fb_blocks = _builder->CreateVector(fboffset_blocks);
+
+ auto fb_version = CreateVersion(*_builder, GetTosaVersion()->_major, GetTosaVersion()->_minor,
+ GetTosaVersion()->_patch, GetTosaVersion()->_experimental);
+
+ auto fb_graph = CreateTosaGraph(*_builder, fb_version, fb_blocks);
+ _builder->Finish(fb_graph);
+
+ return TOSA_OK;
+}
+
+// Magic NUMPY header
+static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
+static const int NUMPY_HEADER_SZ = 128;
+
+NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
+{
+ const char dtype_str[] = "'|b1'";
+ FILE* infile = nullptr;
+ NPError rc = NO_ERROR;
+
+ assert(filename);
+ assert(databuf);
+
+ infile = fopen(filename, "rb");
+ if (!infile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ rc = checkNpyHeader(infile, elems, dtype_str);
+ if (rc != NO_ERROR)
+ {
+ goto done;
+ }
+
+ // Read in the data from numpy byte array to native bool
+ // array format
+ for (uint32_t i = 0; i < elems; i++)
+ {
+ int val = fgetc(infile);
+
+ if (val == EOF)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+ databuf[i] = val;
+ }
+
+done:
+
+ if (infile)
+ fclose(infile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
+{
+ const char dtype_str[] = "'<i4'";
+ FILE* infile = nullptr;
+ NPError rc = NO_ERROR;
+
+ assert(filename);
+ assert(databuf);
+
+ infile = fopen(filename, "rb");
+ if (!infile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ rc = checkNpyHeader(infile, elems, dtype_str);
+ if (rc != NO_ERROR)
+ {
+ goto done;
+ }
+
+ // Now we are at the beginning of the data
+ // Parse based on the datatype and number of dimensions
+ if (fread(databuf, sizeof(int32_t), elems, infile) != elems)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ if (infile)
+ fclose(infile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
+{
+ const char dtype_str[] = "'<i8'";
+ FILE* infile = nullptr;
+ NPError rc = NO_ERROR;
+
+ assert(filename);
+ assert(databuf);
+
+ infile = fopen(filename, "rb");
+ if (!infile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ rc = checkNpyHeader(infile, elems, dtype_str);
+ if (rc != NO_ERROR)
+ {
+ goto done;
+ }
+
+ // Now we are at the beginning of the data
+ // Parse based on the datatype and number of dimensions
+ if (fread(databuf, sizeof(int64_t), elems, infile) != elems)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ if (infile)
+ fclose(infile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
+{
+ const char dtype_str[] = "'<f4'";
+ FILE* infile = nullptr;
+ NPError rc = NO_ERROR;
+
+ assert(filename);
+ assert(databuf);
+
+ infile = fopen(filename, "rb");
+ if (!infile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ rc = checkNpyHeader(infile, elems, dtype_str);
+ if (rc != NO_ERROR)
+ {
+ goto done;
+ }
+
+ // Now we are at the beginning of the data
+ // Parse based on the datatype and number of dimensions
+ if (fread(databuf, sizeof(float), elems, infile) != elems)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ if (infile)
+ fclose(infile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
+{
+ char buf[NUMPY_HEADER_SZ + 1];
+ char* ptr = nullptr;
+ NPError rc = NO_ERROR;
+ bool foundFormat = false;
+ bool foundOrder = false;
+ bool foundShape = false;
+ bool fortranOrder = false;
+ std::vector<int> shape;
+ uint32_t totalElems = 1;
+ char* outer_end = NULL;
+
+ assert(infile);
+ assert(elems > 0);
+
+ if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
+ {
+ rc = HEADER_PARSE_ERROR;
+ goto done;
+ }
+
+ if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
+ {
+ rc = HEADER_PARSE_ERROR;
+ goto done;
+ }
+
+ ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
+
+ // Read in the data type, order, and shape
+ while (ptr && (!foundFormat || !foundOrder || !foundShape))
+ {
+
+ // End of string?
+ if (!ptr)
+ break;
+
+ // Skip whitespace
+ while (isspace(*ptr))
+ ptr++;
+
+ // Parse the dictionary field name
+ if (!strcmp(ptr, "'descr'"))
+ {
+ ptr = strtok_r(NULL, ",", &outer_end);
+ if (!ptr)
+ break;
+
+ while (isspace(*ptr))
+ ptr++;
+
+ if (strcmp(ptr, dtype_str))
+ {
+ rc = FILE_TYPE_MISMATCH;
+ goto done;
+ }
+
+ foundFormat = true;
+ }
+ else if (!strcmp(ptr, "'fortran_order'"))
+ {
+ ptr = strtok_r(NULL, ",", &outer_end);
+ if (!ptr)
+ break;
+
+ while (isspace(*ptr))
+ ptr++;
+
+ if (!strcmp(ptr, "False"))
+ {
+ fortranOrder = false;
+ }
+ else
+ {
+ rc = FILE_TYPE_MISMATCH;
+ goto done;
+ }
+
+ foundOrder = true;
+ }
+ else if (!strcmp(ptr, "'shape'"))
+ {
+
+ ptr = strtok_r(NULL, "(", &outer_end);
+ if (!ptr)
+ break;
+ ptr = strtok_r(NULL, ")", &outer_end);
+ if (!ptr)
+ break;
+
+ while (isspace(*ptr))
+ ptr++;
+
+ // The shape contains N comma-separated integers. Read up to 4.
+ char* end = NULL;
+
+ ptr = strtok_r(ptr, ",", &end);
+ for (int i = 0; i < 4; i++)
+ {
+ // Out of dimensions
+ if (!ptr)
+ break;
+
+ shape.push_back(atoi(ptr));
+ totalElems *= atoi(ptr);
+ ptr = strtok_r(NULL, ",", &end);
+ }
+
+ foundShape = true;
+ }
+ else
+ {
+ rc = HEADER_PARSE_ERROR;
+ goto done;
+ }
+
+ if (!ptr)
+ break;
+
+ ptr = strtok_r(NULL, ":", &outer_end);
+ }
+
+ if (!foundShape || !foundFormat || !foundOrder)
+ {
+ rc = HEADER_PARSE_ERROR;
+ goto done;
+ }
+
+ // Validate header
+ if (fortranOrder != false)
+ {
+ rc = FILE_TYPE_MISMATCH;
+ goto done;
+ }
+
+ if (totalElems != elems)
+ {
+ rc = BUFFER_SIZE_MISMATCH;
+ goto done;
+ }
+
+ // Go back to the begininng and read until the end of the header dictionary
+ rewind(infile);
+ int val;
+
+ do
+ {
+ val = fgetc(infile);
+ } while (val != EOF && val != '\n');
+
+done:
+
+ return rc;
+}
+
+NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
+{
+ std::vector<int32_t> shape = { (int32_t)elems };
+ return writeToNpyFile(filename, shape, databuf);
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
+{
+ const char dtype_str[] = "'|b1'";
+ FILE* outfile = nullptr;
+ NPError rc = NO_ERROR;
+ uint32_t totalElems = 1;
+
+ assert(filename);
+ assert(shape.size() >= 0);
+ assert(databuf);
+
+ outfile = fopen(filename, "wb");
+
+ if (!outfile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ for (uint32_t i = 0; i < shape.size(); i++)
+ {
+ totalElems *= shape[i];
+ }
+
+ rc = writeNpyHeader(outfile, shape, dtype_str);
+
+ // Numpy save format stores booleans as a byte array
+ // with one byte per boolean. This somewhat inefficiently
+ // remaps from system bool[] to this format.
+ for (uint32_t i = 0; i < totalElems; i++)
+ {
+ int val = databuf[i] ? 1 : 0;
+ if (fputc(val, outfile) == EOF)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+ }
+
+done:
+
+ if (outfile)
+ fclose(outfile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
+{
+ std::vector<int32_t> shape = { (int32_t)elems };
+ return writeToNpyFile(filename, shape, databuf);
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
+{
+ const char dtype_str[] = "'<i4'";
+ FILE* outfile = nullptr;
+ NPError rc = NO_ERROR;
+ uint32_t totalElems = 1;
+
+ assert(filename);
+ assert(shape.size() >= 0);
+ assert(databuf);
+
+ outfile = fopen(filename, "wb");
+
+ if (!outfile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ for (uint32_t i = 0; i < shape.size(); i++)
+ {
+ totalElems *= shape[i];
+ }
+
+ rc = writeNpyHeader(outfile, shape, dtype_str);
+
+ if (fwrite(databuf, sizeof(int32_t), totalElems, outfile) != totalElems)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ if (outfile)
+ fclose(outfile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
+{
+ std::vector<int32_t> shape = { (int32_t)elems };
+ return writeToNpyFile(filename, shape, databuf);
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
+{
+ const char dtype_str[] = "'<i8'";
+ FILE* outfile = nullptr;
+ NPError rc = NO_ERROR;
+ uint32_t totalElems = 1;
+
+ assert(filename);
+ assert(shape.size() >= 0);
+ assert(databuf);
+
+ outfile = fopen(filename, "wb");
+
+ if (!outfile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ for (uint32_t i = 0; i < shape.size(); i++)
+ {
+ totalElems *= shape[i];
+ }
+
+ rc = writeNpyHeader(outfile, shape, dtype_str);
+
+ if (fwrite(databuf, sizeof(int64_t), totalElems, outfile) != totalElems)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ if (outfile)
+ fclose(outfile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
+{
+ std::vector<int32_t> shape = { (int32_t)elems };
+ return writeToNpyFile(filename, shape, databuf);
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
+{
+ const char dtype_str[] = "'<f4'";
+ FILE* outfile = nullptr;
+ NPError rc = NO_ERROR;
+ uint32_t totalElems = 1;
+
+ assert(filename);
+ assert(shape.size() >= 0);
+ assert(databuf);
+
+ outfile = fopen(filename, "wb");
+
+ if (!outfile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ for (uint32_t i = 0; i < shape.size(); i++)
+ {
+ totalElems *= shape[i];
+ }
+
+ rc = writeNpyHeader(outfile, shape, dtype_str);
+
+ if (fwrite(databuf, sizeof(float), totalElems, outfile) != totalElems)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ if (outfile)
+ fclose(outfile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
+{
+ NPError rc = NO_ERROR;
+ uint32_t i;
+ char header[NUMPY_HEADER_SZ + 1];
+ int headerPos = 0;
+
+ assert(outfile);
+ assert(shape.size() >= 0);
+
+ // Space-fill the header and end with a newline to start per numpy spec
+ memset(header, 0x20, NUMPY_HEADER_SZ);
+ header[NUMPY_HEADER_SZ - 1] = '\n';
+ header[NUMPY_HEADER_SZ] = 0;
+
+ // Write out the hard-coded header. We only support a 128-byte 1.0 header
+ // for now, which should be sufficient for simple tensor types of any
+ // reasonable rank.
+ memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
+ headerPos += sizeof(NUMPY_HEADER_STR) - 1;
+
+ // Output the format dictionary
+ // Hard-coded for I32 for now
+ headerPos +=
+ snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
+ dtype_str, shape.size() > 0 ? shape[0] : 1);
+
+ // Remainder of shape array
+ for (i = 1; i < shape.size(); i++)
+ {
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
+ }
+
+ // Close off the dictionary
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
+
+ // snprintf leaves a NULL at the end. Replace with a space
+ header[headerPos] = 0x20;
+
+ if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ return rc;
+}
diff --git a/serialization/tosa_serialization_handler.h b/serialization/tosa_serialization_handler.h
new file mode 100644
index 0000000..124b8e0
--- /dev/null
+++ b/serialization/tosa_serialization_handler.h
@@ -0,0 +1,423 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef _TOSA_SERIALIZATION_HANDLER_H
+#define _TOSA_SERIALIZATION_HANDLER_H
+#include "attribute.h"
+#include "flatbuffers/idl.h"
+#include "flatbuffers/util.h"
+#include "quant_info.h"
+#include "tosa_generated.h"
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace tosa
+{
+
+enum tosa_err_t
+{
+ TOSA_OK,
+ TOSA_USER_ERROR,
+ TOSA_FILE_ERROR,
+ TOSA_MEMORY_ERROR,
+ TOSA_SCHEMA_MISSING,
+ TOSA_INTERNAL_ERROR,
+ TOSA_VERSION_MISMATCH,
+ NUM_TOSA_ERROR
+};
+
+struct TosaVersion
+{
+ int32_t _major;
+ int32_t _minor;
+ int32_t _patch;
+ bool _experimental;
+
+ TosaVersion() = delete;
+ TosaVersion(int32_t major, int32_t minor, int32_t patch, bool experimental)
+ {
+ _major = major;
+ _minor = minor;
+ _patch = patch;
+ _experimental = experimental;
+ }
+
+ std::string to_string() const
+ {
+ std::string str;
+ str += std::to_string(_major) + ".";
+ str += std::to_string(_minor) + ".";
+ str += std::to_string(_patch);
+ if (_experimental)
+ str += "(experimental)";
+ return str;
+ };
+
+ bool operator==(const TosaVersion& rhs)
+ {
+ if (rhs._major == _major && rhs._minor == _minor && rhs._patch == _patch && rhs._experimental == _experimental)
+ {
+ return true;
+ }
+ return false;
+ }
+
+ bool operator!=(const TosaVersion& rhs)
+ {
+ return !((*this) == rhs);
+ }
+};
+
+class TosaSerializationHandler;
+
+class TosaSerializationTensor
+{
+public:
+ // constructor and destructor
+ TosaSerializationTensor(const flatbuffers::String* name,
+ const flatbuffers::Vector<uint32_t>& usage,
+ const flatbuffers::Vector<int32_t>& shape,
+ DType dtype,
+ const flatbuffers::Vector<uint32_t>& format,
+ const flatbuffers::String* npy_filename);
+ TosaSerializationTensor(std::string name,
+ const std::vector<Usage>& usage,
+ const std::vector<int32_t>& shape,
+ DType dtype,
+ const std::vector<Format>& format,
+ const std::string* npy_filename);
+ TosaSerializationTensor();
+ ~TosaSerializationTensor();
+
+ // copy constructor/assignment
+ TosaSerializationTensor(const TosaSerializationTensor& rhs);
+ TosaSerializationTensor& operator=(const TosaSerializationTensor& rhs);
+
+ // move constructor/assignment
+ TosaSerializationTensor(TosaSerializationTensor&& rhs);
+ TosaSerializationTensor& operator=(TosaSerializationTensor&& rhs);
+
+ // accessor
+ std::string GetName() const
+ {
+ return *_name;
+ }
+ const std::vector<int32_t>& GetShape() const
+ {
+ return *_shape;
+ }
+ DType GetDtype()
+ {
+ return _dtype;
+ }
+ bool HasFormat(Format format)
+ {
+ for (Format us : *_format)
+ {
+ if (us == format)
+ return true;
+ }
+ return false;
+ }
+ std::vector<Format>& GetFormat()
+ {
+ return *_format;
+ }
+ bool HasUsage(Usage usage)
+ {
+ for (Usage us : *_usage)
+ {
+ if (us == usage)
+ return true;
+ }
+ return false;
+ }
+ std::vector<Usage>& GetUsage()
+ {
+ return *_usage;
+ }
+ std::string* GetNpyFilePtr() const
+ {
+ return _npy_filename;
+ }
+
+ // modifier
+ void SetDtype(DType dtype)
+ {
+ _dtype = dtype;
+ }
+ void SetName(std::string name)
+ {
+ *_name = name;
+ }
+
+private:
+ DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
+ std::vector<Format>* _format; /* list of possible tensor format */
+ std::vector<Usage>* _usage; /* list of possible tensor usage */
+ std::vector<int32_t>* _shape; /* shape of the tensor */
+ std::string* _name; /* name of the tensor, used for solving dependency */
+ std::string* _npy_filename; /* numpy array filename if not null. so null is the distinguisher */
+};
+
+class TosaSerializationOperator
+{
+public:
+ // use default copy, void constructor
+ // constructor and destructor
+ TosaSerializationOperator(Op op_name,
+ Attribute attribute_type,
+ const TosaAttributeBase* attribute,
+ QuantInfo qinfo_type,
+ const TosaQuantInfoBase* qinfo,
+ std::vector<std::string> input_tensor_names,
+ std::vector<std::string> output_tensor_names);
+ ~TosaSerializationOperator();
+
+ // accessor
+ Op GetOp() const
+ {
+ return _op;
+ }
+ Attribute GetAttributeType() const
+ {
+ return _attribute_type;
+ }
+ TosaAttributeBase* GetAttribute() const
+ {
+ return _attribute;
+ }
+ QuantInfo GetQInfoType() const
+ {
+ return _qinfo_type;
+ }
+ TosaQuantInfoBase* GetQInfo() const
+ {
+ return _qinfo;
+ }
+ std::vector<std::string>& GetInputTensorNames() const
+ {
+ return *_input_tensor_names;
+ }
+ std::vector<std::string>& GetOutputTensorNames() const
+ {
+ return *_output_tensor_names;
+ }
+ std::vector<TosaSerializationTensor*>& GetInputTensors() const
+ {
+ return *_input_tensors;
+ }
+ std::vector<TosaSerializationTensor*>& GetOutputTensors() const
+ {
+ return *_output_tensors;
+ }
+
+private:
+ Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */
+ Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */
+ TosaAttributeBase* _attribute; /* real attribute class goes here */
+ QuantInfo _qinfo_type; /* QuantInfo enum */
+ TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */
+ std::vector<std::string>* _input_tensor_names; /* array of input tensor names */
+ std::vector<std::string>* _output_tensor_names; /* array of output tensor names */
+
+ std::vector<TosaSerializationTensor*>* _input_tensors; /* array of input TosaSerializationTensor */
+ std::vector<TosaSerializationTensor*>* _output_tensors; /* array of output TosaSerializationTensor */
+};
+
+class TosaSerializationBasicBlock
+{
+public:
+ // constructor and destructor
+ TosaSerializationBasicBlock(std::string name,
+ std::vector<TosaSerializationOperator*> operators,
+ std::vector<TosaSerializationTensor*> tensors,
+ std::vector<std::string> inputs,
+ std::vector<std::string> outputs);
+ ~TosaSerializationBasicBlock();
+
+ // accessor
+ std::string GetName() const
+ {
+ return *_name;
+ }
+ std::vector<TosaSerializationOperator*>& GetOperators()
+ {
+ return *_operators;
+ }
+ std::vector<TosaSerializationTensor*>& GetTensors()
+ {
+ return *_tensors;
+ }
+
+ TosaSerializationTensor* GetTensorByName(std::string name)
+ {
+ TosaSerializationTensor* result = nullptr;
+ for (auto tensor : GetTensors())
+ {
+ if (tensor->GetName() == name)
+ {
+ result = tensor;
+ break;
+ }
+ }
+ return result;
+ }
+
+ std::vector<std::string>& GetInputs()
+ {
+ return *_inputs;
+ }
+ std::vector<std::string>& GetOutputs()
+ {
+ return *_outputs;
+ }
+
+private:
+ std::string* _name; /* name of basic block */
+ std::vector<TosaSerializationOperator*>* _operators; /* TosaSerializationOperator list */
+ std::vector<TosaSerializationTensor*>* _tensors; /* TosaSerializationTensor list */
+ std::vector<std::string>* _inputs; /* array of string to specify block inputs */
+ std::vector<std::string>* _outputs; /* array of string to specify block outputs */
+};
+
+/*
+ * this is a helper class for writing/reading Tosa ISA
+ * supported format: .tosa (flatbuffer), .json
+ * and provide high-level std::vector-like interface
+ * to access internal data structure
+ */
+class TosaSerializationHandler
+{
+public:
+ // constructor and destructor
+ TosaSerializationHandler();
+ ~TosaSerializationHandler();
+
+ // file io
+ tosa_err_t LoadFileJson(const char* filename);
+ tosa_err_t LoadFileTosaFlatbuffer(const char* filename);
+ tosa_err_t SaveFileJson(const char* filename);
+ tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
+ tosa_err_t LoadFileSchema(const char* filename);
+
+ // version
+ TosaVersion* GetTosaVersion() const
+ {
+ return _version;
+ }
+
+ // accessor
+ std::vector<TosaSerializationBasicBlock*>& GetBlocks()
+ {
+ return *_blocks;
+ }
+
+ TosaSerializationBasicBlock* GetBlockByName(std::string name)
+ {
+ TosaSerializationBasicBlock* result = nullptr;
+ for (auto block : GetBlocks())
+ {
+ if (block->GetName() == name)
+ {
+ result = block;
+ break;
+ }
+ }
+ return result;
+ }
+ TosaSerializationBasicBlock* GetMainBlock()
+ {
+ TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main"));
+ assert(main_block);
+ return main_block;
+ }
+
+ std::vector<std::string>& GetInputs()
+ {
+ return GetMainBlock()->GetInputs();
+ }
+ std::vector<std::string>& GetOutputs()
+ {
+ return GetMainBlock()->GetOutputs();
+ }
+
+ bool GetSchemaLoaded() const
+ {
+ return _schemaLoaded;
+ }
+
+protected:
+ tosa_err_t Clear();
+ tosa_err_t InitWithBuf(const uint8_t* buf);
+ tosa_err_t FreezeBuilder();
+ tosa_err_t SetTosaVersion();
+ tosa_err_t CheckTosaVersion(const TosaVersion& read_version);
+
+private:
+ TosaVersion* _version; /* tosa version */
+ flatbuffers::FlatBufferBuilder* _builder; /* flatbuffer builder */
+ flatbuffers::Parser* _parser; /* flatbuffer parser, used for json parsing */
+ std::vector<TosaSerializationBasicBlock*>* _blocks; /* array structure to store all TosaSerializationBasicBlock */
+ bool _schemaLoaded; /* is the schema properly loaded? */
+};
+
+class NumpyUtilities
+{
+public:
+ enum NPError
+ {
+ NO_ERROR = 0,
+ FILE_NOT_FOUND,
+ FILE_IO_ERROR,
+ FILE_TYPE_MISMATCH,
+ HEADER_PARSE_ERROR,
+ BUFFER_SIZE_MISMATCH,
+ };
+
+ static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* buf);
+
+ static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* buf);
+
+ static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* buf);
+
+ static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* buf);
+
+ static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* buf);
+
+ static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* buf);
+
+ static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* buf);
+
+ static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* buf);
+
+ static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* buf);
+
+ static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* buf);
+
+ static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* buf);
+
+ static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* buf);
+
+private:
+ static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str);
+ static NPError writeNpyHeader(FILE* infile, const std::vector<int32_t>& shape, const char* dtype_str);
+};
+
+} // namespace tosa
+
+#endif // _TOSA_SERIALIZATION_HANDLER_H