From 2364dcd7241d730021bf68e000e5a6411b9f09d1 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Mon, 26 Apr 2021 11:06:57 -0700 Subject: Initial commit of serialization library code Change-Id: Ie09a7245176aa799e59622e5118b145833b23590 Signed-off-by: Eric Kunze --- .clang-format | 34 + .gitmodules | 3 + .pre-commit-config.yaml | 35 + CMakeLists.txt | 53 + README.md | 160 ++ include/attribute.def | 98 + include/attribute.h | 181 ++ include/numpy_utils.h | 81 + include/operator.def | 124 ++ include/quant_info.def | 43 + include/quant_info.h | 164 ++ include/tosa_generated.h | 2649 ++++++++++++++++++++++++++ include/tosa_serialization_handler.h | 349 ++++ python/tosa/ArithmeticRightShiftAttribute.py | 51 + python/tosa/Attribute.py | 37 + python/tosa/AxisAttribute.py | 51 + python/tosa/ClampAttribute.py | 75 + python/tosa/CondIfAttribute.py | 59 + python/tosa/Conv2dAttribute.py | 130 ++ python/tosa/ConvQuantInfo.py | 59 + python/tosa/DType.py | 30 + python/tosa/MatMulQuantInfo.py | 59 + python/tosa/MulAttribute.py | 51 + python/tosa/Op.py | 91 + python/tosa/PadQuantInfo.py | 51 + python/tosa/Pool2dAttribute.py | 130 ++ python/tosa/QuantInfo.py | 26 + python/tosa/ReluNAttribute.py | 59 + python/tosa/RescaleAttribute.py | 141 ++ python/tosa/ReshapeAttribute.py | 72 + python/tosa/ResizeAttribute.py | 204 ++ python/tosa/ResizeMode.py | 24 + python/tosa/SliceAttribute.py | 101 + python/tosa/TileAttribute.py | 72 + python/tosa/TosaBasicBlock.py | 149 ++ python/tosa/TosaGraph.py | 82 + python/tosa/TosaOperator.py | 133 ++ python/tosa/TosaTensor.py | 96 + python/tosa/TransposeConv2dAttribute.py | 159 ++ python/tosa/UnaryQuantInfo.py | 59 + python/tosa/Version.py | 75 + python/tosa/WhileLoopAttribute.py | 59 + python/tosa/__init__.py | 15 + regenerate_headers.sh | 35 + schema/tosa.fbs | 307 +++ src/numpy_utils.cpp | 415 ++++ src/tosa_serialization_handler.cpp | 762 ++++++++ test/scripts/test_npy_fileio.py | 152 ++ test/scripts/test_serialization.py | 197 ++ test/scripts/testfiles/test.tosa | Bin 0 -> 544 bytes test/scripts/xunit/xunit.py | 109 ++ test/src/serialization_npy_test.cpp | 225 +++ test/src/serialization_read_write.cpp | 50 + third_party/CMakeLists.txt | 25 + third_party/flatbuffers | 1 + 55 files changed, 8652 insertions(+) create mode 100644 .clang-format create mode 100644 .gitmodules create mode 100644 .pre-commit-config.yaml create mode 100644 CMakeLists.txt create mode 100644 README.md create mode 100644 include/attribute.def create mode 100644 include/attribute.h create mode 100644 include/numpy_utils.h create mode 100644 include/operator.def create mode 100644 include/quant_info.def create mode 100644 include/quant_info.h create mode 100644 include/tosa_generated.h create mode 100644 include/tosa_serialization_handler.h create mode 100644 python/tosa/ArithmeticRightShiftAttribute.py create mode 100644 python/tosa/Attribute.py create mode 100644 python/tosa/AxisAttribute.py create mode 100644 python/tosa/ClampAttribute.py create mode 100644 python/tosa/CondIfAttribute.py create mode 100644 python/tosa/Conv2dAttribute.py create mode 100644 python/tosa/ConvQuantInfo.py create mode 100644 python/tosa/DType.py create mode 100644 python/tosa/MatMulQuantInfo.py create mode 100644 python/tosa/MulAttribute.py create mode 100644 python/tosa/Op.py create mode 100644 python/tosa/PadQuantInfo.py create mode 100644 python/tosa/Pool2dAttribute.py create mode 100644 python/tosa/QuantInfo.py create mode 100644 python/tosa/ReluNAttribute.py create mode 100644 python/tosa/RescaleAttribute.py create mode 100644 python/tosa/ReshapeAttribute.py create mode 100644 python/tosa/ResizeAttribute.py create mode 100644 python/tosa/ResizeMode.py create mode 100644 python/tosa/SliceAttribute.py create mode 100644 python/tosa/TileAttribute.py create mode 100644 python/tosa/TosaBasicBlock.py create mode 100644 python/tosa/TosaGraph.py create mode 100644 python/tosa/TosaOperator.py create mode 100644 python/tosa/TosaTensor.py create mode 100644 python/tosa/TransposeConv2dAttribute.py create mode 100644 python/tosa/UnaryQuantInfo.py create mode 100644 python/tosa/Version.py create mode 100644 python/tosa/WhileLoopAttribute.py create mode 100644 python/tosa/__init__.py create mode 100755 regenerate_headers.sh create mode 100644 schema/tosa.fbs create mode 100644 src/numpy_utils.cpp create mode 100644 src/tosa_serialization_handler.cpp create mode 100755 test/scripts/test_npy_fileio.py create mode 100755 test/scripts/test_serialization.py create mode 100644 test/scripts/testfiles/test.tosa create mode 100644 test/scripts/xunit/xunit.py create mode 100644 test/src/serialization_npy_test.cpp create mode 100644 test/src/serialization_read_write.cpp create mode 100644 third_party/CMakeLists.txt create mode 160000 third_party/flatbuffers diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..66b0148 --- /dev/null +++ b/.clang-format @@ -0,0 +1,34 @@ +BasedOnStyle: LLVM +AccessModifierOffset: -4 +AllowShortFunctionsOnASingleLine: None +AlwaysBreakTemplateDeclarations: true +BinPackParameters: false +BraceWrapping: + AfterClass: true + AfterControlStatement: true + AfterEnum: true + AfterFunction: true + AfterNamespace: true + AfterObjCDeclaration: true + AfterStruct: true + AfterUnion: true + AfterExternBlock: true + BeforeCatch: true + BeforeElse: true + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: true +BreakBeforeBraces: Custom +BreakConstructorInitializersBeforeComma: true +BreakConstructorInitializers: BeforeColon +Cpp11BracedListStyle: false +IndentCaseLabels: true +IndentWidth: 4 +IndentWrappedFunctionNames: true +PointerAlignment: Left +SpacesInContainerLiterals: false +AlignConsecutiveAssignments: true +ColumnLimit: 120 +ReflowComments: false +SpacesBeforeTrailingComments: 4 diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..fa78bd5 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/flatbuffers"] + path = third_party/flatbuffers + url = https://github.com/google/flatbuffers.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..55d630f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,35 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +#- repo: https://github.com/asottile/reorder_python_imports +# rev: v2.2.0 +# hooks: +# - id: reorder-python-imports + + #- repo: https://github.com/psf/black + #j.rev: 20.8b1 + #j.hooks: + #j.- id: black + +- repo: https://gitlab.com/pycqa/flake8 + rev: 3.7.9 + hooks: + - id: flake8 + exclude: python/tosa + args: [--max-line-length=88, --extend-ignore=E203] + +- repo: https://github.com/psf/black + rev: 20.8b1 + hooks: + - id: black + exclude: python/tosa + +- repo: local + hooks: + - id: clang-format + name: clang-format + exclude: tosa_generated.h|build|third_party + language: system + entry: clang-format + types: ["c++"] + args: ["-i"] \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..44af9c0 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,53 @@ +#TOSA serialization library + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Contains TOSA flatbuffer serialization library content. + +cmake_minimum_required(VERSION 3.13.4) +project(TosaSerialization) + +set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ standard to conform to") +set(CMAKE_CXX_STANDARD_REQUIRED YES) + +set(CMAKE_VERBOSE_MAKEFILE ON) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${PROJECT_SOURCE_DIR}/third_party/flatbuffers/include) + +add_library(tosa_serialization_lib STATIC + src/tosa_serialization_handler.cpp + src/numpy_utils.cpp + ) + +add_subdirectory(third_party) + +add_executable(serialization_read_write + test/src/serialization_read_write.cpp +) + +target_link_libraries(serialization_read_write + tosa_serialization_lib + flatbuffers +) + +add_executable(serialization_npy_test + test/src/serialization_npy_test.cpp +) + +target_link_libraries(serialization_npy_test + tosa_serialization_lib + flatbuffers +) diff --git a/README.md b/README.md new file mode 100644 index 0000000..979b76f --- /dev/null +++ b/README.md @@ -0,0 +1,160 @@ +TOSA Serialization Library +========================== + +# Introduction + +The *TOSA Serialization library* provides methods to read and write serialized +TOSA graphs (). The library includes +a FlatBuffers schema and a C++ API for reading and writing TOSA graphs. + +# Usage + +The section below describes serialization_lib API usage. For more +details, please refer to `include/tosa_serialization_handler.h`. + +## TosaSerializationHandler + +This is the top-level class that contains the entire TOSA graph. In +particular, it contains a vector of `TosaSerializationBasicBlock` objects, +and provides API for file IO, block access, and version checking. + + a. `LoadFileJson(filename)`: + + Load json-formatted file "filename" from disk, and initialize the + internal graph structure. + + Requires the schema file to be loaded via `LoadFileSchema()`. + + b. `SaveFileJson(filename)`: + + Snapshots the internal graph structure and saves out JSON-formatted file + `filename` to disk. + Requires the schema file to be loaded via `LoadFileSchema()`. + + c. `LoadFileTosaFlatbuffer(filename)`: + + Load serialized flatbuffer file "filename" from disk, and initialize the + internal graph structure. + + d. `SaveFileTosaFlatbuffer(filename)`: + + Snapshot the internal graph structure and saves out serialized + flatbuffer file `filename` to disk. + + e. `GetTosaVersion()`: + + Return TOSA version implemented by the serialization library. + + f. `GetBlockByName()`: + + Return vector of `TosaSerializationBasicBlock`. Returns `nullptr` if + nothing matches. A valid graph must have one `main` block as the first + block being traversed. + + g. `GetMainBlock()`: + + Shortcut for accessing the `main` block. Equivalent to + `GetBlockByName("main")`. + + h. `GetInputs()` / `GetOutputs()`: + + Shortcut for `main` block's input/output tensor name. Input tensors of + the main block are usually treated as `tosa.PLACEHOLDER`. Output tensors + are the output of the entire graph and should be evaluated when graph + traversal has finished. + +## TosaSerializationBasicBlock + +This is the basic-block class. It contains vectors of +`TosaSerializationOperator` and `TosaSerializationTensor`. Once entering +a basic block, all of the operators within the block will be evaluated +in order. + +Upon reaching a TOSA control flow operator (`tosa.WHILE` and +`tosa.COND_IF`), the status of current unfinished block will be saved, and +the blocks specified in control flow operator will be evaluated first. Once +the control flow blocks finishes its evaluation, the original unfinished +block status will be restored and evaluation continues. This is more +analogous to a function call than a compiler basic block. + + a. `GetName()`: + + Return string of the basic block. + + b. `GetOperators()`: + + Return vector of `TosaSerializationOperator` + + c. `GetTensors()`: + + Return vector of `TosaSerializationTensor` + + d. `GetTensorByName(name)`: + + Return the `TosaSerializationTensor` with name `name`. Returns `nullptr` + if nothing matches. + + e. `GetInputs()` / `GetOutputs()`: + + Return input/output tensor name of the basic block. + +## TosaSerializationOperator + +The operator class contains (1) what TOSA Op, (2) attribute (compile-time- +known input), (3) quantization information and (4) input/output tensor +names. The combination of (Op, attribute, quantization information) is +defined in include/operator.def. + + a. `GetOp()`: + + Return TOSA Op. Defined in schema `tosa.fbs`. + + b. `GetAttribute()` / `GetAttributeType()`: + + `GetAttribute()` returns the base object of attribute. + `GetAttributeType()` returns which type of attribute the base object + needs to be casted to. Type of attribute is defined in `tosa.fbs` and + `include/attribute.def`. + + c. `GetQuantInfo()` + `GetQuantInfoType()`: + + `GetQuantInfo()` returns the base object's quantization information. + `GetQuantInfoType()` returns which type of quantization information the + base object needs to be casted to. Type of quantization information is + defined in `tosa.fbs` and `include/quant_info.def`. + + d. `GetInputTensorNames()` / `GetOutputTensorNames()`: + + Returns the input/output tensor name of the basic block. + + e. `GetInputTensors()` / `GetOutputTensors()`: + + Returns the input/output tensor of the basic block. + +## TosaSerializationTensor + +The tensor class contains (1) data type, (2) shape, (3) symbolic link to +numpy file, (4) format and (5) usage. + + a. `GetName()` / `SetName(name)`: + + `GetName()` returns the name of the tensor. `SetName()` sets the name + of the tensor. + + b. `GetShape()`: + + Returns the shape of the tensor as `vector`. + + c. `GetDtype()` / `SetDtype(dtype)`: + + `GetDtype()` returns the data type of the tensor. `SetDtype()` sets the + data type of the tensor. DType is defined in `tosa.fbs`. + + d. `GetNpyFilePtr()`: + + Return the numpy file pointer of this tensor if this is a constant + tensor. Return `nullptr` if the tensor is not constant. + +# License + +The *TOSA Serialization Library* is licensed under Apache-2.0. diff --git a/include/attribute.def b/include/attribute.def new file mode 100644 index 0000000..12b9c96 --- /dev/null +++ b/include/attribute.def @@ -0,0 +1,98 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + Syntax: + DEF_ATTRIBUTE(ATTRIBUTE_NAME, NUM_ARGS_IN_ATTRIBUTES, ARG0_TYPE, ARG0_SCALAR_OR_VECTOR, ARGS0_NAME, ...) + + Description: + ATTRIBUTE_NAME: corresponding attribute name, must match corresponding "table XXXAttribute" in tosa.fbs + NUM_ARGS_IN_ATTRIBUTES: number of arguments in this attribute + ARG0_TYPE: data type of arg0 in attribute + ARG0_SCALAR_OR_VECTOR: is arg0 a scalar(S) or a vector(V) + ARG0_NAME: name of arg0 + ...: variadic variables for more arguments, depending on NUM_ARGS_IN_ATTRIBUTES +*/ + +DEF_ATTRIBUTE(Pool2d, 3, + int32_t, V, padding, + int32_t, V, kernel, + int32_t, V, stride) + +DEF_ATTRIBUTE(Conv2d, 3, + int32_t, V, padding, + int32_t, V, stride, + int32_t, V, dilation) + +DEF_ATTRIBUTE(TransposeConv2d, 4, + int32_t, V, outpad, + int32_t, V, stride, + int32_t, V, dilation, + int32_t, V, output_shape) + +DEF_ATTRIBUTE(ReluN, 2, + int32_t, S, max_int, + float, S, max_fp) + +DEF_ATTRIBUTE(Axis, 1, + int32_t, S, axis) + +DEF_ATTRIBUTE(Reshape, 1, + int32_t, V, shape) + +DEF_ATTRIBUTE(Slice, 2, + int32_t, V, begin, + int32_t, V, size) + +DEF_ATTRIBUTE(Tile, 1, + int32_t, V, multiples) + +DEF_ATTRIBUTE(Resize, 7, + int32_t, V, output_size, + int32_t, V, stride, + int32_t, V, offset, + int32_t, S, shift, + float, V, stride_fp, + float, V, offset_fp, + ResizeMode, S, mode) + +DEF_ATTRIBUTE(Clamp, 4, + int32_t, S, min_int, + int32_t, S, max_int, + float, S, min_fp, + float, S, max_fp) + +DEF_ATTRIBUTE(Rescale, 7, + int32_t, S, input_zp, + int32_t, S, output_zp, + int32_t, V, multiplier, + int32_t, V, shift, + bool, S, scale32, + bool, S, double_round, + bool, S, per_channel) + +DEF_ATTRIBUTE(Mul, 1, + int32_t, S, shift) + +DEF_ATTRIBUTE(ArithmeticRightShift, 1, + bool, S, round) + +DEF_ATTRIBUTE(CondIf, 2, + string, S, then_branch, + string, S, else_branch) + +DEF_ATTRIBUTE(WhileLoop, 2, + string, S, cond_branch, + string, S, body_branch) diff --git a/include/attribute.h b/include/attribute.h new file mode 100644 index 0000000..ff354cb --- /dev/null +++ b/include/attribute.h @@ -0,0 +1,181 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _TOSA_SERIALIZATION_ATTRIBUTE_H +#define _TOSA_SERIALIZATION_ATTRIBUTE_H +#include "flatbuffers/idl.h" +#include "flatbuffers/util.h" +#include "tosa_generated.h" + +using std::string; + +namespace tosa +{ + +class TosaAttributeBase +{ +public: + virtual ~TosaAttributeBase() + {} +}; + +class TosaNoneAttribute : public TosaAttributeBase +{ +public: + TosaNoneAttribute() + {} + TosaNoneAttribute(TosaNoneAttribute* p) + {} +}; + +#define DEF_ARGS_VER0_S_STR(V) _##V = p->V()->str(); +#define DEF_ARGS_VER0_S_DEFAULT(V) _##V = p->V(); + +#define DEF_ARGS_VER0_S_int32_t(V) DEF_ARGS_VER0_S_DEFAULT(V) +#define DEF_ARGS_VER0_S_float(V) DEF_ARGS_VER0_S_DEFAULT(V) +#define DEF_ARGS_VER0_S_bool(V) DEF_ARGS_VER0_S_DEFAULT(V) +#define DEF_ARGS_VER0_S_ResizeMode(V) DEF_ARGS_VER0_S_DEFAULT(V) +#define DEF_ARGS_VER0_S_string(V) DEF_ARGS_VER0_S_STR(V) + +#define DEF_ARGS_VER0_S(T, V) DEF_ARGS_VER0_S_##T(V) +#define DEF_ARGS_VER0_V(T, V) _##V = std::vector(p->V()->begin(), p->V()->end()); + +#define DEF_ARGS_VER1_S(T, V) const T& V +#define DEF_ARGS_VER1_V(T, V) const std::vector& V +#define DEF_ARGS_VER2_S(T, V) _##V = V; +#define DEF_ARGS_VER2_V(T, V) _##V = V; +#define DEF_ARGS_VER3_S(T, V) \ + T V() const \ + { \ + return _##V; \ + } +#define DEF_ARGS_VER3_V(T, V) \ + std::vector V() const \ + { \ + return _##V; \ + } +#define DEF_ARGS_VER4_S(T, V) T _##V; +#define DEF_ARGS_VER4_V(T, V) std::vector _##V; + +// another level of preprocessor indirection to handle ", " as function's input argument +#define DEF_ARGS_VER1_TRUE(T, F, V) DEF_ARGS_VER1_##F(T, V) +#define DEF_ARGS_VER1_FALSE(T, F, V) , DEF_ARGS_VER1_##F(T, V) + +#define DEF_ARGS_VER0(FIRST, T, F, V) DEF_ARGS_VER0_##F(T, V) +#define DEF_ARGS_VER1(FIRST, T, F, V) DEF_ARGS_VER1_##FIRST(T, F, V) +#define DEF_ARGS_VER2(FIRST, T, F, V) DEF_ARGS_VER2_##F(T, V) +#define DEF_ARGS_VER3(FIRST, T, F, V) DEF_ARGS_VER3_##F(T, V) +#define DEF_ARGS_VER4(FIRST, T, F, V) DEF_ARGS_VER4_##F(T, V) + +#define DEF_ARGS_0(VER, ...) +#define DEF_ARGS_1(VER, T0, F0, V0) DEF_ARGS_##VER(TRUE, T0, F0, V0) +#define DEF_ARGS_2(VER, T0, F0, V0, T1, F1, V1) DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) +#define DEF_ARGS_3(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) +#define DEF_ARGS_4(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) +#define DEF_ARGS_5(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) + +#define DEF_ARGS_6(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) + +#define DEF_ARGS_7(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ + DEF_ARGS_##VER(FALSE, T6, F6, V6) + +#define DEF_VER0_VAR_DECL_PTR(NAME) const NAME* p = static_cast(options); +#define DEF_VER0_VAR_0(NAME) +#define DEF_VER0_VAR_1(NAME) DEF_VER0_VAR_DECL_PTR(NAME) +#define DEF_VER0_VAR_2(NAME) DEF_VER0_VAR_DECL_PTR(NAME) +#define DEF_VER0_VAR_3(NAME) DEF_VER0_VAR_DECL_PTR(NAME) +#define DEF_VER0_VAR_4(NAME) DEF_VER0_VAR_DECL_PTR(NAME) +#define DEF_VER0_VAR_5(NAME) DEF_VER0_VAR_DECL_PTR(NAME) +#define DEF_VER0_VAR_6(NAME) DEF_VER0_VAR_DECL_PTR(NAME) +#define DEF_VER0_VAR_7(NAME) DEF_VER0_VAR_DECL_PTR(NAME) + +#define DEF_ATTRIBUTE(NAME, NUM_ARGS, ...) \ + class Tosa##NAME##Attribute : public TosaAttributeBase \ + { \ + public: \ + Tosa##NAME##Attribute(const TosaAttributeBase* options) \ + { \ + const Tosa##NAME##Attribute* p = reinterpret_cast(options); \ + *this = *p; \ + } \ + Tosa##NAME##Attribute(const Tosa##NAME##Attribute* p) \ + { \ + *this = *p; \ + } \ + Tosa##NAME##Attribute(const void* options){ DEF_VER0_VAR_##NUM_ARGS(NAME##Attribute) \ + DEF_ARGS_##NUM_ARGS(VER0, __VA_ARGS__) } Tosa##NAME \ + ##Attribute(DEF_ARGS_##NUM_ARGS(VER1, __VA_ARGS__)) \ + { \ + DEF_ARGS_##NUM_ARGS(VER2, __VA_ARGS__) \ + } \ + virtual ~Tosa##NAME##Attribute() \ + {} \ + DEF_ARGS_##NUM_ARGS(VER3, __VA_ARGS__) private : DEF_ARGS_##NUM_ARGS(VER4, __VA_ARGS__) \ + }; + +#include "attribute.def" +#undef DEF_ATTRIBUTE +#undef DEF_ARGS_0 +#undef DEF_ARGS_1 +#undef DEF_ARGS_2 +#undef DEF_ARGS_3 +#undef DEF_ARGS_4 +#undef DEF_ARGS_5 +#undef DEF_ARGS_6 +#undef DEF_ARGS_7 +#undef DEF_ARGS_VER0 +#undef DEF_ARGS_VER1 +#undef DEF_ARGS_VER2 +#undef DEF_ARGS_VER3 +#undef DEF_ARGS_VER4 +#undef DEF_ARGS_VER0_S_int32_t +#undef DEF_ARGS_VER0_S_float +#undef DEF_ARGS_VER0_S_bool +#undef DEF_ARGS_VER0_S_ResizeMode +#undef DEF_ARGS_VER0_S_string +#undef DEF_ARGS_VER0_S_STR +#undef DEF_ARGS_VER0_S_DEFAULT +#undef DEF_ARGS_VER1_TRUE +#undef DEF_ARGS_VER1_FALSE +#undef DEF_ARGS_VER0_S +#undef DEF_ARGS_VER0_V +#undef DEF_ARGS_VER1_S +#undef DEF_ARGS_VER1_V +#undef DEF_ARGS_VER2_S +#undef DEF_ARGS_VER2_V +#undef DEF_ARGS_VER3_S +#undef DEF_ARGS_VER3_V +#undef DEF_ARGS_VER4_S +#undef DEF_ARGS_VER4_V +#undef DEF_VER0_VAR_0 +#undef DEF_VER0_VAR_1 +#undef DEF_VER0_VAR_2 +#undef DEF_VER0_VAR_3 +#undef DEF_VER0_VAR_4 +#undef DEF_VER0_VAR_5 +#undef DEF_VER0_VAR_DECL_PTR + +} // namespace tosa + +#endif // _TOSA_SERIALIZATION_ATTRIBUTE_H diff --git a/include/numpy_utils.h b/include/numpy_utils.h new file mode 100644 index 0000000..c64bc17 --- /dev/null +++ b/include/numpy_utils.h @@ -0,0 +1,81 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _TOSA_NUMPY_UTILS_H +#define _TOSA_NUMPY_UTILS_H + +#include +#include +#include +#include +#include +#include +#include + +class NumpyUtilities +{ +public: + enum NPError + { + NO_ERROR = 0, + FILE_NOT_FOUND, + FILE_IO_ERROR, + FILE_TYPE_MISMATCH, + HEADER_PARSE_ERROR, + BUFFER_SIZE_MISMATCH, + }; + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf); + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf); + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf); + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf); + + static NPError writeToNpyFile(const char* filename, const std::vector& shape, const bool* databuf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf); + + static NPError writeToNpyFile(const char* filename, const std::vector& shape, const int32_t* databuf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf); + + static NPError writeToNpyFile(const char* filename, const std::vector& shape, const int64_t* databuf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf); + + static NPError writeToNpyFile(const char* filename, const std::vector& shape, const float* databuf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf); + +private: + static NPError writeToNpyFileCommon(const char* filename, + const char* dtype_str, + const size_t elementsize, + const std::vector& shape, + const void* databuf, + bool bool_translate); + static NPError readFromNpyFileCommon(const char* filename, + const char* dtype_str, + const size_t elementsize, + const uint32_t elems, + void* databuf, + bool bool_translate); + static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str); + static NPError writeNpyHeader(FILE* outfile, const std::vector& shape, const char* dtype_str); +}; + +#endif // _TOSA_NUMPY_UTILS_H diff --git a/include/operator.def b/include/operator.def new file mode 100644 index 0000000..80ae547 --- /dev/null +++ b/include/operator.def @@ -0,0 +1,124 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + Syntax: + DEF_OPERATOR(MLIR_NAME, SCHEMA_NAME, REF_IMPL_NAME, OPTIONS, QUANT_INFO) + + Description: + MLIR_NAME: the symbolic string of this op, must match tosa_ops.td + SCHEMA_NAME: corresponding operator name, must match "enum Op" in serialization/tosa.fbs + REF_IMPL_NAME: name used internally in tosa reference implementation + OPTIONS: compile time constant options of this op, corresponding to operator_option.def + QUANT_INFO: quantization infomation of this op, corresponding to quant_info.def +*/ + + +/* tensor operators */ +DEF_OPERATOR(argmax, ARGMAX, ArgMax, Axis, None) +DEF_OPERATOR(avg_pool2d, AVG_POOL2D, AvgPool2d, Pool2d, Unary) +DEF_OPERATOR(conv2d, CONV2D, Conv2d, Conv2d, Conv) +DEF_OPERATOR(conv3d, CONV3D, Conv3d, None, None) +DEF_OPERATOR(depthwise_conv2d, DEPTHWISE_CONV2D, DepthwiseConv2d, Conv2d, Conv) +DEF_OPERATOR(fully_connected, FULLY_CONNECTED, FullyConnected, None, Conv) +DEF_OPERATOR(matmul, MATMUL, MatMul, None, MatMul) +DEF_OPERATOR(max_pool2d, MAX_POOL2D, MaxPool2d, Pool2d, None) +DEF_OPERATOR(transpose_conv2d, TRANSPOSE_CONV2D, TransposeConv2d, TransposeConv2d, Conv) + +/* activation */ +DEF_OPERATOR(clamp, CLAMP, Clamp, Clamp, None) +DEF_OPERATOR(reluN, RELUN, ReluN, ReluN, None) +DEF_OPERATOR(sigmoid, SIGMOID, Sigmoid, None, None) +DEF_OPERATOR(tanh, TANH, Tanh, None, None) + +/* elementwise - binary */ +DEF_OPERATOR(add, ADD, Add, None, None) +DEF_OPERATOR(arithmetic_right_shift, ARITHMETIC_RIGHT_SHIFT, ArithmeticRightShift, ArithmeticRightShift, None) +DEF_OPERATOR(bitwise_and, BITWISE_AND, BitwiseAnd, None, None) +DEF_OPERATOR(bitwise_or, BITWISE_OR, BitwiseOr, None, None) +DEF_OPERATOR(bitwise_xor, BITWISE_XOR, BitwiseXor, None, None) +DEF_OPERATOR(logical_and, LOGICAL_AND, LogicalAnd, None, None) +DEF_OPERATOR(logical_left_shift, LOGICAL_LEFT_SHIFT, LogicalLeftShift, None, None) +DEF_OPERATOR(logical_right_shift, LOGICAL_RIGHT_SHIFT, LogicalRightShift, None, None) +DEF_OPERATOR(logical_or, LOGICAL_OR, LogicalOr, None, None) +DEF_OPERATOR(logical_xor, LOGICAL_XOR, LogicalXor, None, None) +DEF_OPERATOR(maximum, MAXIMUM, Maximum, None, None) +DEF_OPERATOR(minimum, MINIMUM, Minimum, None, None) +DEF_OPERATOR(mul, MUL, Mul, Mul, None) +DEF_OPERATOR(pow, POW, Pow, None, None) +DEF_OPERATOR(sub, SUB, Sub, None, None) +DEF_OPERATOR(table, TABLE, Table, None, None) + +/* elementwise - unary */ +DEF_OPERATOR(abs, ABS, Abs, None, None) +DEF_OPERATOR(bitwise_not, BITWISE_NOT, BitwiseNot, None, None) +DEF_OPERATOR(ceil, CEIL, Ceil, None, None) +DEF_OPERATOR(clz, CLZ, Clz, None, None) +DEF_OPERATOR(exp, EXP, Exp, None, None) +DEF_OPERATOR(floor, FLOOR, Floor, None, None) +DEF_OPERATOR(log, LOG, Log, None, None) +DEF_OPERATOR(logical_not, LOGICAL_NOT, LogicalNot, None, None) +DEF_OPERATOR(negate, NEGATE, Negate, None, Unary) +DEF_OPERATOR(reciprocal, RECIPROCAL, Reciprocal, None, None) +DEF_OPERATOR(rsqrt, RSQRT, Rsqrt, None, None) + +/* elementwise - ternary */ +DEF_OPERATOR(select, SELECT, Select, None, None) + +/* logical */ +DEF_OPERATOR(equal, EQUAL, Equal, None, None) +DEF_OPERATOR(greater, GREATER, Greater, None, None) +DEF_OPERATOR(greater_equal, GREATER_EQUAL, GreaterEqual, None, None) + +/* reduction */ +DEF_OPERATOR(reduce_any, REDUCE_ANY, ReduceAny, Reduce, None) +DEF_OPERATOR(reduce_all, REDUCE_ALL, ReduceAll, Reduce, None) +DEF_OPERATOR(reduce_max, REDUCE_MAX, ReduceMax, Reduce, None) +DEF_OPERATOR(reduce_min, REDUCE_MIN, ReduceMin, Reduce, None) +DEF_OPERATOR(reduce_prod, REDUCE_PRODUCT, ReduceProduct, Reduce, None) +DEF_OPERATOR(reduce_sum, REDUCE_SUM, ReduceSum, Reduce, None) + +/* memory operation */ +DEF_OPERATOR(concat, CONCAT, Concat, Axis, None) +DEF_OPERATOR(pad, PAD, Pad, None, Pad) +DEF_OPERATOR(reshape, RESHAPE, Reshape, Reshape, None) +DEF_OPERATOR(reverse, REVERSE, Reverse, Reverse, None) +DEF_OPERATOR(slice, SLICE, Slice, Slice, None) +DEF_OPERATOR(tile, TILE, Tile, Tile, None) +DEF_OPERATOR(transpose, TRANSPOSE, Transpose, None, None) + +/* gather/scatter */ +DEF_OPERATOR(gather, GATHER, Gather, None, None) +DEF_OPERATOR(scatter, SCATTER, Scatter, None, None) + +/* image */ +DEF_OPERATOR(resize, RESIZE, Resize, Resize, None) + +/* quantization */ +DEF_OPERATOR(cast, CAST, Cast, None, None) +DEF_OPERATOR(rescale, RESCALE, Rescale, Rescale, None) + +/* data nodes */ +DEF_OPERATOR(const, CONST, Const, None, None) +DEF_OPERATOR(placeholder, PLACEHOLDER, Placeholder, None, None) +DEF_OPERATOR(identity, IDENTITY, Identity, None, None) +DEF_OPERATOR(identityn, IDENTITYN, IdentityN, None, None) + +/* custom operations */ +DEF_OPERATOR(custom, CUSTOM, Custom, None, None) + +/* control flow operators */ +DEF_OPERATOR(cond_if, COND_IF, CondIf, CondIf, None) +DEF_OPERATOR(while_loop, WHILE_LOOP, WhileLoop, WhileLoop, None) diff --git a/include/quant_info.def b/include/quant_info.def new file mode 100644 index 0000000..888c183 --- /dev/null +++ b/include/quant_info.def @@ -0,0 +1,43 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + Syntax: + DEF_QUANTIZATION_INFO(NAME, NUM_ARGS_IN_OPTIONS, ARG0_TYPE, ARG0_SCALAR_OR_VECTOR, ARGS0_NAME, ...) + + Description: + NAME: corresponding quantization info name, must match corresponding "table XXXQuantInfo" in tosa.fbs + NUM_ARGS_IN_QINFO: number of arguments in this quantization info + ARG0_TYPE: data type of arg0 + ARG0_SCALAR_OR_VECTOR: is arg0 a scalar (S) or a vector (V) + ARG0_NAME: name of arg0 + ...: variadic variables for more arguments, depending on NUM_ARGS_IN_QINFO +*/ + + +DEF_QUANTIZATION_INFO(Unary, 2, + int32_t, S, input_zp, + int32_t, S, output_zp) + +DEF_QUANTIZATION_INFO(Conv, 2, + int32_t, S, input_zp, + int32_t, S, weight_zp) + +DEF_QUANTIZATION_INFO(MatMul, 2, + int32_t, S, a_zp, + int32_t, S, b_zp) + +DEF_QUANTIZATION_INFO(Pad, 1, + int32_t, S, input_zp) diff --git a/include/quant_info.h b/include/quant_info.h new file mode 100644 index 0000000..d83063d --- /dev/null +++ b/include/quant_info.h @@ -0,0 +1,164 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _TOSA_SERIALIZATION_QUANT_INFO_H +#define _TOSA_SERIALIZATION_QUANT_INFO_H +#include "flatbuffers/idl.h" +#include "flatbuffers/util.h" +#include "tosa_generated.h" + +namespace tosa +{ + +class TosaQuantInfoBase +{ +public: + virtual ~TosaQuantInfoBase() + {} +}; + +class TosaNoneQuantInfo : public TosaQuantInfoBase +{ +public: + TosaNoneQuantInfo() + {} + TosaNoneQuantInfo(TosaNoneQuantInfo* p) + {} +}; + +#define DEF_ARGS_VER0_S(T, V) _##V = p->V(); +#define DEF_ARGS_VER0_V(T, V) _##V = std::vector(p->V()->begin(), p->V()->end()); +#define DEF_ARGS_VER1_S(T, V) const T& V +#define DEF_ARGS_VER1_V(T, V) const std::vector& V +#define DEF_ARGS_VER2_S(T, V) _##V = V; +#define DEF_ARGS_VER2_V(T, V) _##V = V; +#define DEF_ARGS_VER3_S(T, V) \ + T V() const \ + { \ + return _##V; \ + } +#define DEF_ARGS_VER3_V(T, V) \ + std::vector V() const \ + { \ + return _##V; \ + } +#define DEF_ARGS_VER4_S(T, V) T _##V; +#define DEF_ARGS_VER4_V(T, V) std::vector _##V; + +// another level of preprocessor indirection to handle ", " as function's input argument +#define DEF_ARGS_VER1_TRUE(T, F, V) DEF_ARGS_VER1_##F(T, V) +#define DEF_ARGS_VER1_FALSE(T, F, V) , DEF_ARGS_VER1_##F(T, V) + +#define DEF_ARGS_VER0(FIRST, T, F, V) DEF_ARGS_VER0_##F(T, V) +#define DEF_ARGS_VER1(FIRST, T, F, V) DEF_ARGS_VER1_##FIRST(T, F, V) +#define DEF_ARGS_VER2(FIRST, T, F, V) DEF_ARGS_VER2_##F(T, V) +#define DEF_ARGS_VER3(FIRST, T, F, V) DEF_ARGS_VER3_##F(T, V) +#define DEF_ARGS_VER4(FIRST, T, F, V) DEF_ARGS_VER4_##F(T, V) + +#define DEF_ARGS_1(VER, T0, F0, V0) DEF_ARGS_##VER(TRUE, T0, F0, V0) +#define DEF_ARGS_2(VER, T0, F0, V0, T1, F1, V1) DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) +#define DEF_ARGS_3(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) +#define DEF_ARGS_4(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) +#define DEF_ARGS_5(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) +#define DEF_ARGS_6(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) +#define DEF_ARGS_7(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ + DEF_ARGS_##VER(FALSE, T6, F6, V6) +#define DEF_ARGS_8(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ + DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) +#define DEF_ARGS_9(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7, T8, F8, V8) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ + DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) DEF_ARGS_##VER(FALSE, T8, F8, V8) +#define DEF_ARGS_10(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7, T8, F8, V8, T9, F9, V9) \ + DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \ + DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \ + DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) DEF_ARGS_##VER(FALSE, T8, F8, V8) \ + DEF_ARGS_##VER(FALSE, T9, F9, V9) + +#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \ + class Tosa##NAME##QuantInfo : public TosaQuantInfoBase \ + { \ + public: \ + Tosa##NAME##QuantInfo(const TosaQuantInfoBase* qinfo) \ + { \ + const Tosa##NAME##QuantInfo* p = dynamic_cast(qinfo); \ + assert(p); \ + *this = *p; \ + } \ + Tosa##NAME##QuantInfo(const Tosa##NAME##QuantInfo* p) \ + { \ + *this = *p; \ + } \ + Tosa##NAME##QuantInfo(const void* qinfo) \ + { \ + const NAME##QuantInfo* p = static_cast(qinfo); \ + DEF_ARGS_##NUM_ARGS(VER0, __VA_ARGS__) \ + } \ + Tosa##NAME##QuantInfo(DEF_ARGS_##NUM_ARGS(VER1, __VA_ARGS__)) \ + { \ + DEF_ARGS_##NUM_ARGS(VER2, __VA_ARGS__) \ + } \ + virtual ~Tosa##NAME##QuantInfo() \ + {} \ + DEF_ARGS_##NUM_ARGS(VER3, __VA_ARGS__) private : DEF_ARGS_##NUM_ARGS(VER4, __VA_ARGS__) \ + }; + +#include "quant_info.def" +#undef DEF_QUANTIZATION_INFO +#undef DEF_ARGS_1 +#undef DEF_ARGS_2 +#undef DEF_ARGS_3 +#undef DEF_ARGS_4 +#undef DEF_ARGS_5 +#undef DEF_ARGS_6 +#undef DEF_ARGS_7 +#undef DEF_ARGS_8 +#undef DEF_ARGS_9 +#undef DEF_ARGS_10 +#undef DEF_ARGS_VER0 +#undef DEF_ARGS_VER1 +#undef DEF_ARGS_VER2 +#undef DEF_ARGS_VER3 +#undef DEF_ARGS_VER4 +#undef DEF_ARGS_VER1_TRUE +#undef DEF_ARGS_VER1_FALSE +#undef DEF_ARGS_VER0_S +#undef DEF_ARGS_VER0_V +#undef DEF_ARGS_VER1_S +#undef DEF_ARGS_VER1_V +#undef DEF_ARGS_VER2_S +#undef DEF_ARGS_VER2_V +#undef DEF_ARGS_VER3_S +#undef DEF_ARGS_VER3_V +#undef DEF_ARGS_VER4_S +#undef DEF_ARGS_VER4_V + +} // namespace tosa + +#endif diff --git a/include/tosa_generated.h b/include/tosa_generated.h new file mode 100644 index 0000000..5e883a1 --- /dev/null +++ b/include/tosa_generated.h @@ -0,0 +1,2649 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_TOSA_TOSA_H_ +#define FLATBUFFERS_GENERATED_TOSA_TOSA_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace tosa { + +struct Pool2dAttribute; +struct Pool2dAttributeBuilder; + +struct Conv2dAttribute; +struct Conv2dAttributeBuilder; + +struct TransposeConv2dAttribute; +struct TransposeConv2dAttributeBuilder; + +struct ReluNAttribute; +struct ReluNAttributeBuilder; + +struct AxisAttribute; +struct AxisAttributeBuilder; + +struct ReshapeAttribute; +struct ReshapeAttributeBuilder; + +struct SliceAttribute; +struct SliceAttributeBuilder; + +struct TileAttribute; +struct TileAttributeBuilder; + +struct ResizeAttribute; +struct ResizeAttributeBuilder; + +struct ClampAttribute; +struct ClampAttributeBuilder; + +struct RescaleAttribute; +struct RescaleAttributeBuilder; + +struct MulAttribute; +struct MulAttributeBuilder; + +struct ArithmeticRightShiftAttribute; +struct ArithmeticRightShiftAttributeBuilder; + +struct CondIfAttribute; +struct CondIfAttributeBuilder; + +struct WhileLoopAttribute; +struct WhileLoopAttributeBuilder; + +struct UnaryQuantInfo; +struct UnaryQuantInfoBuilder; + +struct ConvQuantInfo; +struct ConvQuantInfoBuilder; + +struct MatMulQuantInfo; +struct MatMulQuantInfoBuilder; + +struct PadQuantInfo; +struct PadQuantInfoBuilder; + +struct Version; +struct VersionBuilder; + +struct TosaTensor; +struct TosaTensorBuilder; + +struct TosaOperator; +struct TosaOperatorBuilder; + +struct TosaBasicBlock; +struct TosaBasicBlockBuilder; + +struct TosaGraph; +struct TosaGraphBuilder; + +enum DType { + DType_UNKNOWN = 0, + DType_BOOL = 1, + DType_UINT8 = 2, + DType_INT4 = 3, + DType_INT8 = 4, + DType_INT16 = 5, + DType_INT32 = 6, + DType_INT48 = 7, + DType_FLOAT = 8, + DType_MIN = DType_UNKNOWN, + DType_MAX = DType_FLOAT +}; + +inline const DType (&EnumValuesDType())[9] { + static const DType values[] = { + DType_UNKNOWN, + DType_BOOL, + DType_UINT8, + DType_INT4, + DType_INT8, + DType_INT16, + DType_INT32, + DType_INT48, + DType_FLOAT + }; + return values; +} + +inline const char * const *EnumNamesDType() { + static const char * const names[10] = { + "UNKNOWN", + "BOOL", + "UINT8", + "INT4", + "INT8", + "INT16", + "INT32", + "INT48", + "FLOAT", + nullptr + }; + return names; +} + +inline const char *EnumNameDType(DType e) { + if (flatbuffers::IsOutRange(e, DType_UNKNOWN, DType_FLOAT)) return ""; + const size_t index = static_cast(e); + return EnumNamesDType()[index]; +} + +enum ResizeMode { + ResizeMode_UNKNOWN = 0, + ResizeMode_NEAREST = 1, + ResizeMode_BILINEAR = 2, + ResizeMode_MIN = ResizeMode_UNKNOWN, + ResizeMode_MAX = ResizeMode_BILINEAR +}; + +inline const ResizeMode (&EnumValuesResizeMode())[3] { + static const ResizeMode values[] = { + ResizeMode_UNKNOWN, + ResizeMode_NEAREST, + ResizeMode_BILINEAR + }; + return values; +} + +inline const char * const *EnumNamesResizeMode() { + static const char * const names[4] = { + "UNKNOWN", + "NEAREST", + "BILINEAR", + nullptr + }; + return names; +} + +inline const char *EnumNameResizeMode(ResizeMode e) { + if (flatbuffers::IsOutRange(e, ResizeMode_UNKNOWN, ResizeMode_BILINEAR)) return ""; + const size_t index = static_cast(e); + return EnumNamesResizeMode()[index]; +} + +enum Op { + Op_UNKNOWN = 0, + Op_ARGMAX = 1, + Op_AVG_POOL2D = 2, + Op_CONV2D = 3, + Op_CONV3D = 4, + Op_DEPTHWISE_CONV2D = 5, + Op_FULLY_CONNECTED = 6, + Op_MATMUL = 7, + Op_MAX_POOL2D = 8, + Op_TRANSPOSE_CONV2D = 9, + Op_CLAMP = 10, + Op_RELUN = 11, + Op_SIGMOID = 12, + Op_TANH = 13, + Op_ADD = 14, + Op_ARITHMETIC_RIGHT_SHIFT = 15, + Op_BITWISE_AND = 16, + Op_BITWISE_OR = 17, + Op_BITWISE_XOR = 18, + Op_LOGICAL_AND = 19, + Op_LOGICAL_LEFT_SHIFT = 20, + Op_LOGICAL_RIGHT_SHIFT = 21, + Op_LOGICAL_OR = 22, + Op_LOGICAL_XOR = 23, + Op_MAXIMUM = 24, + Op_MINIMUM = 25, + Op_MUL = 26, + Op_POW = 27, + Op_SUB = 28, + Op_TABLE = 29, + Op_ABS = 30, + Op_BITWISE_NOT = 31, + Op_CEIL = 32, + Op_CLZ = 33, + Op_EXP = 34, + Op_FLOOR = 35, + Op_LOG = 36, + Op_LOGICAL_NOT = 37, + Op_NEGATE = 38, + Op_RECIPROCAL = 39, + Op_RSQRT = 40, + Op_SELECT = 41, + Op_EQUAL = 42, + Op_GREATER = 43, + Op_GREATER_EQUAL = 44, + Op_REDUCE_ANY = 45, + Op_REDUCE_ALL = 46, + Op_REDUCE_MAX = 47, + Op_REDUCE_MIN = 48, + Op_REDUCE_PRODUCT = 49, + Op_REDUCE_SUM = 50, + Op_CONCAT = 51, + Op_PAD = 52, + Op_RESHAPE = 53, + Op_REVERSE = 54, + Op_SLICE = 55, + Op_TILE = 56, + Op_TRANSPOSE = 57, + Op_GATHER = 58, + Op_SCATTER = 59, + Op_RESIZE = 60, + Op_CAST = 61, + Op_RESCALE = 62, + Op_CONST = 63, + Op_PLACEHOLDER = 64, + Op_IDENTITY = 65, + Op_IDENTITYN = 66, + Op_CUSTOM = 67, + Op_COND_IF = 68, + Op_WHILE_LOOP = 69, + Op_MIN = Op_UNKNOWN, + Op_MAX = Op_WHILE_LOOP +}; + +inline const Op (&EnumValuesOp())[70] { + static const Op values[] = { + Op_UNKNOWN, + Op_ARGMAX, + Op_AVG_POOL2D, + Op_CONV2D, + Op_CONV3D, + Op_DEPTHWISE_CONV2D, + Op_FULLY_CONNECTED, + Op_MATMUL, + Op_MAX_POOL2D, + Op_TRANSPOSE_CONV2D, + Op_CLAMP, + Op_RELUN, + Op_SIGMOID, + Op_TANH, + Op_ADD, + Op_ARITHMETIC_RIGHT_SHIFT, + Op_BITWISE_AND, + Op_BITWISE_OR, + Op_BITWISE_XOR, + Op_LOGICAL_AND, + Op_LOGICAL_LEFT_SHIFT, + Op_LOGICAL_RIGHT_SHIFT, + Op_LOGICAL_OR, + Op_LOGICAL_XOR, + Op_MAXIMUM, + Op_MINIMUM, + Op_MUL, + Op_POW, + Op_SUB, + Op_TABLE, + Op_ABS, + Op_BITWISE_NOT, + Op_CEIL, + Op_CLZ, + Op_EXP, + Op_FLOOR, + Op_LOG, + Op_LOGICAL_NOT, + Op_NEGATE, + Op_RECIPROCAL, + Op_RSQRT, + Op_SELECT, + Op_EQUAL, + Op_GREATER, + Op_GREATER_EQUAL, + Op_REDUCE_ANY, + Op_REDUCE_ALL, + Op_REDUCE_MAX, + Op_REDUCE_MIN, + Op_REDUCE_PRODUCT, + Op_REDUCE_SUM, + Op_CONCAT, + Op_PAD, + Op_RESHAPE, + Op_REVERSE, + Op_SLICE, + Op_TILE, + Op_TRANSPOSE, + Op_GATHER, + Op_SCATTER, + Op_RESIZE, + Op_CAST, + Op_RESCALE, + Op_CONST, + Op_PLACEHOLDER, + Op_IDENTITY, + Op_IDENTITYN, + Op_CUSTOM, + Op_COND_IF, + Op_WHILE_LOOP + }; + return values; +} + +inline const char * const *EnumNamesOp() { + static const char * const names[71] = { + "UNKNOWN", + "ARGMAX", + "AVG_POOL2D", + "CONV2D", + "CONV3D", + "DEPTHWISE_CONV2D", + "FULLY_CONNECTED", + "MATMUL", + "MAX_POOL2D", + "TRANSPOSE_CONV2D", + "CLAMP", + "RELUN", + "SIGMOID", + "TANH", + "ADD", + "ARITHMETIC_RIGHT_SHIFT", + "BITWISE_AND", + "BITWISE_OR", + "BITWISE_XOR", + "LOGICAL_AND", + "LOGICAL_LEFT_SHIFT", + "LOGICAL_RIGHT_SHIFT", + "LOGICAL_OR", + "LOGICAL_XOR", + "MAXIMUM", + "MINIMUM", + "MUL", + "POW", + "SUB", + "TABLE", + "ABS", + "BITWISE_NOT", + "CEIL", + "CLZ", + "EXP", + "FLOOR", + "LOG", + "LOGICAL_NOT", + "NEGATE", + "RECIPROCAL", + "RSQRT", + "SELECT", + "EQUAL", + "GREATER", + "GREATER_EQUAL", + "REDUCE_ANY", + "REDUCE_ALL", + "REDUCE_MAX", + "REDUCE_MIN", + "REDUCE_PRODUCT", + "REDUCE_SUM", + "CONCAT", + "PAD", + "RESHAPE", + "REVERSE", + "SLICE", + "TILE", + "TRANSPOSE", + "GATHER", + "SCATTER", + "RESIZE", + "CAST", + "RESCALE", + "CONST", + "PLACEHOLDER", + "IDENTITY", + "IDENTITYN", + "CUSTOM", + "COND_IF", + "WHILE_LOOP", + nullptr + }; + return names; +} + +inline const char *EnumNameOp(Op e) { + if (flatbuffers::IsOutRange(e, Op_UNKNOWN, Op_WHILE_LOOP)) return ""; + const size_t index = static_cast(e); + return EnumNamesOp()[index]; +} + +enum Attribute { + Attribute_NONE = 0, + Attribute_Pool2dAttribute = 1, + Attribute_Conv2dAttribute = 2, + Attribute_TransposeConv2dAttribute = 3, + Attribute_ReluNAttribute = 4, + Attribute_AxisAttribute = 5, + Attribute_ReshapeAttribute = 6, + Attribute_SliceAttribute = 7, + Attribute_TileAttribute = 8, + Attribute_ResizeAttribute = 9, + Attribute_ClampAttribute = 10, + Attribute_RescaleAttribute = 11, + Attribute_MulAttribute = 12, + Attribute_ArithmeticRightShiftAttribute = 13, + Attribute_CondIfAttribute = 14, + Attribute_WhileLoopAttribute = 15, + Attribute_MIN = Attribute_NONE, + Attribute_MAX = Attribute_WhileLoopAttribute +}; + +inline const Attribute (&EnumValuesAttribute())[16] { + static const Attribute values[] = { + Attribute_NONE, + Attribute_Pool2dAttribute, + Attribute_Conv2dAttribute, + Attribute_TransposeConv2dAttribute, + Attribute_ReluNAttribute, + Attribute_AxisAttribute, + Attribute_ReshapeAttribute, + Attribute_SliceAttribute, + Attribute_TileAttribute, + Attribute_ResizeAttribute, + Attribute_ClampAttribute, + Attribute_RescaleAttribute, + Attribute_MulAttribute, + Attribute_ArithmeticRightShiftAttribute, + Attribute_CondIfAttribute, + Attribute_WhileLoopAttribute + }; + return values; +} + +inline const char * const *EnumNamesAttribute() { + static const char * const names[17] = { + "NONE", + "Pool2dAttribute", + "Conv2dAttribute", + "TransposeConv2dAttribute", + "ReluNAttribute", + "AxisAttribute", + "ReshapeAttribute", + "SliceAttribute", + "TileAttribute", + "ResizeAttribute", + "ClampAttribute", + "RescaleAttribute", + "MulAttribute", + "ArithmeticRightShiftAttribute", + "CondIfAttribute", + "WhileLoopAttribute", + nullptr + }; + return names; +} + +inline const char *EnumNameAttribute(Attribute e) { + if (flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_WhileLoopAttribute)) return ""; + const size_t index = static_cast(e); + return EnumNamesAttribute()[index]; +} + +template struct AttributeTraits { + static const Attribute enum_value = Attribute_NONE; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_Pool2dAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_Conv2dAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_TransposeConv2dAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_ReluNAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_AxisAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_ReshapeAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_SliceAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_TileAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_ResizeAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_ClampAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_RescaleAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_MulAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_ArithmeticRightShiftAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_CondIfAttribute; +}; + +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_WhileLoopAttribute; +}; + +bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type); +bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +enum QuantInfo { + QuantInfo_NONE = 0, + QuantInfo_UnaryQuantInfo = 1, + QuantInfo_ConvQuantInfo = 2, + QuantInfo_MatMulQuantInfo = 3, + QuantInfo_PadQuantInfo = 4, + QuantInfo_MIN = QuantInfo_NONE, + QuantInfo_MAX = QuantInfo_PadQuantInfo +}; + +inline const QuantInfo (&EnumValuesQuantInfo())[5] { + static const QuantInfo values[] = { + QuantInfo_NONE, + QuantInfo_UnaryQuantInfo, + QuantInfo_ConvQuantInfo, + QuantInfo_MatMulQuantInfo, + QuantInfo_PadQuantInfo + }; + return values; +} + +inline const char * const *EnumNamesQuantInfo() { + static const char * const names[6] = { + "NONE", + "UnaryQuantInfo", + "ConvQuantInfo", + "MatMulQuantInfo", + "PadQuantInfo", + nullptr + }; + return names; +} + +inline const char *EnumNameQuantInfo(QuantInfo e) { + if (flatbuffers::IsOutRange(e, QuantInfo_NONE, QuantInfo_PadQuantInfo)) return ""; + const size_t index = static_cast(e); + return EnumNamesQuantInfo()[index]; +} + +template struct QuantInfoTraits { + static const QuantInfo enum_value = QuantInfo_NONE; +}; + +template<> struct QuantInfoTraits { + static const QuantInfo enum_value = QuantInfo_UnaryQuantInfo; +}; + +template<> struct QuantInfoTraits { + static const QuantInfo enum_value = QuantInfo_ConvQuantInfo; +}; + +template<> struct QuantInfoTraits { + static const QuantInfo enum_value = QuantInfo_MatMulQuantInfo; +}; + +template<> struct QuantInfoTraits { + static const QuantInfo enum_value = QuantInfo_PadQuantInfo; +}; + +bool VerifyQuantInfo(flatbuffers::Verifier &verifier, const void *obj, QuantInfo type); +bool VerifyQuantInfoVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +struct Pool2dAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef Pool2dAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PADDING = 4, + VT_KERNEL = 6, + VT_STRIDE = 8 + }; + const flatbuffers::Vector *padding() const { + return GetPointer *>(VT_PADDING); + } + const flatbuffers::Vector *kernel() const { + return GetPointer *>(VT_KERNEL); + } + const flatbuffers::Vector *stride() const { + return GetPointer *>(VT_STRIDE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_PADDING) && + verifier.VerifyVector(padding()) && + VerifyOffset(verifier, VT_KERNEL) && + verifier.VerifyVector(kernel()) && + VerifyOffset(verifier, VT_STRIDE) && + verifier.VerifyVector(stride()) && + verifier.EndTable(); + } +}; + +struct Pool2dAttributeBuilder { + typedef Pool2dAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_padding(flatbuffers::Offset> padding) { + fbb_.AddOffset(Pool2dAttribute::VT_PADDING, padding); + } + void add_kernel(flatbuffers::Offset> kernel) { + fbb_.AddOffset(Pool2dAttribute::VT_KERNEL, kernel); + } + void add_stride(flatbuffers::Offset> stride) { + fbb_.AddOffset(Pool2dAttribute::VT_STRIDE, stride); + } + explicit Pool2dAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + Pool2dAttributeBuilder &operator=(const Pool2dAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreatePool2dAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> padding = 0, + flatbuffers::Offset> kernel = 0, + flatbuffers::Offset> stride = 0) { + Pool2dAttributeBuilder builder_(_fbb); + builder_.add_stride(stride); + builder_.add_kernel(kernel); + builder_.add_padding(padding); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreatePool2dAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *padding = nullptr, + const std::vector *kernel = nullptr, + const std::vector *stride = nullptr) { + auto padding__ = padding ? _fbb.CreateVector(*padding) : 0; + auto kernel__ = kernel ? _fbb.CreateVector(*kernel) : 0; + auto stride__ = stride ? _fbb.CreateVector(*stride) : 0; + return tosa::CreatePool2dAttribute( + _fbb, + padding__, + kernel__, + stride__); +} + +struct Conv2dAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef Conv2dAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PADDING = 4, + VT_STRIDE = 6, + VT_DILATION = 8 + }; + const flatbuffers::Vector *padding() const { + return GetPointer *>(VT_PADDING); + } + const flatbuffers::Vector *stride() const { + return GetPointer *>(VT_STRIDE); + } + const flatbuffers::Vector *dilation() const { + return GetPointer *>(VT_DILATION); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_PADDING) && + verifier.VerifyVector(padding()) && + VerifyOffset(verifier, VT_STRIDE) && + verifier.VerifyVector(stride()) && + VerifyOffset(verifier, VT_DILATION) && + verifier.VerifyVector(dilation()) && + verifier.EndTable(); + } +}; + +struct Conv2dAttributeBuilder { + typedef Conv2dAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_padding(flatbuffers::Offset> padding) { + fbb_.AddOffset(Conv2dAttribute::VT_PADDING, padding); + } + void add_stride(flatbuffers::Offset> stride) { + fbb_.AddOffset(Conv2dAttribute::VT_STRIDE, stride); + } + void add_dilation(flatbuffers::Offset> dilation) { + fbb_.AddOffset(Conv2dAttribute::VT_DILATION, dilation); + } + explicit Conv2dAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + Conv2dAttributeBuilder &operator=(const Conv2dAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConv2dAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> padding = 0, + flatbuffers::Offset> stride = 0, + flatbuffers::Offset> dilation = 0) { + Conv2dAttributeBuilder builder_(_fbb); + builder_.add_dilation(dilation); + builder_.add_stride(stride); + builder_.add_padding(padding); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateConv2dAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *padding = nullptr, + const std::vector *stride = nullptr, + const std::vector *dilation = nullptr) { + auto padding__ = padding ? _fbb.CreateVector(*padding) : 0; + auto stride__ = stride ? _fbb.CreateVector(*stride) : 0; + auto dilation__ = dilation ? _fbb.CreateVector(*dilation) : 0; + return tosa::CreateConv2dAttribute( + _fbb, + padding__, + stride__, + dilation__); +} + +struct TransposeConv2dAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TransposeConv2dAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OUTPAD = 4, + VT_STRIDE = 6, + VT_DILATION = 8, + VT_OUTPUT_SHAPE = 10 + }; + const flatbuffers::Vector *outpad() const { + return GetPointer *>(VT_OUTPAD); + } + const flatbuffers::Vector *stride() const { + return GetPointer *>(VT_STRIDE); + } + const flatbuffers::Vector *dilation() const { + return GetPointer *>(VT_DILATION); + } + const flatbuffers::Vector *output_shape() const { + return GetPointer *>(VT_OUTPUT_SHAPE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_OUTPAD) && + verifier.VerifyVector(outpad()) && + VerifyOffset(verifier, VT_STRIDE) && + verifier.VerifyVector(stride()) && + VerifyOffset(verifier, VT_DILATION) && + verifier.VerifyVector(dilation()) && + VerifyOffset(verifier, VT_OUTPUT_SHAPE) && + verifier.VerifyVector(output_shape()) && + verifier.EndTable(); + } +}; + +struct TransposeConv2dAttributeBuilder { + typedef TransposeConv2dAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_outpad(flatbuffers::Offset> outpad) { + fbb_.AddOffset(TransposeConv2dAttribute::VT_OUTPAD, outpad); + } + void add_stride(flatbuffers::Offset> stride) { + fbb_.AddOffset(TransposeConv2dAttribute::VT_STRIDE, stride); + } + void add_dilation(flatbuffers::Offset> dilation) { + fbb_.AddOffset(TransposeConv2dAttribute::VT_DILATION, dilation); + } + void add_output_shape(flatbuffers::Offset> output_shape) { + fbb_.AddOffset(TransposeConv2dAttribute::VT_OUTPUT_SHAPE, output_shape); + } + explicit TransposeConv2dAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TransposeConv2dAttributeBuilder &operator=(const TransposeConv2dAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTransposeConv2dAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> outpad = 0, + flatbuffers::Offset> stride = 0, + flatbuffers::Offset> dilation = 0, + flatbuffers::Offset> output_shape = 0) { + TransposeConv2dAttributeBuilder builder_(_fbb); + builder_.add_output_shape(output_shape); + builder_.add_dilation(dilation); + builder_.add_stride(stride); + builder_.add_outpad(outpad); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTransposeConv2dAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *outpad = nullptr, + const std::vector *stride = nullptr, + const std::vector *dilation = nullptr, + const std::vector *output_shape = nullptr) { + auto outpad__ = outpad ? _fbb.CreateVector(*outpad) : 0; + auto stride__ = stride ? _fbb.CreateVector(*stride) : 0; + auto dilation__ = dilation ? _fbb.CreateVector(*dilation) : 0; + auto output_shape__ = output_shape ? _fbb.CreateVector(*output_shape) : 0; + return tosa::CreateTransposeConv2dAttribute( + _fbb, + outpad__, + stride__, + dilation__, + output_shape__); +} + +struct ReluNAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ReluNAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_MAX_INT = 4, + VT_MAX_FP = 6 + }; + int32_t max_int() const { + return GetField(VT_MAX_INT, 0); + } + float max_fp() const { + return GetField(VT_MAX_FP, 0.0f); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_MAX_INT) && + VerifyField(verifier, VT_MAX_FP) && + verifier.EndTable(); + } +}; + +struct ReluNAttributeBuilder { + typedef ReluNAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_max_int(int32_t max_int) { + fbb_.AddElement(ReluNAttribute::VT_MAX_INT, max_int, 0); + } + void add_max_fp(float max_fp) { + fbb_.AddElement(ReluNAttribute::VT_MAX_FP, max_fp, 0.0f); + } + explicit ReluNAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ReluNAttributeBuilder &operator=(const ReluNAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateReluNAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t max_int = 0, + float max_fp = 0.0f) { + ReluNAttributeBuilder builder_(_fbb); + builder_.add_max_fp(max_fp); + builder_.add_max_int(max_int); + return builder_.Finish(); +} + +struct AxisAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef AxisAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_AXIS = 4 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_AXIS) && + verifier.EndTable(); + } +}; + +struct AxisAttributeBuilder { + typedef AxisAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement(AxisAttribute::VT_AXIS, axis, 0); + } + explicit AxisAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + AxisAttributeBuilder &operator=(const AxisAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateAxisAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0) { + AxisAttributeBuilder builder_(_fbb); + builder_.add_axis(axis); + return builder_.Finish(); +} + +struct ReshapeAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ReshapeAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SHAPE = 4 + }; + const flatbuffers::Vector *shape() const { + return GetPointer *>(VT_SHAPE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_SHAPE) && + verifier.VerifyVector(shape()) && + verifier.EndTable(); + } +}; + +struct ReshapeAttributeBuilder { + typedef ReshapeAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_shape(flatbuffers::Offset> shape) { + fbb_.AddOffset(ReshapeAttribute::VT_SHAPE, shape); + } + explicit ReshapeAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ReshapeAttributeBuilder &operator=(const ReshapeAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateReshapeAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> shape = 0) { + ReshapeAttributeBuilder builder_(_fbb); + builder_.add_shape(shape); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateReshapeAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *shape = nullptr) { + auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; + return tosa::CreateReshapeAttribute( + _fbb, + shape__); +} + +struct SliceAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SliceAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_BEGIN = 4, + VT_SIZE = 6 + }; + const flatbuffers::Vector *begin() const { + return GetPointer *>(VT_BEGIN); + } + const flatbuffers::Vector *size() const { + return GetPointer *>(VT_SIZE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_BEGIN) && + verifier.VerifyVector(begin()) && + VerifyOffset(verifier, VT_SIZE) && + verifier.VerifyVector(size()) && + verifier.EndTable(); + } +}; + +struct SliceAttributeBuilder { + typedef SliceAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_begin(flatbuffers::Offset> begin) { + fbb_.AddOffset(SliceAttribute::VT_BEGIN, begin); + } + void add_size(flatbuffers::Offset> size) { + fbb_.AddOffset(SliceAttribute::VT_SIZE, size); + } + explicit SliceAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SliceAttributeBuilder &operator=(const SliceAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSliceAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> begin = 0, + flatbuffers::Offset> size = 0) { + SliceAttributeBuilder builder_(_fbb); + builder_.add_size(size); + builder_.add_begin(begin); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateSliceAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *begin = nullptr, + const std::vector *size = nullptr) { + auto begin__ = begin ? _fbb.CreateVector(*begin) : 0; + auto size__ = size ? _fbb.CreateVector(*size) : 0; + return tosa::CreateSliceAttribute( + _fbb, + begin__, + size__); +} + +struct TileAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TileAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_MULTIPLES = 4 + }; + const flatbuffers::Vector *multiples() const { + return GetPointer *>(VT_MULTIPLES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_MULTIPLES) && + verifier.VerifyVector(multiples()) && + verifier.EndTable(); + } +}; + +struct TileAttributeBuilder { + typedef TileAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_multiples(flatbuffers::Offset> multiples) { + fbb_.AddOffset(TileAttribute::VT_MULTIPLES, multiples); + } + explicit TileAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TileAttributeBuilder &operator=(const TileAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTileAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> multiples = 0) { + TileAttributeBuilder builder_(_fbb); + builder_.add_multiples(multiples); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTileAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *multiples = nullptr) { + auto multiples__ = multiples ? _fbb.CreateVector(*multiples) : 0; + return tosa::CreateTileAttribute( + _fbb, + multiples__); +} + +struct ResizeAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ResizeAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OUTPUT_SIZE = 4, + VT_STRIDE = 6, + VT_OFFSET = 8, + VT_SHIFT = 10, + VT_STRIDE_FP = 12, + VT_OFFSET_FP = 14, + VT_MODE = 16 + }; + const flatbuffers::Vector *output_size() const { + return GetPointer *>(VT_OUTPUT_SIZE); + } + const flatbuffers::Vector *stride() const { + return GetPointer *>(VT_STRIDE); + } + const flatbuffers::Vector *offset() const { + return GetPointer *>(VT_OFFSET); + } + int32_t shift() const { + return GetField(VT_SHIFT, 0); + } + const flatbuffers::Vector *stride_fp() const { + return GetPointer *>(VT_STRIDE_FP); + } + const flatbuffers::Vector *offset_fp() const { + return GetPointer *>(VT_OFFSET_FP); + } + tosa::ResizeMode mode() const { + return static_cast(GetField(VT_MODE, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_OUTPUT_SIZE) && + verifier.VerifyVector(output_size()) && + VerifyOffset(verifier, VT_STRIDE) && + verifier.VerifyVector(stride()) && + VerifyOffset(verifier, VT_OFFSET) && + verifier.VerifyVector(offset()) && + VerifyField(verifier, VT_SHIFT) && + VerifyOffset(verifier, VT_STRIDE_FP) && + verifier.VerifyVector(stride_fp()) && + VerifyOffset(verifier, VT_OFFSET_FP) && + verifier.VerifyVector(offset_fp()) && + VerifyField(verifier, VT_MODE) && + verifier.EndTable(); + } +}; + +struct ResizeAttributeBuilder { + typedef ResizeAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_output_size(flatbuffers::Offset> output_size) { + fbb_.AddOffset(ResizeAttribute::VT_OUTPUT_SIZE, output_size); + } + void add_stride(flatbuffers::Offset> stride) { + fbb_.AddOffset(ResizeAttribute::VT_STRIDE, stride); + } + void add_offset(flatbuffers::Offset> offset) { + fbb_.AddOffset(ResizeAttribute::VT_OFFSET, offset); + } + void add_shift(int32_t shift) { + fbb_.AddElement(ResizeAttribute::VT_SHIFT, shift, 0); + } + void add_stride_fp(flatbuffers::Offset> stride_fp) { + fbb_.AddOffset(ResizeAttribute::VT_STRIDE_FP, stride_fp); + } + void add_offset_fp(flatbuffers::Offset> offset_fp) { + fbb_.AddOffset(ResizeAttribute::VT_OFFSET_FP, offset_fp); + } + void add_mode(tosa::ResizeMode mode) { + fbb_.AddElement(ResizeAttribute::VT_MODE, static_cast(mode), 0); + } + explicit ResizeAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ResizeAttributeBuilder &operator=(const ResizeAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateResizeAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> output_size = 0, + flatbuffers::Offset> stride = 0, + flatbuffers::Offset> offset = 0, + int32_t shift = 0, + flatbuffers::Offset> stride_fp = 0, + flatbuffers::Offset> offset_fp = 0, + tosa::ResizeMode mode = tosa::ResizeMode_UNKNOWN) { + ResizeAttributeBuilder builder_(_fbb); + builder_.add_mode(mode); + builder_.add_offset_fp(offset_fp); + builder_.add_stride_fp(stride_fp); + builder_.add_shift(shift); + builder_.add_offset(offset); + builder_.add_stride(stride); + builder_.add_output_size(output_size); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateResizeAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *output_size = nullptr, + const std::vector *stride = nullptr, + const std::vector *offset = nullptr, + int32_t shift = 0, + const std::vector *stride_fp = nullptr, + const std::vector *offset_fp = nullptr, + tosa::ResizeMode mode = tosa::ResizeMode_UNKNOWN) { + auto output_size__ = output_size ? _fbb.CreateVector(*output_size) : 0; + auto stride__ = stride ? _fbb.CreateVector(*stride) : 0; + auto offset__ = offset ? _fbb.CreateVector(*offset) : 0; + auto stride_fp__ = stride_fp ? _fbb.CreateVector(*stride_fp) : 0; + auto offset_fp__ = offset_fp ? _fbb.CreateVector(*offset_fp) : 0; + return tosa::CreateResizeAttribute( + _fbb, + output_size__, + stride__, + offset__, + shift, + stride_fp__, + offset_fp__, + mode); +} + +struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ClampAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_MIN_INT = 4, + VT_MAX_INT = 6, + VT_MIN_FP = 8, + VT_MAX_FP = 10 + }; + int32_t min_int() const { + return GetField(VT_MIN_INT, 0); + } + int32_t max_int() const { + return GetField(VT_MAX_INT, 0); + } + float min_fp() const { + return GetField(VT_MIN_FP, 0.0f); + } + float max_fp() const { + return GetField(VT_MAX_FP, 0.0f); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_MIN_INT) && + VerifyField(verifier, VT_MAX_INT) && + VerifyField(verifier, VT_MIN_FP) && + VerifyField(verifier, VT_MAX_FP) && + verifier.EndTable(); + } +}; + +struct ClampAttributeBuilder { + typedef ClampAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_min_int(int32_t min_int) { + fbb_.AddElement(ClampAttribute::VT_MIN_INT, min_int, 0); + } + void add_max_int(int32_t max_int) { + fbb_.AddElement(ClampAttribute::VT_MAX_INT, max_int, 0); + } + void add_min_fp(float min_fp) { + fbb_.AddElement(ClampAttribute::VT_MIN_FP, min_fp, 0.0f); + } + void add_max_fp(float max_fp) { + fbb_.AddElement(ClampAttribute::VT_MAX_FP, max_fp, 0.0f); + } + explicit ClampAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ClampAttributeBuilder &operator=(const ClampAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateClampAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t min_int = 0, + int32_t max_int = 0, + float min_fp = 0.0f, + float max_fp = 0.0f) { + ClampAttributeBuilder builder_(_fbb); + builder_.add_max_fp(max_fp); + builder_.add_min_fp(min_fp); + builder_.add_max_int(max_int); + builder_.add_min_int(min_int); + return builder_.Finish(); +} + +struct RescaleAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef RescaleAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INPUT_ZP = 4, + VT_OUTPUT_ZP = 6, + VT_MULTIPLIER = 8, + VT_SHIFT = 10, + VT_SCALE32 = 12, + VT_DOUBLE_ROUND = 14, + VT_PER_CHANNEL = 16 + }; + int32_t input_zp() const { + return GetField(VT_INPUT_ZP, 0); + } + int32_t output_zp() const { + return GetField(VT_OUTPUT_ZP, 0); + } + const flatbuffers::Vector *multiplier() const { + return GetPointer *>(VT_MULTIPLIER); + } + const flatbuffers::Vector *shift() const { + return GetPointer *>(VT_SHIFT); + } + bool scale32() const { + return GetField(VT_SCALE32, 0) != 0; + } + bool double_round() const { + return GetField(VT_DOUBLE_ROUND, 0) != 0; + } + bool per_channel() const { + return GetField(VT_PER_CHANNEL, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_INPUT_ZP) && + VerifyField(verifier, VT_OUTPUT_ZP) && + VerifyOffset(verifier, VT_MULTIPLIER) && + verifier.VerifyVector(multiplier()) && + VerifyOffset(verifier, VT_SHIFT) && + verifier.VerifyVector(shift()) && + VerifyField(verifier, VT_SCALE32) && + VerifyField(verifier, VT_DOUBLE_ROUND) && + VerifyField(verifier, VT_PER_CHANNEL) && + verifier.EndTable(); + } +}; + +struct RescaleAttributeBuilder { + typedef RescaleAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_input_zp(int32_t input_zp) { + fbb_.AddElement(RescaleAttribute::VT_INPUT_ZP, input_zp, 0); + } + void add_output_zp(int32_t output_zp) { + fbb_.AddElement(RescaleAttribute::VT_OUTPUT_ZP, output_zp, 0); + } + void add_multiplier(flatbuffers::Offset> multiplier) { + fbb_.AddOffset(RescaleAttribute::VT_MULTIPLIER, multiplier); + } + void add_shift(flatbuffers::Offset> shift) { + fbb_.AddOffset(RescaleAttribute::VT_SHIFT, shift); + } + void add_scale32(bool scale32) { + fbb_.AddElement(RescaleAttribute::VT_SCALE32, static_cast(scale32), 0); + } + void add_double_round(bool double_round) { + fbb_.AddElement(RescaleAttribute::VT_DOUBLE_ROUND, static_cast(double_round), 0); + } + void add_per_channel(bool per_channel) { + fbb_.AddElement(RescaleAttribute::VT_PER_CHANNEL, static_cast(per_channel), 0); + } + explicit RescaleAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + RescaleAttributeBuilder &operator=(const RescaleAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateRescaleAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t input_zp = 0, + int32_t output_zp = 0, + flatbuffers::Offset> multiplier = 0, + flatbuffers::Offset> shift = 0, + bool scale32 = false, + bool double_round = false, + bool per_channel = false) { + RescaleAttributeBuilder builder_(_fbb); + builder_.add_shift(shift); + builder_.add_multiplier(multiplier); + builder_.add_output_zp(output_zp); + builder_.add_input_zp(input_zp); + builder_.add_per_channel(per_channel); + builder_.add_double_round(double_round); + builder_.add_scale32(scale32); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateRescaleAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t input_zp = 0, + int32_t output_zp = 0, + const std::vector *multiplier = nullptr, + const std::vector *shift = nullptr, + bool scale32 = false, + bool double_round = false, + bool per_channel = false) { + auto multiplier__ = multiplier ? _fbb.CreateVector(*multiplier) : 0; + auto shift__ = shift ? _fbb.CreateVector(*shift) : 0; + return tosa::CreateRescaleAttribute( + _fbb, + input_zp, + output_zp, + multiplier__, + shift__, + scale32, + double_round, + per_channel); +} + +struct MulAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef MulAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SHIFT = 4 + }; + int32_t shift() const { + return GetField(VT_SHIFT, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_SHIFT) && + verifier.EndTable(); + } +}; + +struct MulAttributeBuilder { + typedef MulAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_shift(int32_t shift) { + fbb_.AddElement(MulAttribute::VT_SHIFT, shift, 0); + } + explicit MulAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + MulAttributeBuilder &operator=(const MulAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateMulAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t shift = 0) { + MulAttributeBuilder builder_(_fbb); + builder_.add_shift(shift); + return builder_.Finish(); +} + +struct ArithmeticRightShiftAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ArithmeticRightShiftAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ROUND = 4 + }; + bool round() const { + return GetField(VT_ROUND, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_ROUND) && + verifier.EndTable(); + } +}; + +struct ArithmeticRightShiftAttributeBuilder { + typedef ArithmeticRightShiftAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_round(bool round) { + fbb_.AddElement(ArithmeticRightShiftAttribute::VT_ROUND, static_cast(round), 0); + } + explicit ArithmeticRightShiftAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ArithmeticRightShiftAttributeBuilder &operator=(const ArithmeticRightShiftAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateArithmeticRightShiftAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + bool round = false) { + ArithmeticRightShiftAttributeBuilder builder_(_fbb); + builder_.add_round(round); + return builder_.Finish(); +} + +struct CondIfAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef CondIfAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_THEN_BRANCH = 4, + VT_ELSE_BRANCH = 6 + }; + const flatbuffers::String *then_branch() const { + return GetPointer(VT_THEN_BRANCH); + } + const flatbuffers::String *else_branch() const { + return GetPointer(VT_ELSE_BRANCH); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_THEN_BRANCH) && + verifier.VerifyString(then_branch()) && + VerifyOffset(verifier, VT_ELSE_BRANCH) && + verifier.VerifyString(else_branch()) && + verifier.EndTable(); + } +}; + +struct CondIfAttributeBuilder { + typedef CondIfAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_then_branch(flatbuffers::Offset then_branch) { + fbb_.AddOffset(CondIfAttribute::VT_THEN_BRANCH, then_branch); + } + void add_else_branch(flatbuffers::Offset else_branch) { + fbb_.AddOffset(CondIfAttribute::VT_ELSE_BRANCH, else_branch); + } + explicit CondIfAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CondIfAttributeBuilder &operator=(const CondIfAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateCondIfAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset then_branch = 0, + flatbuffers::Offset else_branch = 0) { + CondIfAttributeBuilder builder_(_fbb); + builder_.add_else_branch(else_branch); + builder_.add_then_branch(then_branch); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateCondIfAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *then_branch = nullptr, + const char *else_branch = nullptr) { + auto then_branch__ = then_branch ? _fbb.CreateString(then_branch) : 0; + auto else_branch__ = else_branch ? _fbb.CreateString(else_branch) : 0; + return tosa::CreateCondIfAttribute( + _fbb, + then_branch__, + else_branch__); +} + +struct WhileLoopAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef WhileLoopAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_COND_BRANCH = 4, + VT_BODY_BRANCH = 6 + }; + const flatbuffers::String *cond_branch() const { + return GetPointer(VT_COND_BRANCH); + } + const flatbuffers::String *body_branch() const { + return GetPointer(VT_BODY_BRANCH); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_COND_BRANCH) && + verifier.VerifyString(cond_branch()) && + VerifyOffset(verifier, VT_BODY_BRANCH) && + verifier.VerifyString(body_branch()) && + verifier.EndTable(); + } +}; + +struct WhileLoopAttributeBuilder { + typedef WhileLoopAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_cond_branch(flatbuffers::Offset cond_branch) { + fbb_.AddOffset(WhileLoopAttribute::VT_COND_BRANCH, cond_branch); + } + void add_body_branch(flatbuffers::Offset body_branch) { + fbb_.AddOffset(WhileLoopAttribute::VT_BODY_BRANCH, body_branch); + } + explicit WhileLoopAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + WhileLoopAttributeBuilder &operator=(const WhileLoopAttributeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateWhileLoopAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset cond_branch = 0, + flatbuffers::Offset body_branch = 0) { + WhileLoopAttributeBuilder builder_(_fbb); + builder_.add_body_branch(body_branch); + builder_.add_cond_branch(cond_branch); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateWhileLoopAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *cond_branch = nullptr, + const char *body_branch = nullptr) { + auto cond_branch__ = cond_branch ? _fbb.CreateString(cond_branch) : 0; + auto body_branch__ = body_branch ? _fbb.CreateString(body_branch) : 0; + return tosa::CreateWhileLoopAttribute( + _fbb, + cond_branch__, + body_branch__); +} + +struct UnaryQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef UnaryQuantInfoBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INPUT_ZP = 4, + VT_OUTPUT_ZP = 6 + }; + int32_t input_zp() const { + return GetField(VT_INPUT_ZP, 0); + } + int32_t output_zp() const { + return GetField(VT_OUTPUT_ZP, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_INPUT_ZP) && + VerifyField(verifier, VT_OUTPUT_ZP) && + verifier.EndTable(); + } +}; + +struct UnaryQuantInfoBuilder { + typedef UnaryQuantInfo Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_input_zp(int32_t input_zp) { + fbb_.AddElement(UnaryQuantInfo::VT_INPUT_ZP, input_zp, 0); + } + void add_output_zp(int32_t output_zp) { + fbb_.AddElement(UnaryQuantInfo::VT_OUTPUT_ZP, output_zp, 0); + } + explicit UnaryQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + UnaryQuantInfoBuilder &operator=(const UnaryQuantInfoBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateUnaryQuantInfo( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t input_zp = 0, + int32_t output_zp = 0) { + UnaryQuantInfoBuilder builder_(_fbb); + builder_.add_output_zp(output_zp); + builder_.add_input_zp(input_zp); + return builder_.Finish(); +} + +struct ConvQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConvQuantInfoBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INPUT_ZP = 4, + VT_WEIGHT_ZP = 6 + }; + int32_t input_zp() const { + return GetField(VT_INPUT_ZP, 0); + } + int32_t weight_zp() const { + return GetField(VT_WEIGHT_ZP, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_INPUT_ZP) && + VerifyField(verifier, VT_WEIGHT_ZP) && + verifier.EndTable(); + } +}; + +struct ConvQuantInfoBuilder { + typedef ConvQuantInfo Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_input_zp(int32_t input_zp) { + fbb_.AddElement(ConvQuantInfo::VT_INPUT_ZP, input_zp, 0); + } + void add_weight_zp(int32_t weight_zp) { + fbb_.AddElement(ConvQuantInfo::VT_WEIGHT_ZP, weight_zp, 0); + } + explicit ConvQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConvQuantInfoBuilder &operator=(const ConvQuantInfoBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConvQuantInfo( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t input_zp = 0, + int32_t weight_zp = 0) { + ConvQuantInfoBuilder builder_(_fbb); + builder_.add_weight_zp(weight_zp); + builder_.add_input_zp(input_zp); + return builder_.Finish(); +} + +struct MatMulQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef MatMulQuantInfoBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_A_ZP = 4, + VT_B_ZP = 6 + }; + int32_t a_zp() const { + return GetField(VT_A_ZP, 0); + } + int32_t b_zp() const { + return GetField(VT_B_ZP, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_A_ZP) && + VerifyField(verifier, VT_B_ZP) && + verifier.EndTable(); + } +}; + +struct MatMulQuantInfoBuilder { + typedef MatMulQuantInfo Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_a_zp(int32_t a_zp) { + fbb_.AddElement(MatMulQuantInfo::VT_A_ZP, a_zp, 0); + } + void add_b_zp(int32_t b_zp) { + fbb_.AddElement(MatMulQuantInfo::VT_B_ZP, b_zp, 0); + } + explicit MatMulQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + MatMulQuantInfoBuilder &operator=(const MatMulQuantInfoBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateMatMulQuantInfo( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t a_zp = 0, + int32_t b_zp = 0) { + MatMulQuantInfoBuilder builder_(_fbb); + builder_.add_b_zp(b_zp); + builder_.add_a_zp(a_zp); + return builder_.Finish(); +} + +struct PadQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef PadQuantInfoBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INPUT_ZP = 4 + }; + int32_t input_zp() const { + return GetField(VT_INPUT_ZP, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_INPUT_ZP) && + verifier.EndTable(); + } +}; + +struct PadQuantInfoBuilder { + typedef PadQuantInfo Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_input_zp(int32_t input_zp) { + fbb_.AddElement(PadQuantInfo::VT_INPUT_ZP, input_zp, 0); + } + explicit PadQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + PadQuantInfoBuilder &operator=(const PadQuantInfoBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreatePadQuantInfo( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t input_zp = 0) { + PadQuantInfoBuilder builder_(_fbb); + builder_.add_input_zp(input_zp); + return builder_.Finish(); +} + +struct Version FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef VersionBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT__MAJOR = 4, + VT__MINOR = 6, + VT__PATCH = 8, + VT__EXPERIMENTAL = 10 + }; + int32_t _major() const { + return GetField(VT__MAJOR, 0); + } + int32_t _minor() const { + return GetField(VT__MINOR, 21); + } + int32_t _patch() const { + return GetField(VT__PATCH, 0); + } + bool _experimental() const { + return GetField(VT__EXPERIMENTAL, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT__MAJOR) && + VerifyField(verifier, VT__MINOR) && + VerifyField(verifier, VT__PATCH) && + VerifyField(verifier, VT__EXPERIMENTAL) && + verifier.EndTable(); + } +}; + +struct VersionBuilder { + typedef Version Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add__major(int32_t _major) { + fbb_.AddElement(Version::VT__MAJOR, _major, 0); + } + void add__minor(int32_t _minor) { + fbb_.AddElement(Version::VT__MINOR, _minor, 21); + } + void add__patch(int32_t _patch) { + fbb_.AddElement(Version::VT__PATCH, _patch, 0); + } + void add__experimental(bool _experimental) { + fbb_.AddElement(Version::VT__EXPERIMENTAL, static_cast(_experimental), 0); + } + explicit VersionBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + VersionBuilder &operator=(const VersionBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateVersion( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t _major = 0, + int32_t _minor = 21, + int32_t _patch = 0, + bool _experimental = false) { + VersionBuilder builder_(_fbb); + builder_.add__patch(_patch); + builder_.add__minor(_minor); + builder_.add__major(_major); + builder_.add__experimental(_experimental); + return builder_.Finish(); +} + +struct TosaTensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TosaTensorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_SHAPE = 6, + VT_TYPE = 8, + VT_NPY_FILENAME = 10 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + const flatbuffers::Vector *shape() const { + return GetPointer *>(VT_SHAPE); + } + tosa::DType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } + const flatbuffers::String *npy_filename() const { + return GetPointer(VT_NPY_FILENAME); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_SHAPE) && + verifier.VerifyVector(shape()) && + VerifyField(verifier, VT_TYPE) && + VerifyOffset(verifier, VT_NPY_FILENAME) && + verifier.VerifyString(npy_filename()) && + verifier.EndTable(); + } +}; + +struct TosaTensorBuilder { + typedef TosaTensor Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(TosaTensor::VT_NAME, name); + } + void add_shape(flatbuffers::Offset> shape) { + fbb_.AddOffset(TosaTensor::VT_SHAPE, shape); + } + void add_type(tosa::DType type) { + fbb_.AddElement(TosaTensor::VT_TYPE, static_cast(type), 0); + } + void add_npy_filename(flatbuffers::Offset npy_filename) { + fbb_.AddOffset(TosaTensor::VT_NPY_FILENAME, npy_filename); + } + explicit TosaTensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TosaTensorBuilder &operator=(const TosaTensorBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTosaTensor( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0, + flatbuffers::Offset> shape = 0, + tosa::DType type = tosa::DType_UNKNOWN, + flatbuffers::Offset npy_filename = 0) { + TosaTensorBuilder builder_(_fbb); + builder_.add_npy_filename(npy_filename); + builder_.add_type(type); + builder_.add_shape(shape); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTosaTensorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + const std::vector *shape = nullptr, + tosa::DType type = tosa::DType_UNKNOWN, + const char *npy_filename = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; + auto npy_filename__ = npy_filename ? _fbb.CreateString(npy_filename) : 0; + return tosa::CreateTosaTensor( + _fbb, + name__, + shape__, + type, + npy_filename__); +} + +struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TosaOperatorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OP = 4, + VT_ATTRIBUTE_TYPE = 6, + VT_ATTRIBUTE = 8, + VT_INPUTS = 10, + VT_OUTPUTS = 12, + VT_QUANT_INFO_TYPE = 14, + VT_QUANT_INFO = 16 + }; + tosa::Op op() const { + return static_cast(GetField(VT_OP, 0)); + } + tosa::Attribute attribute_type() const { + return static_cast(GetField(VT_ATTRIBUTE_TYPE, 0)); + } + const void *attribute() const { + return GetPointer(VT_ATTRIBUTE); + } + template const T *attribute_as() const; + const tosa::Pool2dAttribute *attribute_as_Pool2dAttribute() const { + return attribute_type() == tosa::Attribute_Pool2dAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::Conv2dAttribute *attribute_as_Conv2dAttribute() const { + return attribute_type() == tosa::Attribute_Conv2dAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::TransposeConv2dAttribute *attribute_as_TransposeConv2dAttribute() const { + return attribute_type() == tosa::Attribute_TransposeConv2dAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::ReluNAttribute *attribute_as_ReluNAttribute() const { + return attribute_type() == tosa::Attribute_ReluNAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::AxisAttribute *attribute_as_AxisAttribute() const { + return attribute_type() == tosa::Attribute_AxisAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::ReshapeAttribute *attribute_as_ReshapeAttribute() const { + return attribute_type() == tosa::Attribute_ReshapeAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::SliceAttribute *attribute_as_SliceAttribute() const { + return attribute_type() == tosa::Attribute_SliceAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::TileAttribute *attribute_as_TileAttribute() const { + return attribute_type() == tosa::Attribute_TileAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::ResizeAttribute *attribute_as_ResizeAttribute() const { + return attribute_type() == tosa::Attribute_ResizeAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::ClampAttribute *attribute_as_ClampAttribute() const { + return attribute_type() == tosa::Attribute_ClampAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::RescaleAttribute *attribute_as_RescaleAttribute() const { + return attribute_type() == tosa::Attribute_RescaleAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::MulAttribute *attribute_as_MulAttribute() const { + return attribute_type() == tosa::Attribute_MulAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::ArithmeticRightShiftAttribute *attribute_as_ArithmeticRightShiftAttribute() const { + return attribute_type() == tosa::Attribute_ArithmeticRightShiftAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::CondIfAttribute *attribute_as_CondIfAttribute() const { + return attribute_type() == tosa::Attribute_CondIfAttribute ? static_cast(attribute()) : nullptr; + } + const tosa::WhileLoopAttribute *attribute_as_WhileLoopAttribute() const { + return attribute_type() == tosa::Attribute_WhileLoopAttribute ? static_cast(attribute()) : nullptr; + } + const flatbuffers::Vector> *inputs() const { + return GetPointer> *>(VT_INPUTS); + } + const flatbuffers::Vector> *outputs() const { + return GetPointer> *>(VT_OUTPUTS); + } + tosa::QuantInfo quant_info_type() const { + return static_cast(GetField(VT_QUANT_INFO_TYPE, 0)); + } + const void *quant_info() const { + return GetPointer(VT_QUANT_INFO); + } + template const T *quant_info_as() const; + const tosa::UnaryQuantInfo *quant_info_as_UnaryQuantInfo() const { + return quant_info_type() == tosa::QuantInfo_UnaryQuantInfo ? static_cast(quant_info()) : nullptr; + } + const tosa::ConvQuantInfo *quant_info_as_ConvQuantInfo() const { + return quant_info_type() == tosa::QuantInfo_ConvQuantInfo ? static_cast(quant_info()) : nullptr; + } + const tosa::MatMulQuantInfo *quant_info_as_MatMulQuantInfo() const { + return quant_info_type() == tosa::QuantInfo_MatMulQuantInfo ? static_cast(quant_info()) : nullptr; + } + const tosa::PadQuantInfo *quant_info_as_PadQuantInfo() const { + return quant_info_type() == tosa::QuantInfo_PadQuantInfo ? static_cast(quant_info()) : nullptr; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_OP) && + VerifyField(verifier, VT_ATTRIBUTE_TYPE) && + VerifyOffset(verifier, VT_ATTRIBUTE) && + VerifyAttribute(verifier, attribute(), attribute_type()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && + verifier.VerifyVectorOfStrings(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + verifier.VerifyVectorOfStrings(outputs()) && + VerifyField(verifier, VT_QUANT_INFO_TYPE) && + VerifyOffset(verifier, VT_QUANT_INFO) && + VerifyQuantInfo(verifier, quant_info(), quant_info_type()) && + verifier.EndTable(); + } +}; + +template<> inline const tosa::Pool2dAttribute *TosaOperator::attribute_as() const { + return attribute_as_Pool2dAttribute(); +} + +template<> inline const tosa::Conv2dAttribute *TosaOperator::attribute_as() const { + return attribute_as_Conv2dAttribute(); +} + +template<> inline const tosa::TransposeConv2dAttribute *TosaOperator::attribute_as() const { + return attribute_as_TransposeConv2dAttribute(); +} + +template<> inline const tosa::ReluNAttribute *TosaOperator::attribute_as() const { + return attribute_as_ReluNAttribute(); +} + +template<> inline const tosa::AxisAttribute *TosaOperator::attribute_as() const { + return attribute_as_AxisAttribute(); +} + +template<> inline const tosa::ReshapeAttribute *TosaOperator::attribute_as() const { + return attribute_as_ReshapeAttribute(); +} + +template<> inline const tosa::SliceAttribute *TosaOperator::attribute_as() const { + return attribute_as_SliceAttribute(); +} + +template<> inline const tosa::TileAttribute *TosaOperator::attribute_as() const { + return attribute_as_TileAttribute(); +} + +template<> inline const tosa::ResizeAttribute *TosaOperator::attribute_as() const { + return attribute_as_ResizeAttribute(); +} + +template<> inline const tosa::ClampAttribute *TosaOperator::attribute_as() const { + return attribute_as_ClampAttribute(); +} + +template<> inline const tosa::RescaleAttribute *TosaOperator::attribute_as() const { + return attribute_as_RescaleAttribute(); +} + +template<> inline const tosa::MulAttribute *TosaOperator::attribute_as() const { + return attribute_as_MulAttribute(); +} + +template<> inline const tosa::ArithmeticRightShiftAttribute *TosaOperator::attribute_as() const { + return attribute_as_ArithmeticRightShiftAttribute(); +} + +template<> inline const tosa::CondIfAttribute *TosaOperator::attribute_as() const { + return attribute_as_CondIfAttribute(); +} + +template<> inline const tosa::WhileLoopAttribute *TosaOperator::attribute_as() const { + return attribute_as_WhileLoopAttribute(); +} + +template<> inline const tosa::UnaryQuantInfo *TosaOperator::quant_info_as() const { + return quant_info_as_UnaryQuantInfo(); +} + +template<> inline const tosa::ConvQuantInfo *TosaOperator::quant_info_as() const { + return quant_info_as_ConvQuantInfo(); +} + +template<> inline const tosa::MatMulQuantInfo *TosaOperator::quant_info_as() const { + return quant_info_as_MatMulQuantInfo(); +} + +template<> inline const tosa::PadQuantInfo *TosaOperator::quant_info_as() const { + return quant_info_as_PadQuantInfo(); +} + +struct TosaOperatorBuilder { + typedef TosaOperator Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_op(tosa::Op op) { + fbb_.AddElement(TosaOperator::VT_OP, static_cast(op), 0); + } + void add_attribute_type(tosa::Attribute attribute_type) { + fbb_.AddElement(TosaOperator::VT_ATTRIBUTE_TYPE, static_cast(attribute_type), 0); + } + void add_attribute(flatbuffers::Offset attribute) { + fbb_.AddOffset(TosaOperator::VT_ATTRIBUTE, attribute); + } + void add_inputs(flatbuffers::Offset>> inputs) { + fbb_.AddOffset(TosaOperator::VT_INPUTS, inputs); + } + void add_outputs(flatbuffers::Offset>> outputs) { + fbb_.AddOffset(TosaOperator::VT_OUTPUTS, outputs); + } + void add_quant_info_type(tosa::QuantInfo quant_info_type) { + fbb_.AddElement(TosaOperator::VT_QUANT_INFO_TYPE, static_cast(quant_info_type), 0); + } + void add_quant_info(flatbuffers::Offset quant_info) { + fbb_.AddOffset(TosaOperator::VT_QUANT_INFO, quant_info); + } + explicit TosaOperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TosaOperatorBuilder &operator=(const TosaOperatorBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTosaOperator( + flatbuffers::FlatBufferBuilder &_fbb, + tosa::Op op = tosa::Op_UNKNOWN, + tosa::Attribute attribute_type = tosa::Attribute_NONE, + flatbuffers::Offset attribute = 0, + flatbuffers::Offset>> inputs = 0, + flatbuffers::Offset>> outputs = 0, + tosa::QuantInfo quant_info_type = tosa::QuantInfo_NONE, + flatbuffers::Offset quant_info = 0) { + TosaOperatorBuilder builder_(_fbb); + builder_.add_quant_info(quant_info); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_attribute(attribute); + builder_.add_op(op); + builder_.add_quant_info_type(quant_info_type); + builder_.add_attribute_type(attribute_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTosaOperatorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + tosa::Op op = tosa::Op_UNKNOWN, + tosa::Attribute attribute_type = tosa::Attribute_NONE, + flatbuffers::Offset attribute = 0, + const std::vector> *inputs = nullptr, + const std::vector> *outputs = nullptr, + tosa::QuantInfo quant_info_type = tosa::QuantInfo_NONE, + flatbuffers::Offset quant_info = 0) { + auto inputs__ = inputs ? _fbb.CreateVector>(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector>(*outputs) : 0; + return tosa::CreateTosaOperator( + _fbb, + op, + attribute_type, + attribute, + inputs__, + outputs__, + quant_info_type, + quant_info); +} + +struct TosaBasicBlock FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TosaBasicBlockBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_OPERATORS = 6, + VT_TENSORS = 8, + VT_INPUTS = 10, + VT_OUTPUTS = 12 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + const flatbuffers::Vector> *operators() const { + return GetPointer> *>(VT_OPERATORS); + } + const flatbuffers::Vector> *tensors() const { + return GetPointer> *>(VT_TENSORS); + } + const flatbuffers::Vector> *inputs() const { + return GetPointer> *>(VT_INPUTS); + } + const flatbuffers::Vector> *outputs() const { + return GetPointer> *>(VT_OUTPUTS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_OPERATORS) && + verifier.VerifyVector(operators()) && + verifier.VerifyVectorOfTables(operators()) && + VerifyOffset(verifier, VT_TENSORS) && + verifier.VerifyVector(tensors()) && + verifier.VerifyVectorOfTables(tensors()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && + verifier.VerifyVectorOfStrings(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + verifier.VerifyVectorOfStrings(outputs()) && + verifier.EndTable(); + } +}; + +struct TosaBasicBlockBuilder { + typedef TosaBasicBlock Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(TosaBasicBlock::VT_NAME, name); + } + void add_operators(flatbuffers::Offset>> operators) { + fbb_.AddOffset(TosaBasicBlock::VT_OPERATORS, operators); + } + void add_tensors(flatbuffers::Offset>> tensors) { + fbb_.AddOffset(TosaBasicBlock::VT_TENSORS, tensors); + } + void add_inputs(flatbuffers::Offset>> inputs) { + fbb_.AddOffset(TosaBasicBlock::VT_INPUTS, inputs); + } + void add_outputs(flatbuffers::Offset>> outputs) { + fbb_.AddOffset(TosaBasicBlock::VT_OUTPUTS, outputs); + } + explicit TosaBasicBlockBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TosaBasicBlockBuilder &operator=(const TosaBasicBlockBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTosaBasicBlock( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0, + flatbuffers::Offset>> operators = 0, + flatbuffers::Offset>> tensors = 0, + flatbuffers::Offset>> inputs = 0, + flatbuffers::Offset>> outputs = 0) { + TosaBasicBlockBuilder builder_(_fbb); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_tensors(tensors); + builder_.add_operators(operators); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTosaBasicBlockDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + const std::vector> *operators = nullptr, + const std::vector> *tensors = nullptr, + const std::vector> *inputs = nullptr, + const std::vector> *outputs = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto operators__ = operators ? _fbb.CreateVector>(*operators) : 0; + auto tensors__ = tensors ? _fbb.CreateVector>(*tensors) : 0; + auto inputs__ = inputs ? _fbb.CreateVector>(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector>(*outputs) : 0; + return tosa::CreateTosaBasicBlock( + _fbb, + name__, + operators__, + tensors__, + inputs__, + outputs__); +} + +struct TosaGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TosaGraphBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VERSION = 4, + VT_BLOCKS = 6 + }; + const tosa::Version *version() const { + return GetPointer(VT_VERSION); + } + const flatbuffers::Vector> *blocks() const { + return GetPointer> *>(VT_BLOCKS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_VERSION) && + verifier.VerifyTable(version()) && + VerifyOffset(verifier, VT_BLOCKS) && + verifier.VerifyVector(blocks()) && + verifier.VerifyVectorOfTables(blocks()) && + verifier.EndTable(); + } +}; + +struct TosaGraphBuilder { + typedef TosaGraph Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_version(flatbuffers::Offset version) { + fbb_.AddOffset(TosaGraph::VT_VERSION, version); + } + void add_blocks(flatbuffers::Offset>> blocks) { + fbb_.AddOffset(TosaGraph::VT_BLOCKS, blocks); + } + explicit TosaGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TosaGraphBuilder &operator=(const TosaGraphBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTosaGraph( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset version = 0, + flatbuffers::Offset>> blocks = 0) { + TosaGraphBuilder builder_(_fbb); + builder_.add_blocks(blocks); + builder_.add_version(version); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTosaGraphDirect( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset version = 0, + const std::vector> *blocks = nullptr) { + auto blocks__ = blocks ? _fbb.CreateVector>(*blocks) : 0; + return tosa::CreateTosaGraph( + _fbb, + version, + blocks__); +} + +inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type) { + switch (type) { + case Attribute_NONE: { + return true; + } + case Attribute_Pool2dAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_Conv2dAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_TransposeConv2dAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_ReluNAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_AxisAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_ReshapeAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_SliceAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_TileAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_ResizeAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_ClampAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_RescaleAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_MulAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_ArithmeticRightShiftAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_CondIfAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case Attribute_WhileLoopAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyAttribute( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline bool VerifyQuantInfo(flatbuffers::Verifier &verifier, const void *obj, QuantInfo type) { + switch (type) { + case QuantInfo_NONE: { + return true; + } + case QuantInfo_UnaryQuantInfo: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case QuantInfo_ConvQuantInfo: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case QuantInfo_MatMulQuantInfo: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case QuantInfo_PadQuantInfo: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyQuantInfoVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyQuantInfo( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline const tosa::TosaGraph *GetTosaGraph(const void *buf) { + return flatbuffers::GetRoot(buf); +} + +inline const tosa::TosaGraph *GetSizePrefixedTosaGraph(const void *buf) { + return flatbuffers::GetSizePrefixedRoot(buf); +} + +inline const char *TosaGraphIdentifier() { + return "TOSA"; +} + +inline bool TosaGraphBufferHasIdentifier(const void *buf) { + return flatbuffers::BufferHasIdentifier( + buf, TosaGraphIdentifier()); +} + +inline bool VerifyTosaGraphBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(TosaGraphIdentifier()); +} + +inline bool VerifySizePrefixedTosaGraphBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(TosaGraphIdentifier()); +} + +inline const char *TosaGraphExtension() { + return "tosa"; +} + +inline void FinishTosaGraphBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.Finish(root, TosaGraphIdentifier()); +} + +inline void FinishSizePrefixedTosaGraphBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root, TosaGraphIdentifier()); +} + +} // namespace tosa + +#endif // FLATBUFFERS_GENERATED_TOSA_TOSA_H_ diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h new file mode 100644 index 0000000..398590d --- /dev/null +++ b/include/tosa_serialization_handler.h @@ -0,0 +1,349 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _TOSA_SERIALIZATION_HANDLER_H +#define _TOSA_SERIALIZATION_HANDLER_H +#include "attribute.h" +#include "flatbuffers/idl.h" +#include "flatbuffers/util.h" +#include "numpy_utils.h" +#include "quant_info.h" +#include "tosa_generated.h" +#include +#include +#include +#include + +namespace tosa +{ + +enum tosa_err_t +{ + TOSA_OK, + TOSA_USER_ERROR, + TOSA_FILE_ERROR, + TOSA_MEMORY_ERROR, + TOSA_SCHEMA_MISSING, + TOSA_INTERNAL_ERROR, + TOSA_VERSION_MISMATCH, + NUM_TOSA_ERROR +}; + +struct TosaVersion +{ + int32_t _major; + int32_t _minor; + int32_t _patch; + bool _experimental; + bool _valid; + + TosaVersion() + { + _valid = false; + } + + TosaVersion(int32_t major, int32_t minor, int32_t patch, bool experimental) + { + set_version(major, minor, patch, experimental); + } + + void set_version(int32_t major, int32_t minor, int32_t patch, bool experimental) + { + _major = major; + _minor = minor; + _patch = patch; + _experimental = experimental; + _valid = true; + } + + std::string to_string() const + { + std::string str; + assert(_valid); + str += std::to_string(_major) + "."; + str += std::to_string(_minor) + "."; + str += std::to_string(_patch); + if (_experimental) + str += "(experimental)"; + return str; + }; + + bool operator==(const TosaVersion& rhs) + { + assert(_valid); + if (!_valid) + return false; + if (rhs._major == _major && rhs._minor == _minor && rhs._patch == _patch && rhs._experimental == _experimental) + { + return true; + } + return false; + } + + bool operator!=(const TosaVersion& rhs) + { + assert(_valid); + if (!_valid) + return true; + return !((*this) == rhs); + } +}; + +class TosaSerializationHandler; + +class TosaSerializationTensor +{ +public: + // constructor and destructor + TosaSerializationTensor(const flatbuffers::String* name, + const flatbuffers::Vector& shape, + DType dtype, + const flatbuffers::String* npy_filename); + TosaSerializationTensor(std::string& name, + const std::vector& shape, + DType dtype, + const std::string& npy_filename); + TosaSerializationTensor(); + ~TosaSerializationTensor(); + + // accessor + std::string GetName() const + { + return _name; + } + const std::vector& GetShape() const + { + return _shape; + } + DType GetDtype() + { + return _dtype; + } + const std::string& GetNpyFilePtr() const + { + return _npy_filename; + } + + // modifier + void SetDtype(DType dtype) + { + _dtype = dtype; + } + void SetName(std::string name) + { + _name = name; + } + +private: + DType _dtype; /* data type enumeration, see tosa_isa_generated.h */ + std::vector _shape; /* shape of the tensor */ + std::string _name; /* name of the tensor, used for solving dependency */ + std::string _npy_filename; /* numpy array filename if not null. so null is the distinguisher */ +}; + +class TosaSerializationOperator +{ +public: + // use default copy, void constructor + // constructor and destructor + TosaSerializationOperator(Op op, + Attribute attribute_type, + const TosaAttributeBase* attribute, + QuantInfo qinfo_type, + const TosaQuantInfoBase* qinfo, + std::vector input_tensor_names, + std::vector output_tensor_names); + ~TosaSerializationOperator(); + + // accessor + Op GetOp() const + { + return _op; + } + Attribute GetAttributeType() const + { + return _attribute_type; + } + TosaAttributeBase* GetAttribute() const + { + return _attribute; + } + QuantInfo GetQInfoType() const + { + return _qinfo_type; + } + TosaQuantInfoBase* GetQInfo() const + { + return _qinfo; + } + std::vector& GetInputTensorNames() + { + return _input_tensor_names; + } + std::vector& GetOutputTensorNames() + { + return _output_tensor_names; + } + +private: + Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */ + Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */ + TosaAttributeBase* _attribute; /* real attribute class goes here */ + QuantInfo _qinfo_type; /* QuantInfo enum */ + TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */ + std::vector _input_tensor_names; /* array of input tensor names */ + std::vector _output_tensor_names; /* array of output tensor names */ +}; + +class TosaSerializationBasicBlock +{ +public: + // constructor and destructor + TosaSerializationBasicBlock(std::string name, + std::vector operators, + std::vector tensors, + std::vector inputs, + std::vector outputs); + ~TosaSerializationBasicBlock(); + + // accessor + std::string GetName() const + { + return _name; + } + std::vector& GetOperators() + { + return _operators; + } + std::vector& GetTensors() + { + return _tensors; + } + + TosaSerializationTensor* GetTensorByName(std::string name) + { + TosaSerializationTensor* result = nullptr; + for (auto tensor : GetTensors()) + { + if (tensor->GetName() == name) + { + result = tensor; + break; + } + } + return result; + } + + std::vector& GetInputs() + { + return _inputs; + } + std::vector& GetOutputs() + { + return _outputs; + } + +private: + std::string _name; /* name of basic block */ + std::vector _operators; /* TosaSerializationOperator list */ + std::vector _tensors; /* TosaSerializationTensor list */ + std::vector _inputs; /* array of string to specify block inputs */ + std::vector _outputs; /* array of string to specify block outputs */ +}; + +/* + * this is a helper class for writing/reading Tosa ISA + * supported format: .tosa (flatbuffer), .json + * and provide high-level std::vector-like interface + * to access internal data structure + */ +class TosaSerializationHandler +{ +public: + // constructor and destructor + TosaSerializationHandler(); + ~TosaSerializationHandler(); + + // file io + tosa_err_t LoadFileJson(const char* filename); + tosa_err_t LoadFileTosaFlatbuffer(const char* filename); + tosa_err_t SaveFileJson(const char* filename); + tosa_err_t SaveFileTosaFlatbuffer(const char* filename); + tosa_err_t LoadFileSchema(const char* schema_filename); + + // version + const TosaVersion& GetTosaVersion() const + { + return _version; + } + + // accessor + std::vector& GetBlocks() + { + return _blocks; + } + + TosaSerializationBasicBlock* GetBlockByName(std::string name) + { + TosaSerializationBasicBlock* result = nullptr; + for (auto block : GetBlocks()) + { + if (block->GetName() == name) + { + result = block; + break; + } + } + return result; + } + TosaSerializationBasicBlock* GetMainBlock() + { + TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main")); + assert(main_block); + return main_block; + } + + std::vector& GetInputs() + { + return GetMainBlock()->GetInputs(); + } + std::vector& GetOutputs() + { + return GetMainBlock()->GetOutputs(); + } + + bool GetSchemaLoaded() const + { + return _schemaLoaded; + } + +protected: + tosa_err_t Clear(); + tosa_err_t InitWithBuf(const uint8_t* buf); + tosa_err_t FreezeBuilder(); + tosa_err_t SetTosaVersion(); + tosa_err_t CheckTosaVersion(const TosaVersion& read_version); + +private: + TosaVersion _version; /* tosa version */ + flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */ + flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */ + std::vector _blocks; /* array structure to store all TosaSerializationBasicBlock */ + bool _schemaLoaded; /* is the schema properly loaded? */ +}; + +} // namespace tosa + +#endif // _TOSA_SERIALIZATION_HANDLER_H diff --git a/python/tosa/ArithmeticRightShiftAttribute.py b/python/tosa/ArithmeticRightShiftAttribute.py new file mode 100644 index 0000000..cd19ab9 --- /dev/null +++ b/python/tosa/ArithmeticRightShiftAttribute.py @@ -0,0 +1,51 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class ArithmeticRightShiftAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsArithmeticRightShiftAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ArithmeticRightShiftAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def ArithmeticRightShiftAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # ArithmeticRightShiftAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ArithmeticRightShiftAttribute + def Round(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + +def ArithmeticRightShiftAttributeStart(builder): builder.StartObject(1) +def ArithmeticRightShiftAttributeAddRound(builder, round): builder.PrependBoolSlot(0, round, 0) +def ArithmeticRightShiftAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/Attribute.py b/python/tosa/Attribute.py new file mode 100644 index 0000000..e70c0ac --- /dev/null +++ b/python/tosa/Attribute.py @@ -0,0 +1,37 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +class Attribute(object): + NONE = 0 + Pool2dAttribute = 1 + Conv2dAttribute = 2 + TransposeConv2dAttribute = 3 + ReluNAttribute = 4 + AxisAttribute = 5 + ReshapeAttribute = 6 + SliceAttribute = 7 + TileAttribute = 8 + ResizeAttribute = 9 + ClampAttribute = 10 + RescaleAttribute = 11 + MulAttribute = 12 + ArithmeticRightShiftAttribute = 13 + CondIfAttribute = 14 + WhileLoopAttribute = 15 + diff --git a/python/tosa/AxisAttribute.py b/python/tosa/AxisAttribute.py new file mode 100644 index 0000000..6ac0053 --- /dev/null +++ b/python/tosa/AxisAttribute.py @@ -0,0 +1,51 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class AxisAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsAxisAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = AxisAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def AxisAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # AxisAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # AxisAttribute + def Axis(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def AxisAttributeStart(builder): builder.StartObject(1) +def AxisAttributeAddAxis(builder, axis): builder.PrependInt32Slot(0, axis, 0) +def AxisAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/ClampAttribute.py b/python/tosa/ClampAttribute.py new file mode 100644 index 0000000..d26ee34 --- /dev/null +++ b/python/tosa/ClampAttribute.py @@ -0,0 +1,75 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class ClampAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsClampAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ClampAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def ClampAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # ClampAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ClampAttribute + def MinInt(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # ClampAttribute + def MaxInt(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # ClampAttribute + def MinFp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # ClampAttribute + def MaxFp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + +def ClampAttributeStart(builder): builder.StartObject(4) +def ClampAttributeAddMinInt(builder, minInt): builder.PrependInt32Slot(0, minInt, 0) +def ClampAttributeAddMaxInt(builder, maxInt): builder.PrependInt32Slot(1, maxInt, 0) +def ClampAttributeAddMinFp(builder, minFp): builder.PrependFloat32Slot(2, minFp, 0.0) +def ClampAttributeAddMaxFp(builder, maxFp): builder.PrependFloat32Slot(3, maxFp, 0.0) +def ClampAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/CondIfAttribute.py b/python/tosa/CondIfAttribute.py new file mode 100644 index 0000000..d5d61c8 --- /dev/null +++ b/python/tosa/CondIfAttribute.py @@ -0,0 +1,59 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class CondIfAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsCondIfAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = CondIfAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def CondIfAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # CondIfAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # CondIfAttribute + def ThenBranch(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # CondIfAttribute + def ElseBranch(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + +def CondIfAttributeStart(builder): builder.StartObject(2) +def CondIfAttributeAddThenBranch(builder, thenBranch): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(thenBranch), 0) +def CondIfAttributeAddElseBranch(builder, elseBranch): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(elseBranch), 0) +def CondIfAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/Conv2dAttribute.py b/python/tosa/Conv2dAttribute.py new file mode 100644 index 0000000..c5ae257 --- /dev/null +++ b/python/tosa/Conv2dAttribute.py @@ -0,0 +1,130 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class Conv2dAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsConv2dAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Conv2dAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def Conv2dAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # Conv2dAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Conv2dAttribute + def Padding(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Conv2dAttribute + def PaddingAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Conv2dAttribute + def PaddingLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Conv2dAttribute + def PaddingIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # Conv2dAttribute + def Stride(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Conv2dAttribute + def StrideAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Conv2dAttribute + def StrideLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Conv2dAttribute + def StrideIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # Conv2dAttribute + def Dilation(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Conv2dAttribute + def DilationAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Conv2dAttribute + def DilationLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Conv2dAttribute + def DilationIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + +def Conv2dAttributeStart(builder): builder.StartObject(3) +def Conv2dAttributeAddPadding(builder, padding): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(padding), 0) +def Conv2dAttributeStartPaddingVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def Conv2dAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0) +def Conv2dAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def Conv2dAttributeAddDilation(builder, dilation): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(dilation), 0) +def Conv2dAttributeStartDilationVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def Conv2dAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/ConvQuantInfo.py b/python/tosa/ConvQuantInfo.py new file mode 100644 index 0000000..85651ee --- /dev/null +++ b/python/tosa/ConvQuantInfo.py @@ -0,0 +1,59 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class ConvQuantInfo(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsConvQuantInfo(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ConvQuantInfo() + x.Init(buf, n + offset) + return x + + @classmethod + def ConvQuantInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # ConvQuantInfo + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ConvQuantInfo + def InputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # ConvQuantInfo + def WeightZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def ConvQuantInfoStart(builder): builder.StartObject(2) +def ConvQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0) +def ConvQuantInfoAddWeightZp(builder, weightZp): builder.PrependInt32Slot(1, weightZp, 0) +def ConvQuantInfoEnd(builder): return builder.EndObject() diff --git a/python/tosa/DType.py b/python/tosa/DType.py new file mode 100644 index 0000000..2e30531 --- /dev/null +++ b/python/tosa/DType.py @@ -0,0 +1,30 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +class DType(object): + UNKNOWN = 0 + BOOL = 1 + UINT8 = 2 + INT4 = 3 + INT8 = 4 + INT16 = 5 + INT32 = 6 + INT48 = 7 + FLOAT = 8 + diff --git a/python/tosa/MatMulQuantInfo.py b/python/tosa/MatMulQuantInfo.py new file mode 100644 index 0000000..da882e0 --- /dev/null +++ b/python/tosa/MatMulQuantInfo.py @@ -0,0 +1,59 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class MatMulQuantInfo(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsMatMulQuantInfo(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = MatMulQuantInfo() + x.Init(buf, n + offset) + return x + + @classmethod + def MatMulQuantInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # MatMulQuantInfo + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # MatMulQuantInfo + def AZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # MatMulQuantInfo + def BZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def MatMulQuantInfoStart(builder): builder.StartObject(2) +def MatMulQuantInfoAddAZp(builder, aZp): builder.PrependInt32Slot(0, aZp, 0) +def MatMulQuantInfoAddBZp(builder, bZp): builder.PrependInt32Slot(1, bZp, 0) +def MatMulQuantInfoEnd(builder): return builder.EndObject() diff --git a/python/tosa/MulAttribute.py b/python/tosa/MulAttribute.py new file mode 100644 index 0000000..e123bd1 --- /dev/null +++ b/python/tosa/MulAttribute.py @@ -0,0 +1,51 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class MulAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsMulAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = MulAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def MulAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # MulAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # MulAttribute + def Shift(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def MulAttributeStart(builder): builder.StartObject(1) +def MulAttributeAddShift(builder, shift): builder.PrependInt32Slot(0, shift, 0) +def MulAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/Op.py b/python/tosa/Op.py new file mode 100644 index 0000000..584b00e --- /dev/null +++ b/python/tosa/Op.py @@ -0,0 +1,91 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +class Op(object): + UNKNOWN = 0 + ARGMAX = 1 + AVG_POOL2D = 2 + CONV2D = 3 + CONV3D = 4 + DEPTHWISE_CONV2D = 5 + FULLY_CONNECTED = 6 + MATMUL = 7 + MAX_POOL2D = 8 + TRANSPOSE_CONV2D = 9 + CLAMP = 10 + RELUN = 11 + SIGMOID = 12 + TANH = 13 + ADD = 14 + ARITHMETIC_RIGHT_SHIFT = 15 + BITWISE_AND = 16 + BITWISE_OR = 17 + BITWISE_XOR = 18 + LOGICAL_AND = 19 + LOGICAL_LEFT_SHIFT = 20 + LOGICAL_RIGHT_SHIFT = 21 + LOGICAL_OR = 22 + LOGICAL_XOR = 23 + MAXIMUM = 24 + MINIMUM = 25 + MUL = 26 + POW = 27 + SUB = 28 + TABLE = 29 + ABS = 30 + BITWISE_NOT = 31 + CEIL = 32 + CLZ = 33 + EXP = 34 + FLOOR = 35 + LOG = 36 + LOGICAL_NOT = 37 + NEGATE = 38 + RECIPROCAL = 39 + RSQRT = 40 + SELECT = 41 + EQUAL = 42 + GREATER = 43 + GREATER_EQUAL = 44 + REDUCE_ANY = 45 + REDUCE_ALL = 46 + REDUCE_MAX = 47 + REDUCE_MIN = 48 + REDUCE_PRODUCT = 49 + REDUCE_SUM = 50 + CONCAT = 51 + PAD = 52 + RESHAPE = 53 + REVERSE = 54 + SLICE = 55 + TILE = 56 + TRANSPOSE = 57 + GATHER = 58 + SCATTER = 59 + RESIZE = 60 + CAST = 61 + RESCALE = 62 + CONST = 63 + PLACEHOLDER = 64 + IDENTITY = 65 + IDENTITYN = 66 + CUSTOM = 67 + COND_IF = 68 + WHILE_LOOP = 69 + diff --git a/python/tosa/PadQuantInfo.py b/python/tosa/PadQuantInfo.py new file mode 100644 index 0000000..241d1e6 --- /dev/null +++ b/python/tosa/PadQuantInfo.py @@ -0,0 +1,51 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class PadQuantInfo(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsPadQuantInfo(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = PadQuantInfo() + x.Init(buf, n + offset) + return x + + @classmethod + def PadQuantInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # PadQuantInfo + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # PadQuantInfo + def InputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def PadQuantInfoStart(builder): builder.StartObject(1) +def PadQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0) +def PadQuantInfoEnd(builder): return builder.EndObject() diff --git a/python/tosa/Pool2dAttribute.py b/python/tosa/Pool2dAttribute.py new file mode 100644 index 0000000..72c09b4 --- /dev/null +++ b/python/tosa/Pool2dAttribute.py @@ -0,0 +1,130 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class Pool2dAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsPool2dAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Pool2dAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def Pool2dAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # Pool2dAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Pool2dAttribute + def Padding(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Pool2dAttribute + def PaddingAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Pool2dAttribute + def PaddingLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Pool2dAttribute + def PaddingIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # Pool2dAttribute + def Kernel(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Pool2dAttribute + def KernelAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Pool2dAttribute + def KernelLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Pool2dAttribute + def KernelIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # Pool2dAttribute + def Stride(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # Pool2dAttribute + def StrideAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # Pool2dAttribute + def StrideLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # Pool2dAttribute + def StrideIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + +def Pool2dAttributeStart(builder): builder.StartObject(3) +def Pool2dAttributeAddPadding(builder, padding): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(padding), 0) +def Pool2dAttributeStartPaddingVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def Pool2dAttributeAddKernel(builder, kernel): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernel), 0) +def Pool2dAttributeStartKernelVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def Pool2dAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0) +def Pool2dAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def Pool2dAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/QuantInfo.py b/python/tosa/QuantInfo.py new file mode 100644 index 0000000..0d77912 --- /dev/null +++ b/python/tosa/QuantInfo.py @@ -0,0 +1,26 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +class QuantInfo(object): + NONE = 0 + UnaryQuantInfo = 1 + ConvQuantInfo = 2 + MatMulQuantInfo = 3 + PadQuantInfo = 4 + diff --git a/python/tosa/ReluNAttribute.py b/python/tosa/ReluNAttribute.py new file mode 100644 index 0000000..b96d701 --- /dev/null +++ b/python/tosa/ReluNAttribute.py @@ -0,0 +1,59 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class ReluNAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsReluNAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ReluNAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def ReluNAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # ReluNAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ReluNAttribute + def MaxInt(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # ReluNAttribute + def MaxFp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + +def ReluNAttributeStart(builder): builder.StartObject(2) +def ReluNAttributeAddMaxInt(builder, maxInt): builder.PrependInt32Slot(0, maxInt, 0) +def ReluNAttributeAddMaxFp(builder, maxFp): builder.PrependFloat32Slot(1, maxFp, 0.0) +def ReluNAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/RescaleAttribute.py b/python/tosa/RescaleAttribute.py new file mode 100644 index 0000000..238f971 --- /dev/null +++ b/python/tosa/RescaleAttribute.py @@ -0,0 +1,141 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class RescaleAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsRescaleAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = RescaleAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def RescaleAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # RescaleAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # RescaleAttribute + def InputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # RescaleAttribute + def OutputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # RescaleAttribute + def Multiplier(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # RescaleAttribute + def MultiplierAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # RescaleAttribute + def MultiplierLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # RescaleAttribute + def MultiplierIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # RescaleAttribute + def Shift(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # RescaleAttribute + def ShiftAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # RescaleAttribute + def ShiftLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # RescaleAttribute + def ShiftIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + return o == 0 + + # RescaleAttribute + def Scale32(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # RescaleAttribute + def DoubleRound(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + + # RescaleAttribute + def PerChannel(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + +def RescaleAttributeStart(builder): builder.StartObject(7) +def RescaleAttributeAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0) +def RescaleAttributeAddOutputZp(builder, outputZp): builder.PrependInt32Slot(1, outputZp, 0) +def RescaleAttributeAddMultiplier(builder, multiplier): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(multiplier), 0) +def RescaleAttributeStartMultiplierVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def RescaleAttributeAddShift(builder, shift): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(shift), 0) +def RescaleAttributeStartShiftVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def RescaleAttributeAddScale32(builder, scale32): builder.PrependBoolSlot(4, scale32, 0) +def RescaleAttributeAddDoubleRound(builder, doubleRound): builder.PrependBoolSlot(5, doubleRound, 0) +def RescaleAttributeAddPerChannel(builder, perChannel): builder.PrependBoolSlot(6, perChannel, 0) +def RescaleAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/ReshapeAttribute.py b/python/tosa/ReshapeAttribute.py new file mode 100644 index 0000000..7836f49 --- /dev/null +++ b/python/tosa/ReshapeAttribute.py @@ -0,0 +1,72 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class ReshapeAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsReshapeAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ReshapeAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def ReshapeAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # ReshapeAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ReshapeAttribute + def Shape(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # ReshapeAttribute + def ShapeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # ReshapeAttribute + def ShapeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ReshapeAttribute + def ShapeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + +def ReshapeAttributeStart(builder): builder.StartObject(1) +def ReshapeAttributeAddShape(builder, shape): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0) +def ReshapeAttributeStartShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ReshapeAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/ResizeAttribute.py b/python/tosa/ResizeAttribute.py new file mode 100644 index 0000000..6e5f259 --- /dev/null +++ b/python/tosa/ResizeAttribute.py @@ -0,0 +1,204 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class ResizeAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsResizeAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ResizeAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def ResizeAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # ResizeAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # ResizeAttribute + def OutputSize(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # ResizeAttribute + def OutputSizeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # ResizeAttribute + def OutputSizeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ResizeAttribute + def OutputSizeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # ResizeAttribute + def Stride(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # ResizeAttribute + def StrideAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # ResizeAttribute + def StrideLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ResizeAttribute + def StrideIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # ResizeAttribute + def Offset(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # ResizeAttribute + def OffsetAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # ResizeAttribute + def OffsetLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ResizeAttribute + def OffsetIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # ResizeAttribute + def Shift(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # ResizeAttribute + def StrideFp(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Float32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # ResizeAttribute + def StrideFpAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Float32Flags, o) + return 0 + + # ResizeAttribute + def StrideFpLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ResizeAttribute + def StrideFpIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + + # ResizeAttribute + def OffsetFp(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Float32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # ResizeAttribute + def OffsetFpAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Float32Flags, o) + return 0 + + # ResizeAttribute + def OffsetFpLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ResizeAttribute + def OffsetFpIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + return o == 0 + + # ResizeAttribute + def Mode(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + +def ResizeAttributeStart(builder): builder.StartObject(7) +def ResizeAttributeAddOutputSize(builder, outputSize): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(outputSize), 0) +def ResizeAttributeStartOutputSizeVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ResizeAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0) +def ResizeAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ResizeAttributeAddOffset(builder, offset): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(offset), 0) +def ResizeAttributeStartOffsetVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ResizeAttributeAddShift(builder, shift): builder.PrependInt32Slot(3, shift, 0) +def ResizeAttributeAddStrideFp(builder, strideFp): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(strideFp), 0) +def ResizeAttributeStartStrideFpVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ResizeAttributeAddOffsetFp(builder, offsetFp): builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(offsetFp), 0) +def ResizeAttributeStartOffsetFpVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ResizeAttributeAddMode(builder, mode): builder.PrependUint32Slot(6, mode, 0) +def ResizeAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/ResizeMode.py b/python/tosa/ResizeMode.py new file mode 100644 index 0000000..121fb0d --- /dev/null +++ b/python/tosa/ResizeMode.py @@ -0,0 +1,24 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +class ResizeMode(object): + UNKNOWN = 0 + NEAREST = 1 + BILINEAR = 2 + diff --git a/python/tosa/SliceAttribute.py b/python/tosa/SliceAttribute.py new file mode 100644 index 0000000..d3bf491 --- /dev/null +++ b/python/tosa/SliceAttribute.py @@ -0,0 +1,101 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class SliceAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsSliceAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SliceAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def SliceAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # SliceAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # SliceAttribute + def Begin(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # SliceAttribute + def BeginAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # SliceAttribute + def BeginLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # SliceAttribute + def BeginIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # SliceAttribute + def Size(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # SliceAttribute + def SizeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # SliceAttribute + def SizeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # SliceAttribute + def SizeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + +def SliceAttributeStart(builder): builder.StartObject(2) +def SliceAttributeAddBegin(builder, begin): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(begin), 0) +def SliceAttributeStartBeginVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def SliceAttributeAddSize(builder, size): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(size), 0) +def SliceAttributeStartSizeVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def SliceAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/TileAttribute.py b/python/tosa/TileAttribute.py new file mode 100644 index 0000000..5b4a02d --- /dev/null +++ b/python/tosa/TileAttribute.py @@ -0,0 +1,72 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class TileAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsTileAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TileAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def TileAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # TileAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # TileAttribute + def Multiples(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # TileAttribute + def MultiplesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # TileAttribute + def MultiplesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TileAttribute + def MultiplesIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + +def TileAttributeStart(builder): builder.StartObject(1) +def TileAttributeAddMultiples(builder, multiples): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(multiples), 0) +def TileAttributeStartMultiplesVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TileAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/TosaBasicBlock.py b/python/tosa/TosaBasicBlock.py new file mode 100644 index 0000000..324f33f --- /dev/null +++ b/python/tosa/TosaBasicBlock.py @@ -0,0 +1,149 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class TosaBasicBlock(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsTosaBasicBlock(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TosaBasicBlock() + x.Init(buf, n + offset) + return x + + @classmethod + def TosaBasicBlockBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # TosaBasicBlock + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # TosaBasicBlock + def Name(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # TosaBasicBlock + def Operators(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from tosa.TosaOperator import TosaOperator + obj = TosaOperator() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # TosaBasicBlock + def OperatorsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TosaBasicBlock + def OperatorsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # TosaBasicBlock + def Tensors(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from tosa.TosaTensor import TosaTensor + obj = TosaTensor() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # TosaBasicBlock + def TensorsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TosaBasicBlock + def TensorsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # TosaBasicBlock + def Inputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # TosaBasicBlock + def InputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TosaBasicBlock + def InputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + return o == 0 + + # TosaBasicBlock + def Outputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # TosaBasicBlock + def OutputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TosaBasicBlock + def OutputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + +def TosaBasicBlockStart(builder): builder.StartObject(5) +def TosaBasicBlockAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) +def TosaBasicBlockAddOperators(builder, operators): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(operators), 0) +def TosaBasicBlockStartOperatorsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TosaBasicBlockAddTensors(builder, tensors): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(tensors), 0) +def TosaBasicBlockStartTensorsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TosaBasicBlockAddInputs(builder, inputs): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0) +def TosaBasicBlockStartInputsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TosaBasicBlockAddOutputs(builder, outputs): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0) +def TosaBasicBlockStartOutputsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TosaBasicBlockEnd(builder): return builder.EndObject() diff --git a/python/tosa/TosaGraph.py b/python/tosa/TosaGraph.py new file mode 100644 index 0000000..42e2702 --- /dev/null +++ b/python/tosa/TosaGraph.py @@ -0,0 +1,82 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class TosaGraph(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsTosaGraph(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TosaGraph() + x.Init(buf, n + offset) + return x + + @classmethod + def TosaGraphBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # TosaGraph + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # TosaGraph + def Version(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + x = self._tab.Indirect(o + self._tab.Pos) + from tosa.Version import Version + obj = Version() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # TosaGraph + def Blocks(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from tosa.TosaBasicBlock import TosaBasicBlock + obj = TosaBasicBlock() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # TosaGraph + def BlocksLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TosaGraph + def BlocksIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + +def TosaGraphStart(builder): builder.StartObject(2) +def TosaGraphAddVersion(builder, version): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(version), 0) +def TosaGraphAddBlocks(builder, blocks): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(blocks), 0) +def TosaGraphStartBlocksVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TosaGraphEnd(builder): return builder.EndObject() diff --git a/python/tosa/TosaOperator.py b/python/tosa/TosaOperator.py new file mode 100644 index 0000000..998357f --- /dev/null +++ b/python/tosa/TosaOperator.py @@ -0,0 +1,133 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class TosaOperator(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsTosaOperator(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TosaOperator() + x.Init(buf, n + offset) + return x + + @classmethod + def TosaOperatorBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # TosaOperator + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # TosaOperator + def Op(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + + # TosaOperator + def AttributeType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) + return 0 + + # TosaOperator + def Attribute(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + from flatbuffers.table import Table + obj = Table(bytearray(), 0) + self._tab.Union(obj, o) + return obj + return None + + # TosaOperator + def Inputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # TosaOperator + def InputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TosaOperator + def InputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + return o == 0 + + # TosaOperator + def Outputs(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return "" + + # TosaOperator + def OutputsLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TosaOperator + def OutputsIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + + # TosaOperator + def QuantInfoType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) + return 0 + + # TosaOperator + def QuantInfo(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + from flatbuffers.table import Table + obj = Table(bytearray(), 0) + self._tab.Union(obj, o) + return obj + return None + +def TosaOperatorStart(builder): builder.StartObject(7) +def TosaOperatorAddOp(builder, op): builder.PrependUint32Slot(0, op, 0) +def TosaOperatorAddAttributeType(builder, attributeType): builder.PrependUint8Slot(1, attributeType, 0) +def TosaOperatorAddAttribute(builder, attribute): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(attribute), 0) +def TosaOperatorAddInputs(builder, inputs): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0) +def TosaOperatorStartInputsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TosaOperatorAddOutputs(builder, outputs): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0) +def TosaOperatorStartOutputsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TosaOperatorAddQuantInfoType(builder, quantInfoType): builder.PrependUint8Slot(5, quantInfoType, 0) +def TosaOperatorAddQuantInfo(builder, quantInfo): builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(quantInfo), 0) +def TosaOperatorEnd(builder): return builder.EndObject() diff --git a/python/tosa/TosaTensor.py b/python/tosa/TosaTensor.py new file mode 100644 index 0000000..760c091 --- /dev/null +++ b/python/tosa/TosaTensor.py @@ -0,0 +1,96 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class TosaTensor(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsTosaTensor(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TosaTensor() + x.Init(buf, n + offset) + return x + + @classmethod + def TosaTensorBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # TosaTensor + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # TosaTensor + def Name(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # TosaTensor + def Shape(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # TosaTensor + def ShapeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # TosaTensor + def ShapeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TosaTensor + def ShapeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # TosaTensor + def Type(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + + # TosaTensor + def NpyFilename(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + +def TosaTensorStart(builder): builder.StartObject(4) +def TosaTensorAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) +def TosaTensorAddShape(builder, shape): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0) +def TosaTensorStartShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TosaTensorAddType(builder, type): builder.PrependUint32Slot(2, type, 0) +def TosaTensorAddNpyFilename(builder, npyFilename): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(npyFilename), 0) +def TosaTensorEnd(builder): return builder.EndObject() diff --git a/python/tosa/TransposeConv2dAttribute.py b/python/tosa/TransposeConv2dAttribute.py new file mode 100644 index 0000000..02af44d --- /dev/null +++ b/python/tosa/TransposeConv2dAttribute.py @@ -0,0 +1,159 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class TransposeConv2dAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsTransposeConv2dAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TransposeConv2dAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def TransposeConv2dAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # TransposeConv2dAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # TransposeConv2dAttribute + def Outpad(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # TransposeConv2dAttribute + def OutpadAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # TransposeConv2dAttribute + def OutpadLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TransposeConv2dAttribute + def OutpadIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # TransposeConv2dAttribute + def Stride(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # TransposeConv2dAttribute + def StrideAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # TransposeConv2dAttribute + def StrideLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TransposeConv2dAttribute + def StrideIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + + # TransposeConv2dAttribute + def Dilation(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # TransposeConv2dAttribute + def DilationAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # TransposeConv2dAttribute + def DilationLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TransposeConv2dAttribute + def DilationIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + return o == 0 + + # TransposeConv2dAttribute + def OutputShape(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # TransposeConv2dAttribute + def OutputShapeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) + return 0 + + # TransposeConv2dAttribute + def OutputShapeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TransposeConv2dAttribute + def OutputShapeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + return o == 0 + +def TransposeConv2dAttributeStart(builder): builder.StartObject(4) +def TransposeConv2dAttributeAddOutpad(builder, outpad): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(outpad), 0) +def TransposeConv2dAttributeStartOutpadVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TransposeConv2dAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0) +def TransposeConv2dAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TransposeConv2dAttributeAddDilation(builder, dilation): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(dilation), 0) +def TransposeConv2dAttributeStartDilationVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TransposeConv2dAttributeAddOutputShape(builder, outputShape): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(outputShape), 0) +def TransposeConv2dAttributeStartOutputShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def TransposeConv2dAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/UnaryQuantInfo.py b/python/tosa/UnaryQuantInfo.py new file mode 100644 index 0000000..b0c4083 --- /dev/null +++ b/python/tosa/UnaryQuantInfo.py @@ -0,0 +1,59 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class UnaryQuantInfo(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsUnaryQuantInfo(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = UnaryQuantInfo() + x.Init(buf, n + offset) + return x + + @classmethod + def UnaryQuantInfoBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # UnaryQuantInfo + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # UnaryQuantInfo + def InputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # UnaryQuantInfo + def OutputZp(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + +def UnaryQuantInfoStart(builder): builder.StartObject(2) +def UnaryQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0) +def UnaryQuantInfoAddOutputZp(builder, outputZp): builder.PrependInt32Slot(1, outputZp, 0) +def UnaryQuantInfoEnd(builder): return builder.EndObject() diff --git a/python/tosa/Version.py b/python/tosa/Version.py new file mode 100644 index 0000000..2aeab9b --- /dev/null +++ b/python/tosa/Version.py @@ -0,0 +1,75 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class Version(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsVersion(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Version() + x.Init(buf, n + offset) + return x + + @classmethod + def VersionBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # Version + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Version + def _major(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # Version + def _minor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 21 + + # Version + def _patch(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # Version + def _experimental(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + +def VersionStart(builder): builder.StartObject(4) +def VersionAdd_major(builder, Major): builder.PrependInt32Slot(0, Major, 0) +def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 21) +def VersionAdd_patch(builder, Patch): builder.PrependInt32Slot(2, Patch, 0) +def VersionAdd_experimental(builder, Experimental): builder.PrependBoolSlot(3, Experimental, 0) +def VersionEnd(builder): return builder.EndObject() diff --git a/python/tosa/WhileLoopAttribute.py b/python/tosa/WhileLoopAttribute.py new file mode 100644 index 0000000..3e10d48 --- /dev/null +++ b/python/tosa/WhileLoopAttribute.py @@ -0,0 +1,59 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class WhileLoopAttribute(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsWhileLoopAttribute(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = WhileLoopAttribute() + x.Init(buf, n + offset) + return x + + @classmethod + def WhileLoopAttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # WhileLoopAttribute + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # WhileLoopAttribute + def CondBranch(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # WhileLoopAttribute + def BodyBranch(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + +def WhileLoopAttributeStart(builder): builder.StartObject(2) +def WhileLoopAttributeAddCondBranch(builder, condBranch): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(condBranch), 0) +def WhileLoopAttributeAddBodyBranch(builder, bodyBranch): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(bodyBranch), 0) +def WhileLoopAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/__init__.py b/python/tosa/__init__.py new file mode 100644 index 0000000..69a6bc4 --- /dev/null +++ b/python/tosa/__init__.py @@ -0,0 +1,15 @@ + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/regenerate_headers.sh b/regenerate_headers.sh new file mode 100755 index 0000000..a7d2141 --- /dev/null +++ b/regenerate_headers.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +if test -f "third_party/flatbuffers/flatc"; +then + echo "Found flatc, skip building..." +else + echo "flatc not found, building now..." + pushd third_party/flatbuffers/ + cmake . + make flatc -j8 + popd +fi + +pushd include/ + ../third_party/flatbuffers/flatc --cpp ../schema/tosa.fbs +popd +pushd python/ + ../third_party/flatbuffers/flatc --python ../schema/tosa.fbs +popd + diff --git a/schema/tosa.fbs b/schema/tosa.fbs new file mode 100644 index 0000000..c02154d --- /dev/null +++ b/schema/tosa.fbs @@ -0,0 +1,307 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace tosa; + +// This corresponds to the version. +file_identifier "TOSA"; +// File extension of any written files. +file_extension "tosa"; + +enum DType:uint32 { + UNKNOWN = 0, + BOOL, + UINT8, + INT4, + INT8, + INT16, + INT32, + INT48, + FLOAT, +} + +enum ResizeMode:uint32 { + UNKNOWN = 0, + NEAREST, + BILINEAR, +} + +enum Op:uint32 { + UNKNOWN = 0, + + // Tensor Operator + ARGMAX, + AVG_POOL2D, + CONV2D, + CONV3D, + DEPTHWISE_CONV2D, + FULLY_CONNECTED, + MATMUL, + MAX_POOL2D, + TRANSPOSE_CONV2D, + + // Activation + CLAMP, + RELUN, + SIGMOID, + TANH, + + // Elementwise-Binary + ADD, + ARITHMETIC_RIGHT_SHIFT, + BITWISE_AND, + BITWISE_OR, + BITWISE_XOR, + LOGICAL_AND, + LOGICAL_LEFT_SHIFT, + LOGICAL_RIGHT_SHIFT, + LOGICAL_OR, + LOGICAL_XOR, + MAXIMUM, + MINIMUM, + MUL, + POW, + SUB, + TABLE, + + // Elementwise-Unary + ABS, + BITWISE_NOT, + CEIL, + CLZ, + EXP, + FLOOR, + LOG, + LOGICAL_NOT, + NEGATE, + RECIPROCAL, + RSQRT, + + // Elementwise-Ternary + SELECT, + + // Logical + EQUAL, + GREATER, + GREATER_EQUAL, + + // Reduction + REDUCE_ANY, + REDUCE_ALL, + REDUCE_MAX, + REDUCE_MIN, + REDUCE_PRODUCT, + REDUCE_SUM, + + // Data layout operation + CONCAT, + PAD, + RESHAPE, + REVERSE, + SLICE, + TILE, + TRANSPOSE, + + // Gather/scatter operation + GATHER, + SCATTER, + + // Image + RESIZE, + + // Type conversion + CAST, + RESCALE, + + // Data Nodes + CONST, + PLACEHOLDER, + IDENTITY, + IDENTITYN, + + // Custom operations + CUSTOM, + + // Control flow operators + COND_IF, + WHILE_LOOP, +} + +union Attribute { + Pool2dAttribute, + Conv2dAttribute, + TransposeConv2dAttribute, + ReluNAttribute, + AxisAttribute, + ReshapeAttribute, + SliceAttribute, + TileAttribute, + ResizeAttribute, + ClampAttribute, + RescaleAttribute, + MulAttribute, + ArithmeticRightShiftAttribute, + CondIfAttribute, + WhileLoopAttribute, +} + +table Pool2dAttribute { + padding: [int32]; + kernel: [int32]; + stride: [int32]; +} + +table Conv2dAttribute { + padding: [int32]; + stride: [int32]; + dilation: [int32]; +} + +table TransposeConv2dAttribute { + outpad: [int32]; + stride: [int32]; + dilation: [int32]; + output_shape: [int32]; +} + +table ReluNAttribute { + max_int: int32; + max_fp: float; +} + +table AxisAttribute { + axis: int32; +} + +table ReshapeAttribute { + shape: [int32]; +} + +table SliceAttribute { + begin: [int32]; + size: [int32]; +} + +table TileAttribute { + multiples: [int32]; +} + +table ResizeAttribute { + output_size: [int32]; + stride: [int32]; + offset: [int32]; + shift: int32; + stride_fp: [float]; + offset_fp: [float]; + mode: ResizeMode; +} + +table ClampAttribute { + min_int: int32; + max_int: int32; + min_fp: float; + max_fp: float; +} + +table RescaleAttribute { + input_zp: int32; + output_zp: int32; + multiplier: [int32]; + shift: [int32]; + scale32: bool; + double_round: bool; + per_channel: bool; +} + +table MulAttribute { + shift: int32; +} + +table ArithmeticRightShiftAttribute { + round: bool; +} + +table CondIfAttribute { + then_branch: string; + else_branch: string; +} + +table WhileLoopAttribute { + cond_branch: string; + body_branch: string; +} + +union QuantInfo { + UnaryQuantInfo, + ConvQuantInfo, + MatMulQuantInfo, + PadQuantInfo, +} + +table UnaryQuantInfo { + input_zp: int32; + output_zp: int32; +} + +table ConvQuantInfo { + input_zp: int32; + weight_zp: int32; +} + +table MatMulQuantInfo { + a_zp: int32; + b_zp: int32; +} + +table PadQuantInfo { + input_zp: int32; +} + +table Version { + _major: int32 = 0; + _minor: int32 = 21; + _patch: int32 = 0; + _experimental: bool = false; +} + +table TosaTensor { + name:string; // name of the tensor, used for solving dependency + shape:[int32]; // shape of the tensor + type:DType; // data type of the tensor + npy_filename: string; // numpy array filename +} + +table TosaOperator { + op:Op; // operator enum + attribute: Attribute; // union structure. operator attribute + inputs:[string]; // list of input tensor names + outputs:[string]; // list of output tensor names + quant_info: QuantInfo; // op-based quantization information +} + +table TosaBasicBlock { + name:string; // basic block name + operators:[TosaOperator]; // operators array + tensors:[TosaTensor]; // tensors array + inputs:[string]; // name of graph inputs + outputs:[string]; // name of graph outputs +} + +table TosaGraph { + version: Version; + blocks:[TosaBasicBlock]; // basic blocks array +} + +root_type TosaGraph; diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp new file mode 100644 index 0000000..e438235 --- /dev/null +++ b/src/numpy_utils.cpp @@ -0,0 +1,415 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "numpy_utils.h" + +// Magic NUMPY header +static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{"; +static const int NUMPY_HEADER_SZ = 128; + +NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf) +{ + const char dtype_str[] = "'|b1'"; + return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true); +} + +NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf) +{ + const char dtype_str[] = "'(databuf); + for (uint32_t i = 0; i < elems; i++) + { + int val = fgetc(infile); + + if (val == EOF) + { + rc = FILE_IO_ERROR; + } + + buf[i] = val; + } + } + else + { + // Now we are at the beginning of the data + // Parse based on the datatype and number of dimensions + if (fread(databuf, elementsize, elems, infile) != elems) + { + rc = FILE_IO_ERROR; + } + } + } + + if (infile) + fclose(infile); + + return rc; +} + +NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str) +{ + char buf[NUMPY_HEADER_SZ + 1]; + char* ptr = nullptr; + NPError rc = NO_ERROR; + bool foundFormat = false; + bool foundOrder = false; + bool foundShape = false; + bool fortranOrder = false; + std::vector shape; + uint32_t totalElems = 1; + char* outer_end = NULL; + + assert(infile); + assert(elems > 0); + + if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1) + { + return HEADER_PARSE_ERROR; + } + + if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1)) + { + return HEADER_PARSE_ERROR; + } + + ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end); + + // Read in the data type, order, and shape + while (ptr && (!foundFormat || !foundOrder || !foundShape)) + { + + // End of string? + if (!ptr) + break; + + // Skip whitespace + while (isspace(*ptr)) + ptr++; + + // Parse the dictionary field name + if (!strcmp(ptr, "'descr'")) + { + ptr = strtok_r(NULL, ",", &outer_end); + if (!ptr) + break; + + while (isspace(*ptr)) + ptr++; + + if (strcmp(ptr, dtype_str)) + { + return FILE_TYPE_MISMATCH; + } + + foundFormat = true; + } + else if (!strcmp(ptr, "'fortran_order'")) + { + ptr = strtok_r(NULL, ",", &outer_end); + if (!ptr) + break; + + while (isspace(*ptr)) + ptr++; + + if (!strcmp(ptr, "False")) + { + fortranOrder = false; + } + else + { + return FILE_TYPE_MISMATCH; + } + + foundOrder = true; + } + else if (!strcmp(ptr, "'shape'")) + { + + ptr = strtok_r(NULL, "(", &outer_end); + if (!ptr) + break; + ptr = strtok_r(NULL, ")", &outer_end); + if (!ptr) + break; + + while (isspace(*ptr)) + ptr++; + + // The shape contains N comma-separated integers. Read up to 4. + char* end = NULL; + + ptr = strtok_r(ptr, ",", &end); + for (int i = 0; i < 4; i++) + { + // Out of dimensions + if (!ptr) + break; + + int dim = atoi(ptr); + + // Dimension is 0 + if (dim == 0) + break; + + shape.push_back(dim); + totalElems *= dim; + ptr = strtok_r(NULL, ",", &end); + } + + foundShape = true; + } + else + { + return HEADER_PARSE_ERROR; + } + + if (!ptr) + break; + + ptr = strtok_r(NULL, ":", &outer_end); + } + + if (!foundShape || !foundFormat || !foundOrder) + { + return HEADER_PARSE_ERROR; + } + + // Validate header + if (fortranOrder) + { + return FILE_TYPE_MISMATCH; + } + + if (totalElems != elems) + { + return BUFFER_SIZE_MISMATCH; + } + + // Go back to the begininng and read until the end of the header dictionary + rewind(infile); + int val; + + do + { + val = fgetc(infile); + } while (val != EOF && val != '\n'); + + return rc; +} + +NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf) +{ + std::vector shape = { (int32_t)elems }; + return writeToNpyFile(filename, shape, databuf); +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const std::vector& shape, const bool* databuf) +{ + const char dtype_str[] = "'|b1'"; + return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true); // bools written as size 1 +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf) +{ + std::vector shape = { (int32_t)elems }; + return writeToNpyFile(filename, shape, databuf); +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const std::vector& shape, const int32_t* databuf) +{ + const char dtype_str[] = "' shape = { (int32_t)elems }; + return writeToNpyFile(filename, shape, databuf); +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const std::vector& shape, const int64_t* databuf) +{ + const char dtype_str[] = "' shape = { (int32_t)elems }; + return writeToNpyFile(filename, shape, databuf); +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const std::vector& shape, const float* databuf) +{ + const char dtype_str[] = "'& shape, + const void* databuf, + bool bool_translate) +{ + FILE* outfile = nullptr; + NPError rc = NO_ERROR; + uint32_t totalElems = 1; + + assert(filename); + assert(shape.size() >= 0); + assert(databuf); + + outfile = fopen(filename, "wb"); + + if (!outfile) + { + return FILE_NOT_FOUND; + } + + for (uint32_t i = 0; i < shape.size(); i++) + { + totalElems *= shape[i]; + } + + rc = writeNpyHeader(outfile, shape, dtype_str); + + if (rc == NO_ERROR) + { + if (bool_translate) + { + // Numpy save format stores booleans as a byte array + // with one byte per boolean. This somewhat inefficiently + // remaps from system bool[] to this format. + const bool* buf = reinterpret_cast(databuf); + for (uint32_t i = 0; i < totalElems; i++) + { + int val = buf[i] ? 1 : 0; + if (fputc(val, outfile) == EOF) + { + rc = FILE_IO_ERROR; + } + } + } + else + { + if (fwrite(databuf, elementsize, totalElems, outfile) != totalElems) + { + rc = FILE_IO_ERROR; + } + } + } + + if (outfile) + fclose(outfile); + + return rc; +} + +NumpyUtilities::NPError + NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector& shape, const char* dtype_str) +{ + NPError rc = NO_ERROR; + uint32_t i; + char header[NUMPY_HEADER_SZ + 1]; + int headerPos = 0; + + assert(outfile); + assert(shape.size() >= 0); + + // Space-fill the header and end with a newline to start per numpy spec + memset(header, 0x20, NUMPY_HEADER_SZ); + header[NUMPY_HEADER_SZ - 1] = '\n'; + header[NUMPY_HEADER_SZ] = 0; + + // Write out the hard-coded header. We only support a 128-byte 1.0 header + // for now, which should be sufficient for simple tensor types of any + // reasonable rank. + memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1); + headerPos += sizeof(NUMPY_HEADER_STR) - 1; + + // Output the format dictionary + // Hard-coded for I32 for now + headerPos += + snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,", + dtype_str, shape.empty() ? 1 : shape[0]); + + // Remainder of shape array + for (i = 1; i < shape.size(); i++) + { + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]); + } + + // Close off the dictionary + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }"); + + // snprintf leaves a NULL at the end. Replace with a space + header[headerPos] = 0x20; + + if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1) + { + rc = FILE_IO_ERROR; + } + + return rc; +} diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp new file mode 100644 index 0000000..4fe152f --- /dev/null +++ b/src/tosa_serialization_handler.cpp @@ -0,0 +1,762 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tosa_serialization_handler.h" + +#include +using namespace tosa; + +TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name, + const flatbuffers::Vector& shape, + DType dtype, + const flatbuffers::String* npy_filename) +{ + _dtype = dtype; + + std::copy(shape.begin(), shape.end(), std::back_inserter(_shape)); + + assert(name); + _name = name->str(); + + if (npy_filename) + { + _npy_filename = npy_filename->str(); + } +} + +TosaSerializationTensor::TosaSerializationTensor(std::string& name, + const std::vector& shape, + DType dtype, + const std::string& npy_filename) +{ + _dtype = dtype; + _shape = shape; + _name = name; + _npy_filename = npy_filename; +} + +TosaSerializationTensor::TosaSerializationTensor() +{ + _dtype = DType_UNKNOWN; + + _name = "UNKNOWN"; +} + +TosaSerializationTensor::~TosaSerializationTensor() +{} + +TosaSerializationOperator::TosaSerializationOperator(Op op, + Attribute attribute_type, + const TosaAttributeBase* attribute, + QuantInfo qinfo_type, + const TosaQuantInfoBase* qinfo, + std::vector input_tensor_names, + std::vector output_tensor_names) +{ + _op = op; + _attribute_type = attribute_type; + + switch (attribute_type) + { + case Attribute_NONE: + _attribute = new TosaNoneAttribute(); + break; +#define DEF_ATTRIBUTE(NAME, ...) \ + case Attribute_##NAME##Attribute: \ + _attribute = new Tosa##NAME##Attribute(attribute); \ + break; +#include "attribute.def" +#undef DEF_ATTRIBUTE + default: + printf("TosaSerializationOperator::TosaSerializationOperator(): Attribute %s not implemented yet\n", + EnumNamesAttribute()[attribute_type]); + assert(0); + } + + _qinfo_type = qinfo_type; + switch (qinfo_type) + { + case QuantInfo_NONE: + _qinfo = new TosaNoneQuantInfo(); + break; +#define DEF_QUANTIZATION_INFO(NAME, ...) \ + case QuantInfo_##NAME##QuantInfo: \ + _qinfo = new Tosa##NAME##QuantInfo(qinfo); \ + break; +#include "quant_info.def" +#undef DEF_QUANTIZATION_INFO + default: + printf("TosaSerializationOperator::TosaSerializationOperator(): QuantInfo %s not implemented yet\n", + EnumNamesQuantInfo()[qinfo_type]); + assert(0); + } + + assert(_attribute && _qinfo); + + _input_tensor_names = input_tensor_names; + _output_tensor_names = output_tensor_names; +} + +TosaSerializationOperator::~TosaSerializationOperator() +{ + delete _attribute; + delete _qinfo; + // TosaSerializationTensor should be free'd in TosaSerializationSerializationHandler destructor +} + +TosaSerializationBasicBlock::TosaSerializationBasicBlock(std::string name, + std::vector operators, + std::vector tensors, + std::vector inputs, + std::vector outputs) +{ + + _name = name; + _operators = operators; + _tensors = tensors; + _inputs = inputs; + _outputs = outputs; +} + +TosaSerializationBasicBlock::~TosaSerializationBasicBlock() +{ + // deallocate all operators + for (auto op : GetOperators()) + { + delete op; // ~TosaSerializationOperator() + } + + // deallocate all tensors + for (auto ts : GetTensors()) + { + delete ts; // ~TosaSerializationTensor() + } +} + +TosaSerializationHandler::TosaSerializationHandler() +{ + _schemaLoaded = false; + + SetTosaVersion(); +} + +TosaSerializationHandler::~TosaSerializationHandler() +{ + Clear(); // deallocate all basic blocks +} + +tosa_err_t TosaSerializationHandler::SetTosaVersion() +{ + // version is specified within .fbs + // and it's encoded as defaulted value of CreateTosaVersion() + // need to write out one object to read that value out + // TODO: very costly now. is there any better way to encode constant in .fbs? + auto fboffset_version = CreateVersion(_builder); + auto fboffset_tosa_graph = CreateTosaGraphDirect(_builder, fboffset_version, nullptr); + _builder.Finish(fboffset_tosa_graph); + std::string jsongen; + uint8_t* buf = _builder.GetBufferPointer(); + auto fb_tosa_graph = GetTosaGraph(buf); + auto fb_tosa_version = fb_tosa_graph->version(); + + _version.set_version(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(), + fb_tosa_version->_experimental()); + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::LoadFileSchema(const char* schema_filename) +{ + std::string schema; + bool ok; + + ok = flatbuffers::LoadFile(schema_filename, false, &schema); + if (!ok) + { + printf("Error loading schema file: %s\n", schema_filename); + return TOSA_FILE_ERROR; + } + + ok = _parser.Parse(schema.c_str()); + if (!ok) + { + printf("Error parsing ISA schema file: %s\n", schema_filename); + return TOSA_FILE_ERROR; + } + _schemaLoaded = true; + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::LoadFileJson(const char* filename) +{ + std::string jsonfile; + bool ok; + tosa_err_t err; + + if (!_schemaLoaded) + { + return TOSA_SCHEMA_MISSING; + } + + ok = flatbuffers::LoadFile(filename, false, &jsonfile); + if (!ok) + { + printf("Error loading json file: %s\n", filename); + return TOSA_FILE_ERROR; + } + + ok = _parser.Parse(jsonfile.c_str()); + if (!ok) + { + printf("Error parsing json file: %s\n", filename); + return TOSA_FILE_ERROR; + } + + uint8_t* buf = _parser.builder_.GetBufferPointer(); + + err = InitWithBuf(buf); + if (err != TOSA_OK) + { + return err; + } + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::SaveFileJson(const char* filename) +{ + std::string jsongen; + tosa_err_t err; + + if (!_schemaLoaded) + { + return TOSA_SCHEMA_MISSING; + } + + err = FreezeBuilder(); + if (err != TOSA_OK) + { + return err; + } + + uint8_t* buf = _builder.GetBufferPointer(); + + if (!GenerateText(_parser, buf, &jsongen)) + { + printf("Couldn't serialize parsed data to JSON!\n"); + return TOSA_FILE_ERROR; + } + + FILE* file = fopen(filename, "wb"); + + if (!file) + { + printf("Couldn't open output file: %s\n", filename); + return TOSA_FILE_ERROR; + } + + if (fwrite(jsongen.c_str(), sizeof(char), jsongen.size(), file) != jsongen.size()) + { + printf("Error writing to json output file: %s\n", filename); + fclose(file); + return TOSA_FILE_ERROR; + } + + if (file) + fclose(file); + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::LoadFileTosaFlatbuffer(const char* filename) +{ + std::string read_buffer; + tosa_err_t err; + uint8_t* buf; + bool ok; + + ok = flatbuffers::LoadFile(filename, false, &read_buffer); + if (!ok) + { + printf("Error loading flatbuffer file: %s\n", filename); + return TOSA_FILE_ERROR; + } + + buf = (uint8_t*)read_buffer.data(); + + err = InitWithBuf(buf); + if (err != TOSA_OK) + { + return err; + } + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::SaveFileTosaFlatbuffer(const char* filename) +{ + tosa_err_t err; + + err = FreezeBuilder(); + if (err != TOSA_OK) + { + return err; + } + + uint8_t* buf = _builder.GetBufferPointer(); + + bool ok = flatbuffers::SaveFile(filename, (const char*)buf, _builder.GetSize(), false); + if (!ok) + { + printf("Error saving floatbuffer file: %s\n", filename); + return TOSA_FILE_ERROR; + } + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::Clear() +{ + // deallocate all basic blocks + for (auto bb : GetBlocks()) + { + delete bb; + } + _blocks.clear(); + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::CheckTosaVersion(const TosaVersion& read_version) +{ + if (_version != read_version) + { + printf("WARNING: read tosa version: %s != schema tosa version %s\n", read_version.to_string().c_str(), + _version.to_string().c_str()); + return TOSA_VERSION_MISMATCH; + } + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::InitWithBuf(const uint8_t* buf) +{ + auto fb_tosa_graph = GetTosaGraph(buf); + auto fb_tosa_version = fb_tosa_graph->version(); + auto fb_tosa_blocks = fb_tosa_graph->blocks(); + + std::vector operator_inputs_container; + std::vector operator_outputs_container; + + std::vector block_operators_container; + std::vector block_tensors_container; + std::vector block_inputs_container; + std::vector block_outputs_container; + + TosaAttributeBase* typed_attribute = NULL; + TosaQuantInfoBase* typed_qinfo = NULL; + TosaSerializationOperator* new_operator = NULL; + TosaSerializationBasicBlock* new_block = NULL; + TosaSerializationTensor* new_tensor = NULL; + + // erase container + Clear(); + + TosaVersion read_version(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(), + fb_tosa_version->_experimental()); + tosa_err_t err = CheckTosaVersion(read_version); + + if (err != TOSA_OK) + return err; + + for (size_t i = 0; i < fb_tosa_blocks->size(); i++) + { + auto curr_block = fb_tosa_blocks->Get(i); + + auto block_name = curr_block->name()->str(); + + auto fb_tosa_operators = curr_block->operators(); + block_operators_container.clear(); + for (size_t j = 0; j < fb_tosa_operators->size(); j++) + { + auto curr_operator = fb_tosa_operators->Get(j); + + auto operator_op = curr_operator->op(); + auto attribute_type = curr_operator->attribute_type(); + auto attribute = curr_operator->attribute(); + auto operator_qinfo_type = curr_operator->quant_info_type(); + auto operator_qinfo = curr_operator->quant_info(); + + // input tensors + auto operator_inputs = curr_operator->inputs(); + operator_inputs_container.clear(); + if (operator_inputs) + { + for (size_t k = 0; k < operator_inputs->size(); k++) + { + auto curr_input = operator_inputs->Get(k); + operator_inputs_container.push_back(curr_input->str()); + } + } + + // output tensors + auto operator_outputs = curr_operator->outputs(); + operator_outputs_container.clear(); + if (operator_outputs) + { + for (size_t k = 0; k < operator_outputs->size(); k++) + { + auto curr_output = operator_outputs->Get(k); + operator_outputs_container.push_back(curr_output->str()); + } + } + + switch (attribute_type) + { + case Attribute_NONE: + typed_attribute = new TosaNoneAttribute(); + break; +#define DEF_ATTRIBUTE(NAME, ...) \ + case Attribute_##NAME##Attribute: \ + typed_attribute = new Tosa##NAME##Attribute(attribute); \ + break; +#include "attribute.def" +#undef DEF_ATTRIBUTE + default: + printf("TosaSerializationHandler::InitWithBuf(): Attribute %s not implemented yet\n", + EnumNamesAttribute()[attribute_type]); + return TOSA_INTERNAL_ERROR; + } + + switch (operator_qinfo_type) + { + case QuantInfo_NONE: + typed_qinfo = new TosaNoneQuantInfo(); + break; +#define DEF_QUANTIZATION_INFO(NAME, ...) \ + case QuantInfo_##NAME##QuantInfo: \ + typed_qinfo = new Tosa##NAME##QuantInfo(operator_qinfo); \ + break; + +#include "quant_info.def" +#undef DEF_QUANTIZATION_INFO + default: + printf("TosaSerializationHandler::InitWithBuf(): QuantInfo %s not implemented yet\n", + EnumNamesQuantInfo()[operator_qinfo_type]); + return TOSA_INTERNAL_ERROR; + } + + new_operator = + new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, operator_qinfo_type, + typed_qinfo, operator_inputs_container, operator_outputs_container); + if (new_operator) + { + block_operators_container.push_back(new_operator); + } + else + { + return TOSA_MEMORY_ERROR; + } + + if (typed_attribute) + delete typed_attribute; + if (typed_qinfo) + delete typed_qinfo; + } + + auto fb_tosa_tensors = curr_block->tensors(); + block_tensors_container.clear(); + for (size_t j = 0; j < fb_tosa_tensors->size(); j++) + { + auto curr_tensor = fb_tosa_tensors->Get(j); + + auto tensor_name = curr_tensor->name(); + auto tensor_shape = curr_tensor->shape(); + auto tensor_type = curr_tensor->type(); + auto tensor_npy_filename = curr_tensor->npy_filename(); + + new_tensor = new TosaSerializationTensor(tensor_name, *tensor_shape, tensor_type, tensor_npy_filename); + if (new_tensor) + { + block_tensors_container.push_back(new_tensor); + } + else + { + return TOSA_MEMORY_ERROR; + } + } + + auto block_inputs = curr_block->inputs(); + auto block_outputs = curr_block->outputs(); + + block_inputs_container.clear(); + block_outputs_container.clear(); + + for (size_t j = 0; j < block_inputs->size(); j++) + { + auto curr_block_input = block_inputs->Get(j); + block_inputs_container.push_back(curr_block_input->str()); + } + for (size_t j = 0; j < block_outputs->size(); j++) + { + auto curr_block_output = block_outputs->Get(j); + block_outputs_container.push_back(curr_block_output->str()); + } + + new_block = new TosaSerializationBasicBlock(block_name, block_operators_container, block_tensors_container, + block_inputs_container, block_outputs_container); + if (new_block) + { + this->GetBlocks().push_back(new_block); + } + else + { + return TOSA_MEMORY_ERROR; + } + } + + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::FreezeBuilder() +{ + std::vector> fboffset_blocks; + + std::vector> fboffset_block_operators; + std::vector> fboffset_block_tensors; + std::vector> fboffset_block_inputs; + std::vector> fboffset_block_outputs; + + std::vector> fboffset_operator_inputs; + std::vector> fboffset_operator_outputs; + + // translate TosaFlatbufferOperator to flatbuffers::Offset + for (auto block : GetBlocks()) + { + fboffset_block_operators.clear(); + fboffset_block_tensors.clear(); + fboffset_block_inputs.clear(); + fboffset_block_outputs.clear(); + + auto block_name = _builder.CreateString(block->GetName().c_str()); + + for (auto tensor_str : block->GetInputs()) + { + auto tensor_name = _builder.CreateString(tensor_str.c_str()); + fboffset_block_inputs.push_back(tensor_name); + } + + for (auto tensor_str : block->GetOutputs()) + { + auto tensor_name = _builder.CreateString(tensor_str.c_str()); + fboffset_block_outputs.push_back(tensor_name); + } + + auto fb_block_inputs = _builder.CreateVector(fboffset_block_inputs); + auto fb_block_outputs = _builder.CreateVector(fboffset_block_outputs); + + for (auto op : block->GetOperators()) + { + fboffset_operator_inputs.clear(); + fboffset_operator_outputs.clear(); + + auto operator_op = op->GetOp(); + auto attribute_type = op->GetAttributeType(); + + for (auto tensor_str : op->GetInputTensorNames()) + { + auto tensor_name = _builder.CreateString(tensor_str.c_str()); + fboffset_operator_inputs.push_back(tensor_name); + } + + for (auto tensor_str : op->GetOutputTensorNames()) + { + auto tensor_name = _builder.CreateString(tensor_str.c_str()); + fboffset_operator_outputs.push_back(tensor_name); + } + + auto fb_operator_inputs = _builder.CreateVector(fboffset_operator_inputs); + auto fb_operator_outputs = _builder.CreateVector(fboffset_operator_outputs); + + flatbuffers::Offset fb_attribute; + switch (attribute_type) + { + case Attribute_NONE: + fb_attribute = 0; + break; + +#define DEF_ARGS_S_STR(NAME, V) , _builder.CreateString(reinterpret_cast(op->GetAttribute())->V().c_str()) +#define DEF_ARGS_S_DEFAULT(NAME, V) , reinterpret_cast(op->GetAttribute())->V() + +#define DEF_ARGS_S_int32_t(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) +#define DEF_ARGS_S_float(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) +#define DEF_ARGS_S_bool(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) +#define DEF_ARGS_S_ResizeMode(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) +#define DEF_ARGS_S_string(NAME, V) DEF_ARGS_S_STR(NAME, V) + +#define DEF_ARGS_S(NAME, T, V) DEF_ARGS_S_##T(NAME, V) +#define DEF_ARGS_V(NAME, T, V) , _builder.CreateVector(reinterpret_cast(op->GetAttribute())->V()) + +#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0) +#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) +#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) +#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) +#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) +#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) +#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) +#define DEF_ATTRIBUTE(NAME, NUM_ARGS, ...) \ + case Attribute_##NAME##Attribute: \ + fb_attribute = Create##NAME##Attribute(_builder DEF_ARGS_##NUM_ARGS(NAME##Attribute, __VA_ARGS__)).Union(); \ + break; + +#include "attribute.def" +#undef DEF_ATTRIBUTE +#undef DEF_ARGS_1 +#undef DEF_ARGS_2 +#undef DEF_ARGS_3 +#undef DEF_ARGS_4 +#undef DEF_ARGS_5 +#undef DEF_ARGS_6 +#undef DEF_ARGS_7 +#undef DEF_ARGS_S +#undef DEF_ARGS_V +#undef DEF_ARGS_S_int32_t +#undef DEF_ARGS_S_float +#undef DEF_ARGS_S_bool +#undef DEF_ARGS_S_ResizeMode +#undef DEF_ARGS_S_string +#undef DEF_ARGS_S_STR +#undef DEF_ARGS_S_DEFAULT + default: + printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n", + EnumNamesAttribute()[attribute_type]); + return TOSA_INTERNAL_ERROR; + } + + auto qinfo_type = op->GetQInfoType(); + flatbuffers::Offset fb_operator_qinfo; + switch (qinfo_type) + { + case QuantInfo_NONE: + fb_operator_qinfo = 0; + break; +#define DEF_ARGS_S(NAME, T, V) , reinterpret_cast(op->GetQInfo())->V() +#define DEF_ARGS_V(NAME, T, V) , _builder.CreateVector(reinterpret_cast(op->GetQInfo())->V()) + +#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0) +#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) +#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) +#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) +#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) +#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) +#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) +#define DEF_ARGS_8(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ + DEF_ARGS_##F7(NAME, T7, V7) +#define DEF_ARGS_9(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7, T8, F8, V8) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ + DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) +#define DEF_ARGS_10(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7, T8, F8, V8, T9, F9, V9) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ + DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) DEF_ARGS_##F9(NAME, T9, V9) +#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \ + case QuantInfo_##NAME##QuantInfo: \ + fb_operator_qinfo = \ + Create##NAME##QuantInfo(_builder DEF_ARGS_##NUM_ARGS(NAME##QuantInfo, __VA_ARGS__)).Union(); \ + break; + +#include "quant_info.def" +#undef DEF_QUANTIZATION_INFO +#undef DEF_ARGS_1 +#undef DEF_ARGS_2 +#undef DEF_ARGS_3 +#undef DEF_ARGS_4 +#undef DEF_ARGS_5 +#undef DEF_ARGS_6 +#undef DEF_ARGS_7 +#undef DEF_ARGS_8 +#undef DEF_ARGS_9 +#undef DEF_ARGS_10 +#undef DEF_ARGS_S +#undef DEF_ARGS_V + default: + printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n", + EnumNamesAttribute()[attribute_type]); + return TOSA_INTERNAL_ERROR; + } + + auto fboffset_operator = + CreateTosaOperator(_builder, operator_op, attribute_type, fb_attribute, fb_operator_inputs, + fb_operator_outputs, qinfo_type, fb_operator_qinfo); + fboffset_block_operators.push_back(fboffset_operator); + } + + auto fb_block_operators = _builder.CreateVector(fboffset_block_operators); + + for (auto tensor : block->GetTensors()) + { + + auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); + auto tensor_shape = _builder.CreateVector(tensor->GetShape()); + auto tensor_dtype = tensor->GetDtype(); + flatbuffers::Offset tensor_npy_filename = 0; + if (!tensor->GetNpyFilePtr().empty()) + tensor_npy_filename = _builder.CreateString(tensor->GetNpyFilePtr().c_str()); + + auto fboffset_tensor = + CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_npy_filename); + fboffset_block_tensors.push_back(fboffset_tensor); + } + + auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors); + + auto fboffset_block = CreateTosaBasicBlock(_builder, block_name, fb_block_operators, fb_block_tensors, + fb_block_inputs, fb_block_outputs); + fboffset_blocks.push_back(fboffset_block); + } + + auto fb_blocks = _builder.CreateVector(fboffset_blocks); + + auto fb_version = CreateVersion(_builder, GetTosaVersion()._major, GetTosaVersion()._minor, GetTosaVersion()._patch, + GetTosaVersion()._experimental); + + auto fb_graph = CreateTosaGraph(_builder, fb_version, fb_blocks); + _builder.Finish(fb_graph); + + return TOSA_OK; +} diff --git a/test/scripts/test_npy_fileio.py b/test/scripts/test_npy_fileio.py new file mode 100755 index 0000000..e0a6f5d --- /dev/null +++ b/test/scripts/test_npy_fileio.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Simple test script which tests numpy file read/write""" + +import argparse +import random +import shlex +import subprocess +from datetime import datetime +from enum import IntEnum, unique +from pathlib import Path +from xunit.xunit import xunit_results, xunit_test + + +@unique +class TestResult(IntEnum): + PASS = 0 + COMMAND_ERROR = 1 + MISMATCH = 2 + SKIPPED = 3 + + +def parseArgs(): + baseDir = (Path(__file__).parent / "../..").resolve() + buildDir = (baseDir / "build").resolve() + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", + "--cmd", + default=str(buildDir / "serialization_npy_test"), + help="Command to write/read test file", + ) + parser.add_argument("-s", "--seed", default=1, help="Random number seed") + parser.add_argument( + "-v", "--verbose", action="store_true", help="verbose", default=False + ) + parser.add_argument( + "--xunit-file", default="npy-result.xml", help="xunit result output file" + ) + args = parser.parse_args() + + # check that required files exist + if not Path(args.cmd).exists(): + print("command not found at location " + args.cmd) + parser.print_help() + exit(1) + return args + + +def run_sh_command(full_cmd, verbose=False, capture_output=False): + """Utility function to run an external command. Optionally return captured + stdout/stderr""" + + # Quote the command line for printing + full_cmd_esc = [shlex.quote(x) for x in full_cmd] + + if verbose: + print("### Running {}".format(" ".join(full_cmd_esc))) + + if capture_output: + rc = subprocess.run(full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if rc.returncode != 0: + print(rc.stdout.decode("utf-8")) + print(rc.stderr.decode("utf-8")) + raise Exception( + "Error running command: {}.\n{}".format( + " ".join(full_cmd_esc), rc.stderr.decode("utf-8") + ) + ) + return (rc.stdout, rc.stderr) + else: + rc = subprocess.run(full_cmd) + if rc.returncode != 0: + raise Exception("Error running command: {}".format(" ".join(full_cmd_esc))) + + +def runTest(args, dtype, shape): + start_time = datetime.now() + result = TestResult.PASS + message = "" + + target = Path(f"npytest-{random.randint(0,10000)}.npy") + shape_str = ",".join(shape) + # Remove any previous files + if target.exists(): + target.unlink() + + try: + cmd = [args.cmd, "-d", dtype, "-f", str(target), "-t", shape_str] + run_sh_command(cmd, args.verbose) + target.unlink() + + except Exception as e: + message = str(e) + result = TestResult.COMMAND_ERROR + end_time = datetime.now() + return result, message, end_time - start_time + + +def main(): + args = parseArgs() + + suitename = "basic_serialization" + classname = "npy_test" + + xunit_result = xunit_results() + xunit_suite = xunit_result.create_suite("basic_serialization") + + max_size = 128 + datatypes = ["int32", "int64", "float", "bool"] + random.seed(args.seed) + + failed = 0 + count = 0 + for test in datatypes: + count = count + 1 + shape = [] + for i in range(4): + shape.append(str(random.randint(1, max_size))) + (result, message, time_delta) = runTest(args, test, shape) + xt = xunit_test(str(test), f"{suitename}.{classname}") + xt.time = str( + float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6) + ) + if result == TestResult.PASS: + pass + else: + xt.failed(message) + failed = failed + 1 + xunit_suite.tests.append(xt) + + xunit_result.write_results(args.xunit_file) + print(f"Total tests run: {count} failures: {failed}") + + +if __name__ == "__main__": + exit(main()) diff --git a/test/scripts/test_serialization.py b/test/scripts/test_serialization.py new file mode 100755 index 0000000..834bc1d --- /dev/null +++ b/test/scripts/test_serialization.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Simple test script which uses serialization_read_write to copy tosa files. It +uses flatc to convert to json for comparison since the binary files may +differ. """ + +import argparse +import filecmp +import random +import shlex +import subprocess +from datetime import datetime +from enum import IntEnum, unique +from pathlib import Path +from xunit.xunit import xunit_results, xunit_test + + +@unique +class TestResult(IntEnum): + PASS = 0 + COMMAND_ERROR = 1 + MISMATCH = 2 + SKIPPED = 3 + + +def parseArgs(): + baseDir = (Path(__file__).parent / "../..").resolve() + buildDir = (baseDir / "build").resolve() + parser = argparse.ArgumentParser() + parser.add_argument( + "-t", + "--testdir", + dest="test", + type=str, + required=True, + help="Directory of tosa files to verify", + ) + parser.add_argument( + "--flatc", + default=str(buildDir / "third_party/flatbuffers/flatc"), + help="location of flatc compiler", + ) + parser.add_argument( + "-s", + "--schema", + default=str(baseDir / "schema/tosa.fbs"), + help="location of schema file", + ) + parser.add_argument( + "-c", + "--cmd", + default=str(buildDir / "serialization_read_write"), + help="Command to read/write test file", + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="verbose", default=False + ) + parser.add_argument( + "--xunit-file", default="result.xml", help="xunit result output file" + ) + args = parser.parse_args() + + # check that required files exist + if not Path(args.flatc).exists(): + print("flatc not found at location " + args.flatc) + parser.print_help() + exit(1) + if not Path(args.cmd).exists(): + print("command not found at location " + args.cmd) + parser.print_help() + exit(1) + if not Path(args.schema).exists(): + print("schema not found at location " + args.schema) + parser.print_help() + exit(1) + return args + + +def run_sh_command(full_cmd, verbose=False, capture_output=False): + """Utility function to run an external command. Optionally return captured + stdout/stderr""" + + # Quote the command line for printing + full_cmd_esc = [shlex.quote(x) for x in full_cmd] + + if verbose: + print("### Running {}".format(" ".join(full_cmd_esc))) + + if capture_output: + rc = subprocess.run(full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if rc.returncode != 0: + print(rc.stdout.decode("utf-8")) + print(rc.stderr.decode("utf-8")) + raise Exception( + "Error running command: {}.\n{}".format( + " ".join(full_cmd_esc), rc.stderr.decode("utf-8") + ) + ) + return (rc.stdout, rc.stderr) + else: + rc = subprocess.run(full_cmd) + if rc.returncode != 0: + raise Exception("Error running command: {}".format(" ".join(full_cmd_esc))) + + +def runTest(args, testfile): + start_time = datetime.now() + result = TestResult.PASS + message = "" + + target = Path(f"serialization_script_output-{random.randint(0,10000)}.tosa") + source_json = Path(testfile.stem + ".json") + target_json = Path(target.stem + ".json") + + # Remove any previous files + if target.exists(): + target.unlink() + if source_json.exists(): + source_json.unlink() + if target_json.exists(): + target_json.unlink() + + try: + cmd = [args.cmd, str(testfile), str(target)] + run_sh_command(cmd, args.verbose) + # Create result json + cmd = [args.flatc, "--json", "--raw-binary", args.schema, "--", str(target)] + run_sh_command(cmd, args.verbose) + # Create source json + cmd = [args.flatc, "--json", "--raw-binary", args.schema, "--", str(testfile)] + run_sh_command(cmd, args.verbose) + if not filecmp.cmp(str(target_json), str(source_json), False): + print("Failed to compare files on " + str(testfile)) + result = TestResult.MISMATCH + # Cleanup generated files + source_json.unlink() + target_json.unlink() + target.unlink() + + except Exception as e: + message = str(e) + result = TestResult.COMMAND_ERROR + end_time = datetime.now() + return result, message, end_time - start_time + + +def getTestFiles(dir): + files = Path(dir).glob("**/*.tosa") + return files + + +def main(): + args = parseArgs() + testfiles = getTestFiles(args.test) + + suitename = "basic_serialization" + classname = "copy_test" + + xunit_result = xunit_results() + xunit_suite = xunit_result.create_suite("basic_serialization") + + failed = 0 + count = 0 + for test in testfiles: + count = count + 1 + (result, message, time_delta) = runTest(args, test) + xt = xunit_test(str(test), f"{suitename}.{classname}") + xt.time = str( + float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6) + ) + if result == TestResult.PASS: + pass + else: + xt.failed(message) + failed = failed + 1 + xunit_suite.tests.append(xt) + + xunit_result.write_results(args.xunit_file) + print(f"Total tests run: {count} failures: {failed}") + + +if __name__ == "__main__": + exit(main()) diff --git a/test/scripts/testfiles/test.tosa b/test/scripts/testfiles/test.tosa new file mode 100644 index 0000000..3b4ca56 Binary files /dev/null and b/test/scripts/testfiles/test.tosa differ diff --git a/test/scripts/xunit/xunit.py b/test/scripts/xunit/xunit.py new file mode 100644 index 0000000..2de0d5c --- /dev/null +++ b/test/scripts/xunit/xunit.py @@ -0,0 +1,109 @@ +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import xml.etree.ElementTree as ET +from xml.dom import minidom + + +class xunit_results: + def __init__(self): + self.name = "testsuites" + self.suites = [] + + def create_suite(self, name): + s = xunit_suite(name) + self.suites.append(s) + return s + + def write_results(self, filename): + suites = ET.Element(self.name) + tree = ET.ElementTree(suites) + for s in self.suites: + testsuite = ET.SubElement( + suites, "testsuite", {"name": s.name, "errors": "0"} + ) + tests = 0 + failures = 0 + skip = 0 + for t in s.tests: + test = ET.SubElement( + testsuite, + "testcase", + {"name": t.name, "classname": t.classname, "time": t.time}, + ) + tests += 1 + if t.skip: + skip += 1 + ET.SubElement(test, "skipped", {"type": "Skipped test"}) + if t.fail: + failures += 1 + fail = ET.SubElement(test, "failure", {"type": "Test failed"}) + fail.text = t.fail + if t.sysout: + sysout = ET.SubElement(test, "system-out") + sysout.text = t.sysout + if t.syserr: + syserr = ET.SubElement(test, "system-err") + syserr.text = t.syserr + testsuite.attrib["tests"] = str(tests) + testsuite.attrib["failures"] = str(failures) + testsuite.attrib["skip"] = str(skip) + xmlstr = minidom.parseString(ET.tostring(tree.getroot())).toprettyxml( + indent=" " + ) + with open(filename, "w") as f: + f.write(xmlstr) + + +class xunit_suite: + def __init__(self, name): + self.name = name + self.tests = [] + + +# classname should be of the form suite.class/subclass/subclass2/... It appears +# you can have an unlimited number of subclasses in this manner + + +class xunit_test: + def __init__(self, name, classname=None): + self.name = name + if classname: + self.classname = classname + else: + self.classname = name + self.time = "0.000" + self.fail = None + self.skip = False + self.sysout = None + self.syserr = None + + def failed(self, text): + self.fail = text + + def skipped(self): + self.skip = True + + +if __name__ == "__main__": + r = xunit_results() + s = r.create_suite("selftest") + for i in range(0, 10): + t = xunit_test("atest" + str(i), "selftest") + if i == 3: + t.failed("Unknown failure foo") + if i == 7: + t.skipped() + s.tests.append(t) + r.write_results("foo.xml") diff --git a/test/src/serialization_npy_test.cpp b/test/src/serialization_npy_test.cpp new file mode 100644 index 0000000..27ec464 --- /dev/null +++ b/test/src/serialization_npy_test.cpp @@ -0,0 +1,225 @@ +// Copyright (c) 2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +using namespace tosa; + +void usage() +{ + std::cout << "Usage: serialization_npy_test -f -t -d -s " << std::endl; +} + +template +int test_int_type(std::vector shape, std::default_random_engine& gen, std::string& filename) +{ + size_t total_size = 1; + std::uniform_int_distribution gen_data(std::numeric_limits::min(), std::numeric_limits::max()); + + for (auto i : shape) + { + total_size *= i; + } + + auto buffer = std::make_unique(total_size); + for (int i = 0; i < total_size; i++) + { + buffer[i] = gen_data(gen); + } + + NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get()); + if (err != NumpyUtilities::NO_ERROR) + { + std::cout << "Error writing file, code " << err << std::endl; + return 1; + } + + auto read_buffer = std::make_unique(total_size); + err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get()); + if (err != NumpyUtilities::NO_ERROR) + { + std::cout << "Error reading file, code " << err << std::endl; + return 1; + } + if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T))) + { + std::cout << "Miscompare" << std::endl; + return 1; + } + return 0; +} + +template +int test_float_type(std::vector shape, std::default_random_engine& gen, std::string& filename) +{ + size_t total_size = 1; + std::uniform_real_distribution gen_data(std::numeric_limits::min(), std::numeric_limits::max()); + + for (auto i : shape) + { + total_size *= i; + } + + auto buffer = std::make_unique(total_size); + for (int i = 0; i < total_size; i++) + { + buffer[i] = gen_data(gen); + } + + NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get()); + if (err != NumpyUtilities::NO_ERROR) + { + std::cout << "Error writing file, code " << err << std::endl; + return 1; + } + + auto read_buffer = std::make_unique(total_size); + err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get()); + if (err != NumpyUtilities::NO_ERROR) + { + std::cout << "Error reading file, code " << err << std::endl; + return 1; + } + if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T))) + { + std::cout << "Miscompare" << std::endl; + return 1; + } + return 0; +} + +int test_bool_type(std::vector shape, std::default_random_engine& gen, std::string& filename) +{ + size_t total_size = 1; + std::uniform_int_distribution gen_data(0, 1); + + for (auto i : shape) + { + total_size *= i; + } + + auto buffer = std::make_unique(total_size); + for (int i = 0; i < total_size; i++) + { + buffer[i] = (gen_data(gen)) ? true : false; + } + + NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get()); + if (err != NumpyUtilities::NO_ERROR) + { + std::cout << "Error writing file, code " << err << std::endl; + return 1; + } + + auto read_buffer = std::make_unique(total_size); + err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get()); + if (err != NumpyUtilities::NO_ERROR) + { + std::cout << "Error reading file, code " << err << std::endl; + return 1; + } + + if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(bool))) + { + std::cout << "Miscompare" << std::endl; + return 1; + } + return 0; +} + +int main(int argc, char** argv) +{ + size_t total_size = 1; + int32_t seed = 1; + std::string str_type; + std::string str_shape; + std::string filename = "npytest.npy"; + std::vector shape; + bool verbose = false; + int opt; + while ((opt = getopt(argc, argv, "d:f:s:t:v")) != -1) + { + switch (opt) + { + case 'd': + str_type = optarg; + break; + case 'f': + filename = optarg; + break; + case 's': + seed = strtol(optarg, nullptr, 0); + break; + case 't': + str_shape = optarg; + break; + case 'v': + verbose = true; + break; + default: + std::cerr << "Invalid argument" << std::endl; + break; + } + } + if (str_shape == "") + { + usage(); + return 1; + } + + // parse shape from argument + std::stringstream ss(str_shape); + while (ss.good()) + { + std::string substr; + size_t pos; + std::getline(ss, substr, ','); + if (substr == "") + break; + int val = stoi(substr, &pos, 0); + assert(val); + total_size *= val; + shape.push_back(val); + } + + std::default_random_engine gen(seed); + + // run with type from argument + if (str_type == "int32") + { + return test_int_type(shape, gen, filename); + } + else if (str_type == "int64") + { + return test_int_type(shape, gen, filename); + } + else if (str_type == "float") + { + return test_float_type(shape, gen, filename); + } + else if (str_type == "bool") + { + return test_bool_type(shape, gen, filename); + } + else + { + std::cout << "Unknown type " << str_type << std::endl; + usage(); + return 1; + } +} diff --git a/test/src/serialization_read_write.cpp b/test/src/serialization_read_write.cpp new file mode 100644 index 0000000..1f29fac --- /dev/null +++ b/test/src/serialization_read_write.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +using namespace tosa; + +void usage() +{ + std::cerr << "Usage: " << std::endl; + std::cerr << " : source TOSA serialize filename" << std::endl; + std::cerr << " : destination TOSA serialized filename" << std::endl; +} + +int main(int argc, char** argv) +{ + TosaSerializationHandler handler; + if (argc != 3) + { + usage(); + return 1; + } + + tosa_err_t err = handler.LoadFileTosaFlatbuffer(argv[1]); + if (err != TOSA_OK) + { + std::cout << "error reading file " << argv[1] << " code " << err << std::endl; + return 1; + } + + err = handler.SaveFileTosaFlatbuffer(argv[2]); + if (err != TOSA_OK) + { + std::cout << "error writing file " << argv[2] << " code " << err << std::endl; + return 1; + } + return 0; +} diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt new file mode 100644 index 0000000..b26a94c --- /dev/null +++ b/third_party/CMakeLists.txt @@ -0,0 +1,25 @@ +cmake_minimum_required (VERSION 3.13.4) + +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set(CMAKE_INSTALL_PREFIX "./third_party" CACHE PATH "..." FORCE) + +project(third_party LANGUAGES CXX) + +# Flatbuffers tests are not needed +set(FLATBUFFERS_BUILD_TESTS OFF) + +add_subdirectory(flatbuffers) diff --git a/third_party/flatbuffers b/third_party/flatbuffers new file mode 160000 index 0000000..6df40a2 --- /dev/null +++ b/third_party/flatbuffers @@ -0,0 +1 @@ +Subproject commit 6df40a2471737b27271bdd9b900ab5f3aec746c7 -- cgit v1.2.1