aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2020-10-13 16:11:07 -0700
committerKevin Cheng <kevin.cheng@arm.com>2020-10-14 11:11:43 -0700
commite5e2676409a936431f87d31fb74d825257b20804 (patch)
tree304d93d993ef6417b02a515025f9030367682774
parent88b7860f180f91b5b66764c61cfd97d8bc53cece (diff)
downloadreference_model-e5e2676409a936431f87d31fb74d825257b20804.tar.gz
Initial checkin of TOSA reference_model and tests
Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6 Signed-off-by: Eric Kunze <eric.kunze@arm.com>
-rw-r--r--.gitmodules6
-rw-r--r--CMakeLists.txt17
-rw-r--r--NOTICE204
-rw-r--r--README.md202
-rw-r--r--examples/test_add_1x4x4x4_f32/InputTensor-tf0.npybin0 -> 192 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/InputTensor-tf1.npybin0 -> 384 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/InputTensor-tflite0.npybin0 -> 192 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/InputTensor-tflite1.npybin0 -> 384 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosabin0 -> 644 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosabin0 -> 676 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/placeholder_0.npybin0 -> 192 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/placeholder_1.npybin0 -> 384 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/ref_ofm.npybin0 -> 384 bytes
-rw-r--r--examples/test_add_1x4x4x4_f32/test.json27
-rw-r--r--examples/test_add_1x4x4x4_f32/test_tf.pb112
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/InputTensor-tf0.npybin0 -> 32896 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/InputTensor-tflite0.npybin0 -> 32896 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/layer_1.npybin0 -> 640 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/layer_2.npybin0 -> 144 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/layer_4.npybin0 -> 192 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosabin0 -> 1156 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/layer_1.npybin0 -> 192 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/layer_2.npybin0 -> 640 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosabin0 -> 896 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/placeholder_0.npybin0 -> 32896 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/ref_ofm.npybin0 -> 65664 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/test.json19
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/test_tf.pb163
-rw-r--r--examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/InputTensor-tflite0.npybin0 -> 32896 bytes
-rw-r--r--examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/layer_1.npybin0 -> 640 bytes
-rw-r--r--examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/layer_2.npybin0 -> 192 bytes
-rw-r--r--examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosabin0 -> 1240 bytes
-rw-r--r--examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/placeholder_0.npybin0 -> 32896 bytes
-rw-r--r--examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/placeholder_0_quant.npybin0 -> 32896 bytes
-rw-r--r--examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/ref_ofm.npybin0 -> 16512 bytes
-rw-r--r--examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/test.json27
-rw-r--r--examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/test_tf.pb312
-rw-r--r--reference_model/CMakeLists.txt76
-rw-r--r--reference_model/src/arith_util.h194
-rw-r--r--reference_model/src/debug_modes.def20
-rw-r--r--reference_model/src/debug_types.h57
-rw-r--r--reference_model/src/func_config.cc632
-rw-r--r--reference_model/src/func_config.def90
-rw-r--r--reference_model/src/func_config.h55
-rw-r--r--reference_model/src/func_debug.cc436
-rw-r--r--reference_model/src/func_debug.h255
-rw-r--r--reference_model/src/graph_node.cc226
-rw-r--r--reference_model/src/graph_node.h354
-rw-r--r--reference_model/src/main.cpp295
-rw-r--r--reference_model/src/model_common.h28
-rw-r--r--reference_model/src/ops/activation_funcs.cc118
-rw-r--r--reference_model/src/ops/activation_funcs.h101
-rw-r--r--reference_model/src/ops/comparison.cc81
-rw-r--r--reference_model/src/ops/comparison.h71
-rw-r--r--reference_model/src/ops/control_flow.cc353
-rw-r--r--reference_model/src/ops/control_flow.h72
-rw-r--r--reference_model/src/ops/custom.cc40
-rw-r--r--reference_model/src/ops/custom.h38
-rw-r--r--reference_model/src/ops/data_layout.cc644
-rw-r--r--reference_model/src/ops/data_layout.h216
-rw-r--r--reference_model/src/ops/data_nodes.cc172
-rw-r--r--reference_model/src/ops/data_nodes.h86
-rw-r--r--reference_model/src/ops/ewise_binary.cc586
-rw-r--r--reference_model/src/ops/ewise_binary.h195
-rw-r--r--reference_model/src/ops/ewise_ternary.cc115
-rw-r--r--reference_model/src/ops/ewise_ternary.h83
-rw-r--r--reference_model/src/ops/ewise_unary.cc302
-rw-r--r--reference_model/src/ops/ewise_unary.h102
-rw-r--r--reference_model/src/ops/image.cc169
-rw-r--r--reference_model/src/ops/image.h53
-rw-r--r--reference_model/src/ops/op_factory.cc432
-rw-r--r--reference_model/src/ops/op_factory.h294
-rw-r--r--reference_model/src/ops/reduction.cc139
-rw-r--r--reference_model/src/ops/reduction.h109
-rw-r--r--reference_model/src/ops/scatter_gather.cc120
-rw-r--r--reference_model/src/ops/scatter_gather.h54
-rw-r--r--reference_model/src/ops/template_types.h277
-rw-r--r--reference_model/src/ops/tensor_ops.cc1229
-rw-r--r--reference_model/src/ops/tensor_ops.h253
-rw-r--r--reference_model/src/ops/type_conversion.cc299
-rw-r--r--reference_model/src/ops/type_conversion.h162
-rw-r--r--reference_model/src/quant_util.h103
-rw-r--r--reference_model/src/subgraph_traverser.cc649
-rw-r--r--reference_model/src/subgraph_traverser.h90
-rw-r--r--reference_model/src/tensor.cc3008
-rw-r--r--reference_model/src/tensor.h815
-rw-r--r--scripts/xunit/xunit.py91
-rw-r--r--serialization/CMakeLists.txt32
-rw-r--r--serialization/attribute.def90
-rw-r--r--serialization/attribute.h181
-rw-r--r--serialization/operator.def123
-rw-r--r--serialization/quant_info.def43
-rw-r--r--serialization/quant_info.h164
-rw-r--r--serialization/tosa.fbs318
-rw-r--r--serialization/tosa_generated.h2605
-rw-r--r--serialization/tosa_serialization_handler.cpp1526
-rw-r--r--serialization/tosa_serialization_handler.h423
-rw-r--r--thirdparty/CMakeLists.txt10
m---------thirdparty/eigen0
m---------thirdparty/flatbuffers0
-rw-r--r--verif/tosa/Attribute.py36
-rw-r--r--verif/tosa/AxisAttribute.py45
-rw-r--r--verif/tosa/ClampAttribute.py69
-rw-r--r--verif/tosa/CondIfAttribute.py53
-rw-r--r--verif/tosa/Conv2dAttribute.py109
-rw-r--r--verif/tosa/ConvQuantInfo.py53
-rw-r--r--verif/tosa/CustomAttribute.py45
-rw-r--r--verif/tosa/DType.py31
-rw-r--r--verif/tosa/Format.py27
-rw-r--r--verif/tosa/MatMulQuantInfo.py53
-rw-r--r--verif/tosa/Op.py90
-rw-r--r--verif/tosa/PadQuantInfo.py45
-rw-r--r--verif/tosa/Pool2dAttribute.py109
-rw-r--r--verif/tosa/QuantInfo.py26
-rw-r--r--verif/tosa/README.md14
-rw-r--r--verif/tosa/ReluNAttribute.py53
-rw-r--r--verif/tosa/RescaleAttribute.py125
-rw-r--r--verif/tosa/ReshapeAttribute.py61
-rw-r--r--verif/tosa/ResizeAttribute.py125
-rw-r--r--verif/tosa/ResizeMode.py24
-rw-r--r--verif/tosa/SliceAttribute.py85
-rw-r--r--verif/tosa/TileAttribute.py61
-rw-r--r--verif/tosa/TosaBasicBlock.py123
-rw-r--r--verif/tosa/TosaGraph.py71
-rw-r--r--verif/tosa/TosaOperator.py117
-rw-r--r--verif/tosa/TosaTensor.py133
-rw-r--r--verif/tosa/TransposeConv2dAttribute.py133
-rw-r--r--verif/tosa/UnaryQuantInfo.py53
-rw-r--r--verif/tosa/Usage.py25
-rw-r--r--verif/tosa/Version.py69
-rw-r--r--verif/tosa/WhileLoopAttribute.py53
-rw-r--r--verif/tosa/__init__.py15
-rw-r--r--verif/tosa_ref_run.py66
-rw-r--r--verif/tosa_serializer.py718
-rw-r--r--verif/tosa_test_gen.py2301
-rw-r--r--verif/tosa_test_runner.py63
-rwxr-xr-xverif/tosa_verif_build_tests.py136
-rwxr-xr-xverif/tosa_verif_run_ref.py198
138 files changed, 26656 insertions, 0 deletions
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..9a6276e
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,6 @@
+[submodule "thirdparty/eigen"]
+ path = thirdparty/eigen
+ url = https://gitlab.com/libeigen/eigen.git
+[submodule "thirdparty/flatbuffers"]
+ path = thirdparty/flatbuffers
+ url = https://github.com/google/flatbuffers
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000..19c5824
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,17 @@
+cmake_minimum_required (VERSION 3.4)
+
+set(CMAKE_INSTALL_PREFIX ".")
+project(tosa_tools LANGUAGES CXX)
+
+option(TOSA_TOOLS_BUILD_SERIALIZATION "Enable building of Tosa Serialization Library" ON)
+option(TOSA_TOOLS_BUILD_REFERENCE_MODEL "Enable building of Tosa Reference Model" ON)
+
+add_subdirectory(thirdparty)
+
+if(TOSA_TOOLS_BUILD_SERIALIZATION)
+ add_subdirectory(serialization)
+endif()
+
+if(TOSA_TOOLS_BUILD_REFERENCE_MODEL)
+ add_subdirectory(reference_model)
+endif()
diff --git a/NOTICE b/NOTICE
new file mode 100644
index 0000000..a350989
--- /dev/null
+++ b/NOTICE
@@ -0,0 +1,204 @@
+Copyright (c) 2020 ARM Limited.
+
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright {yyyy} {name of copyright owner}
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..47901e5
--- /dev/null
+++ b/README.md
@@ -0,0 +1,202 @@
+TOSA Reference Model
+=============
+
+# Introduction
+
+The *Tensor Operator Set Architecture (TOSA) Specification
+<https://git.mlplatform.org/tosa/specification.git/>* is a set of operators
+with defined accuracy and compatibility constraints that Arm expects
+to be implemented on its Neural Processing Units (NPUs). Most
+operators from the common ML frameworks (TensorFlow, PyTorch, etc)
+should be expressible in TOSA. TOSA is focused on inference, leaving
+training to the original frameworks.
+
+The *TOSA Reference Model* package provides a reference implementation
+and testing infrastructure for TOSA. The reference model consumes a
+FlatBuffers serialization of the network subgraph generated by the
+TOSA Serialization Library, along with input tensors for placeholder
+nodes in NumPy format. By default, the model validates and evalutes
+the network subgraph, and writes out the resulting output tensors in
+NumPy format.
+
+# Installation Requirements
+
+The *TOSA Reference Model* and testing suite requires the following
+tools:
+
+* CMake version 3.4 or later
+* GNU Make 4.1 or later
+* GCC (tested with 7.5.0) or Clang C++ compiler (tested with clang-9)
+ with C++17 support
+
+The model includes the TOSA Serialization Library, Eigen 3.3.7, and
+FlatBuffers 1.11.0 as git submodules. The model is written using
+C++17 and has been primarily tested on Ubuntu x86_64 18.04 LTS Linux
+systems.
+
+The testing infrastructure requires:
+* Python 3.6 or later
+* TensorFlow 2.3 or later
+* NumPy 1.15 or later
+
+Check out the required git submodules with:
+
+``` bash
+$ git submodule init
+$ git submodule update
+```
+
+# Compilation
+
+The *TOSA Reference Model* build can be prepared by creating makefiles using CMake:
+
+``` bash
+$ mkdir -p build
+$ cd build
+$ cmake ..
+```
+
+Optionally, `-DCMAKE_BUILD_MODE=Debug` can be used on the `cmake`
+command to create a debug build. Next compile using `make`:
+
+``` bash
+$ make
+```
+
+The resulting executable will be named:
+`reference_model/tosa_reference_model`. CMake only needs to be re-run
+if the build environment changes (e.g., new dependencies or source
+files). Code changes that do not affect these build rules can be
+rebuilt simply using `make`.
+
+# Usage
+
+The inputs to the *TOSA Reference Model* consist of a FlatBuffers file
+containing the serialized subgraph, a sequence of placeholder node
+name/input tensor NumPy file pairs (produced by an external tool), and
+a prefix for output tensor NumPy files (produced by the reference model).
+
+An example command is shown below:
+
+``` bash
+$ mkdir -p examples_out/test_add_1x4x4x4_f32
+$ ./build/reference_model/tosa_reference_model \
+ -Csubgraph_dir=examples/test_add_1x4x4x4_f32/flatbuffer-tflite \
+ -Csubgraph_file=test_add_1x4x4x4_f32.tosa \
+ -Cinput_dir=examples/test_add_1x4x4x4_f32/ \
+ -Coutput_dir=examples_out/test_add_1x4x4x4_f32/ \
+ -Coutput_tensor_prefix=ref_model_tflite_ \
+ -Cinput_tensor=InputTensor-tflite0:InputTensor-tflite0.npy,InputTensor-tflite1:InputTensor-tflite1.npy
+```
+
+On a successful execution, the output tensors will be written in NumPy
+format into output tensors in -Coutput_dir and prefixed with
+-Coutput_tensor_prefix.
+
+When using JSON-formatted FlatBuffers input (.json extension), the
+FlatBuffers schema file from the TOSA Serialization library must be
+specified using -Coperator_fbs=. When using the binary FlatBuffers
+format (.tosa), the schema is not necessary.
+
+## Examples
+
+The TOSA Reference Model distribution contains several example
+networks with inputs and reference outputs generated by
+TensorFlow or TensorFlow Lite in the examples directory.
+
+These examples can be run through the TOSA Reference model and should
+produce the equivalent TOSA-compliant reference output tensors.
+Please note that differences in floating-point ordering and rounding
+may cause small differences in output for floating-point tests and
+differences in quantized scaling between TensorFlow Lite and the TOSA
+Specification may cause differences in quantized integer tests.
+
+# Debugging
+
+The debugging facility can be enabled by setting a debug scope and
+debug level on the command line. For most purposes, the following
+flags will work: `-dALL -lHIGH`. Debug output can be directed to a
+file using the `-o` switch.
+
+# TOSA Unit Test Infrastructure
+
+The TOSA Unit Test infrastruture builds and runs self-contained tests
+for implementations of the *Tensor Operator Set Architecture (TOSA)
+Specification*. These tools directly generate TOSA operators for
+verification of the TOSA reference model against existing frameworks
+or other operator implementations.
+
+The test builder tool generates tests with random arguments and
+reference inputs for each TOSA operator. Currently, the test builder
+focuses on generating a wide range of legal arguments to each
+operator, but it also has limited support for generating tests with
+illegal arguments in order to make sure such usages are properly
+detected.
+
+The unit tests are typically structured as a combination of input
+placeholder nodes, const nodes, and attributes feeding into a single
+TOSA operator. The unit tests use a Python copy of the FlatBuffers
+schema written by ``flatc`` to verif/tosa.
+
+Each test has a JSON file which provides machine-readable metadata for
+the test, including the .tosa flatbuffer file, names, shapes, and
+NumPy filenames for each input and output tensor. There is also a
+boolean value for whether a failure is expected because the test is
+expected to trigger an invalid set of operands or attributes.
+
+The test runner tool executes the unit tests on the TOSA Reference
+Model to generate reference output tensor values (for legal tests).
+The test runner is a modular tool which can be exended to run the same
+tests on additional tools or frameworks. The reference output NumPy
+files are generated by this step and can be programatically compared
+with output of other tools. to validate those tools.
+
+## Usage
+
+### Unit Test Builder
+The test builder is in ``verif/tosa_verif_build_tests.py``. The
+builder generates test outputs in ``./vtest/<operator_name>/`` by
+default. To restrict test generation to particular regular expression
+wildcard, use the ``--filter `` argument. The tool can be run with no
+arguments to generate all tests.
+
+Inputs and certain attributes are created using a random number
+generator, while others are exhaustive (within reasonable bounds)
+where the combinatorics allow exhaustive tests. The test generation
+is deterministic for a given random seed, but additional tests can be
+generated using ``--seed``. As many corner-case error are often
+uncovered using creative tensor shapes, the random seed parameter will
+help get coverage of additional shapes.
+
+Additional parameters on some operators can be found in the command
+line help.
+
+### Unit Test Runner
+
+The unit test running script takes self-contained unit tests from the
+builder and runs them on the reference model. Shell wildcards can be
+used to run more than one test at a time and tests can be run in
+parallel using the ``-j`` switch. For example, to run all of the
+add operator tests:
+
+``` bash
+$ ./verif/tosa_verif_run_ref.py -t vtest/add/add* -j 8
+```
+
+The test runner is quiet by default, so running a large number of
+tests without any obvious errors will show no output while the tests
+are running. The ``-v`` switch will show the command being run in the
+background.
+
+To enable debugging on the reference model, shortcut commands have
+been provided: ``--ref-debug=high`` and ``--ref-intermediates`` to
+turn on debugging and dump intermediate tensor values.
+
+Additional Systems Under Test (SUTs), such as reference
+implementations of operators, full frameworks, etc, can be defined by
+extending the TosaTestRunner class. The SUTs can then be enabled by
+using the ``--sut-module`` flag.
+
+# License
+
+The *TOSA Reference Model* and TOSA Unit Tests are licensed under Apache-2.0.
diff --git a/examples/test_add_1x4x4x4_f32/InputTensor-tf0.npy b/examples/test_add_1x4x4x4_f32/InputTensor-tf0.npy
new file mode 100644
index 0000000..1b3effb
--- /dev/null
+++ b/examples/test_add_1x4x4x4_f32/InputTensor-tf0.npy
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/InputTensor-tf1.npy b/examples/test_add_1x4x4x4_f32/InputTensor-tf1.npy
new file mode 100644
index 0000000..f233cd4
--- /dev/null
+++ b/examples/test_add_1x4x4x4_f32/InputTensor-tf1.npy
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/InputTensor-tflite0.npy b/examples/test_add_1x4x4x4_f32/InputTensor-tflite0.npy
new file mode 100644
index 0000000..1b3effb
--- /dev/null
+++ b/examples/test_add_1x4x4x4_f32/InputTensor-tflite0.npy
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/InputTensor-tflite1.npy b/examples/test_add_1x4x4x4_f32/InputTensor-tflite1.npy
new file mode 100644
index 0000000..f233cd4
--- /dev/null
+++ b/examples/test_add_1x4x4x4_f32/InputTensor-tflite1.npy
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa b/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa
new file mode 100644
index 0000000..673efdb
--- /dev/null
+++ b/examples/test_add_1x4x4x4_f32/flatbuffer-tf/test_add_1x4x4x4_f32.tosa
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa b/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa
new file mode 100644
index 0000000..8035292
--- /dev/null
+++ b/examples/test_add_1x4x4x4_f32/flatbuffer-tflite/test_add_1x4x4x4_f32.tosa
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/placeholder_0.npy b/examples/test_add_1x4x4x4_f32/placeholder_0.npy
new file mode 100644
index 0000000..1b3effb
--- /dev/null
+++ b/examples/test_add_1x4x4x4_f32/placeholder_0.npy
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/placeholder_1.npy b/examples/test_add_1x4x4x4_f32/placeholder_1.npy
new file mode 100644
index 0000000..f233cd4
--- /dev/null
+++ b/examples/test_add_1x4x4x4_f32/placeholder_1.npy
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/ref_ofm.npy b/examples/test_add_1x4x4x4_f32/ref_ofm.npy
new file mode 100644
index 0000000..71e2008
--- /dev/null
+++ b/examples/test_add_1x4x4x4_f32/ref_ofm.npy
Binary files differ
diff --git a/examples/test_add_1x4x4x4_f32/test.json b/examples/test_add_1x4x4x4_f32/test.json
new file mode 100644
index 0000000..cd993ef
--- /dev/null
+++ b/examples/test_add_1x4x4x4_f32/test.json
@@ -0,0 +1,27 @@
+{
+ "pb": "test_tf.pb",
+ "ofm_file": "ref_ofm.npy",
+ "ofm_placeholder": "result",
+ "ifm_file": [
+ "placeholder_0.npy",
+ "placeholder_1.npy"
+ ],
+ "ifm_placeholder": [
+ "placeholder_0:0",
+ "placeholder_1:0"
+ ],
+ "ifm_shape": [
+ [
+ 1,
+ 4,
+ 4,
+ 1
+ ],
+ [
+ 1,
+ 4,
+ 4,
+ 4
+ ]
+ ]
+} \ No newline at end of file
diff --git a/examples/test_add_1x4x4x4_f32/test_tf.pb b/examples/test_add_1x4x4x4_f32/test_tf.pb
new file mode 100644
index 0000000..dae00ee
--- /dev/null
+++ b/examples/test_add_1x4x4x4_f32/test_tf.pb
@@ -0,0 +1,112 @@
+node {
+ name: "keras_learning_phase/input"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_BOOL
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_BOOL
+ tensor_shape {
+ }
+ bool_val: false
+ }
+ }
+ }
+}
+node {
+ name: "keras_learning_phase"
+ op: "PlaceholderWithDefault"
+ input: "keras_learning_phase/input"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_BOOL
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "placeholder_0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 4
+ }
+ dim {
+ size: 4
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+}
+node {
+ name: "placeholder_1"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 4
+ }
+ dim {
+ size: 4
+ }
+ dim {
+ size: 4
+ }
+ }
+ }
+ }
+}
+node {
+ name: "result"
+ op: "Add"
+ input: "placeholder_0"
+ input: "placeholder_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+versions {
+ producer: 498
+}
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/InputTensor-tf0.npy b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/InputTensor-tf0.npy
new file mode 100644
index 0000000..328dbd8
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/InputTensor-tf0.npy
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/InputTensor-tflite0.npy b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/InputTensor-tflite0.npy
new file mode 100644
index 0000000..328dbd8
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/InputTensor-tflite0.npy
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/layer_1.npy b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/layer_1.npy
new file mode 100644
index 0000000..79bf0ec
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/layer_1.npy
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/layer_2.npy b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/layer_2.npy
new file mode 100644
index 0000000..42ff6d7
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/layer_2.npy
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/layer_4.npy b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/layer_4.npy
new file mode 100644
index 0000000..ec9d526
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/layer_4.npy
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
new file mode 100644
index 0000000..42f11de
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/layer_1.npy b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/layer_1.npy
new file mode 100644
index 0000000..ec9d526
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/layer_1.npy
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/layer_2.npy b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/layer_2.npy
new file mode 100644
index 0000000..207be76
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/layer_2.npy
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
new file mode 100644
index 0000000..105b755
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/placeholder_0.npy b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/placeholder_0.npy
new file mode 100644
index 0000000..328dbd8
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/placeholder_0.npy
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/ref_ofm.npy b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/ref_ofm.npy
new file mode 100644
index 0000000..d0ee52f
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/ref_ofm.npy
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/test.json b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/test.json
new file mode 100644
index 0000000..5962306
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/test.json
@@ -0,0 +1,19 @@
+{
+ "pb": "test_tf.pb",
+ "ofm_file": "ref_ofm.npy",
+ "ofm_placeholder": "result",
+ "ifm_file": [
+ "placeholder_0.npy"
+ ],
+ "ifm_placeholder": [
+ "placeholder_0:0"
+ ],
+ "ifm_shape": [
+ [
+ 1,
+ 32,
+ 32,
+ 8
+ ]
+ ]
+} \ No newline at end of file
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/test_tf.pb b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/test_tf.pb
new file mode 100644
index 0000000..3b00d26
--- /dev/null
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/test_tf.pb
@@ -0,0 +1,163 @@
+node {
+ name: "keras_learning_phase/input"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_BOOL
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_BOOL
+ tensor_shape {
+ }
+ bool_val: false
+ }
+ }
+ }
+}
+node {
+ name: "keras_learning_phase"
+ op: "PlaceholderWithDefault"
+ input: "keras_learning_phase/input"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_BOOL
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "placeholder_0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 32
+ }
+ dim {
+ size: 32
+ }
+ dim {
+ size: 8
+ }
+ }
+ }
+ }
+}
+node {
+ name: "const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ dim {
+ size: 8
+ }
+ dim {
+ size: 16
+ }
+ }
+ tensor_content: "\234\225\213?\343\017\327>\324\256K\277\371\365->9KC??q\254?Gp\031\276\257I\305\276.UL\277\336\263\260?\272\245#?\2343\337\277\016\312\363\277I\003\353\277\232F\325\277\255\264\301\277t@A\277\311#;\276D-\374?\347/\322?\026U\242?3E\217?\217\225\206?\204v\360\277\033[\217?\310\033\260\276\273\277\331\276w{\257?v\315\310?u\252\343?0\254\222?m\004\274?4\3420\277n\016\261\277D\362\202\276\251}j?\225\242\014\277\240\322\300?\233\205\202?\362\375\362?\377\302\222?\337\t\275=\207\021\r\277\"4\337\277\351?\354\276\037\362f>\326Z\214>2o\364?\201\357\357;\2703\332?[;J\277Ta\253\277\335}\324\277\354) ?&5\253?\353\257\231?\031~\242\277\225i\006\277\025\225\343?\rs\227\274\007\367L?uH\231?\303\027\305\276\315\247\260=_x\374?6\305\310\277\017\337\350?\206@\035\276\217\235\355?\300d\323?H*\266\277\223\300\332\277\325\251\256?9j\007\275\356\025T?\006\240\302\2779)N?\3212\217\274\233p\027\277\312a\212>\265ly?\310?\247?\345%\356\2767\257\342\277\345\276\226=\367\202\267\276!\254\307=\375\326!?\017\256\274\273\006\321\201?@0\242\277\325\333\355?\353\030\212\276\025|\000>\224\2353\277\270\250@\277\213\014\246\277\275q-\277\225\370\366?7\033\363\276\331\246\022?\306\3439\277\242\334\030\277(K\375?W\322G>d\350K\277\236\2754?\177\262\252\277\024\327K>\221\220\306?D\244\304\276\243 \342?t\326\026?u7\363?\365\000\200\276*\215\361\276?\256\035\277\210o\266\277&\026\"\277\342\021i\276;m\317\276\373N\037?T?\205?$\340\361?\230\032\034\277\004\235*?\225\254\324?\263\207\016>"
+ }
+ }
+ }
+}
+node {
+ name: "result"
+ op: "Conv2D"
+ input: "placeholder_0"
+ input: "const_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "dilations"
+ value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "explicit_paddings"
+ value {
+ list {
+ }
+ }
+ }
+ attr {
+ key: "padding"
+ value {
+ s: "SAME"
+ }
+ }
+ attr {
+ key: "strides"
+ value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "use_cudnn_on_gpu"
+ value {
+ b: true
+ }
+ }
+}
+versions {
+ producer: 498
+}
diff --git a/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/InputTensor-tflite0.npy b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/InputTensor-tflite0.npy
new file mode 100644
index 0000000..c9e9b19
--- /dev/null
+++ b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/InputTensor-tflite0.npy
Binary files differ
diff --git a/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/layer_1.npy b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/layer_1.npy
new file mode 100644
index 0000000..0bbfb0c
--- /dev/null
+++ b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/layer_1.npy
Binary files differ
diff --git a/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/layer_2.npy b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/layer_2.npy
new file mode 100644
index 0000000..67c2421
--- /dev/null
+++ b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/layer_2.npy
Binary files differ
diff --git a/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
new file mode 100644
index 0000000..d5b4fab
--- /dev/null
+++ b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/placeholder_0.npy b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/placeholder_0.npy
new file mode 100644
index 0000000..d20c639
--- /dev/null
+++ b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/placeholder_0.npy
Binary files differ
diff --git a/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/placeholder_0_quant.npy b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/placeholder_0_quant.npy
new file mode 100644
index 0000000..c9e9b19
--- /dev/null
+++ b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/placeholder_0_quant.npy
Binary files differ
diff --git a/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/ref_ofm.npy b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/ref_ofm.npy
new file mode 100644
index 0000000..e8f4b56
--- /dev/null
+++ b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/ref_ofm.npy
Binary files differ
diff --git a/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/test.json b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/test.json
new file mode 100644
index 0000000..3331b46
--- /dev/null
+++ b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/test.json
@@ -0,0 +1,27 @@
+{
+ "pb": "test_tf.pb",
+ "tflite": "test.tflite",
+ "ofm_file": "ref_ofm.npy",
+ "ofm_placeholder": "result",
+ "ifm_file": [
+ "placeholder_0.npy"
+ ],
+ "ifm_quant_file": [
+ "placeholder_0_quant.npy"
+ ],
+ "ifm_placeholder": [
+ "placeholder_0:0"
+ ],
+ "ifm_shape": [
+ [
+ 1,
+ 32,
+ 32,
+ 8
+ ]
+ ],
+ "framework_exclusions": [
+ "tf"
+ ],
+ "quantized": 1
+} \ No newline at end of file
diff --git a/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/test_tf.pb b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/test_tf.pb
new file mode 100644
index 0000000..65d7a78
--- /dev/null
+++ b/examples/test_fakequant_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/test_tf.pb
@@ -0,0 +1,312 @@
+node {
+ name: "keras_learning_phase/input"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_BOOL
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_BOOL
+ tensor_shape {
+ }
+ bool_val: false
+ }
+ }
+ }
+}
+node {
+ name: "keras_learning_phase"
+ op: "PlaceholderWithDefault"
+ input: "keras_learning_phase/input"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_BOOL
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "placeholder_0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 32
+ }
+ dim {
+ size: 32
+ }
+ dim {
+ size: 8
+ }
+ }
+ }
+ }
+}
+node {
+ name: "const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ dim {
+ size: 8
+ }
+ dim {
+ size: 16
+ }
+ }
+ tensor_content: "\266\"\346\277}\306\256\276,\3103\277\334\364\310=\3762\200\277\325@{?\n\312n=\002\202M\276\314\177g\277\215\233\376\277S\315\t\277\370\357\233\277\216\300\226=L0\377?\316\273\323\277\270G\226?\032l\353>\334\376j\277\212dj?\305\032\210?!t\235\277\377X\324\277\3751=?E\022\235\276\033\353\356\277\3570\342>r\203\016\277a\'\266\277\266\021\034>\306\225C\276I\370\247\276\024d\245?\233\262\355\277\022\254*\277\265\0026\277\266\333c?\237Z\220\277~x>\277p2\216?\001C\267\277v\254\247\277\026\322(?\231\r\346\276\252\372\352?\315+\240\277\213f\320\276N\233\367\276\002\262\241\277=\177\335>\234\327\275>+\027M\277$\324\265\277\351\373\227?\261\266L\277tP\024\277\203\353\000\275}\334P=Rr\232?\240y\373?\343\264\230?\037\316\212\276\232\203\347?\257O\361\276\262j\243?\251n(\277\002\266v?\031%\216?\311s\235\277f\274:\277\274\"\205?\262a\225>\251,\351>\025\204\352\277\204?\310?\017{\370>\023\240\226\276w\r\246?\337\274\311?\267\241\325?\206\341\325?\005\246\276\2779\036L\277h\376\260\277\201\247\306?\207R4?s\304\230\277\3503\215\277\310\t\205\277X\336\255\277\275] \277F?\215\275pj\370?\346yK\277\0028\272=\207e\343\277\231,\371?\004Z{>N\261\347\277\355\323\007?\310\275\221\2771\266\206\277\010\307\323\277!5\262?\244\366)?\n\360\321?\300B9\277Py\324?Y5\313\277I\n\244\277\313W<?i\221\227?\003\327\245\277\265<\264?\240\363\253?\225\3607?bX\301\277\230\250\200\276\327C\237\276\023X\373\277\336\235G\277]\252\206?\322\260\366\277^l\264?f_\031\277\355w\340?T\025\331=\001\276\331?\335\231\230="
+ }
+ }
+ }
+}
+node {
+ name: "gen_quant_npy/inputs"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 32
+ }
+ dim {
+ size: 32
+ }
+ dim {
+ size: 8
+ }
+ }
+ tensor_content: "\267*\224\277\323\277\231\276hE\323?\321\230\336\277\355n\341\277\031\030\345\276Z\223\037\277 \006\\\276\006\213\333>\345\014\315\277!\210\014\277$Iq?\254s[\2767m\204?\031-\201?;\314\026>/\377\261?\264b\275?\213\314a\277D\370\213?\212\342\361\277\033c\227?\253p\261\276\235\t\217>}\316\213?%.\364\277,bC?\261\177\337?\273l\200>o\035r?\275\327\226?A\346O\277]\340\330?\374{\001\277\303\036u?\331j\237>\357\002\244?l\270\224\277+\254:\277\365\3254\275{*\260\277P\326\311\277;\363\376>W+\236>\226\034\025>\336\353s\276Y?\363\275Z\361\375\276L\240\337\277\205\237\326\277_\t\223>\036\312\335\277I\240\251\277\\\367\222?\322\037\345\277\231e \277\225\260R\277\301P\271?5\016\365?\337r\335\276y\300\364\277lN\234=\245U\210?\\E\265?h~@\277_e&?\373\241\367\277-a[?\003\032|\276\360\307\376?\2023\244?\367\243\344\277p)\215?\367]\272?\013\3341\277[\024\374?\263x\311\276x\341\242\2768F\361?\362\360`?\340\225d\277S\204\214?\257\244\362\276\\z\373?*\031\330\277Pz\363\277\270d\016\277\330\210\024\277\024yN>\234\250\215?\222\344\273\277\346\277\343?\300\320\313?\016\207a?&\236\263?\340\231\243\277\214\270\363\277\344u\346\2760\376\231?Z?\315?\331*\223\277\234#\323\277,o\224?Bw\231?\2452,?\263\377\002?\334\026\337?`\274\022\277\327\277\273?\237\312\276\277\274\227\261?\3714\341?ibC\277Q\364\205\277\206tQ\276J\324\205?F\207\300?\362\2154\277\005\231\202\277+\355\300?\315X\216>t=\376?\013[\333\277\362\366\031>\222W\271\277\002C\216\277\007\265\216=\330C\334?\036\245\305?0\374\365>m\214\201\276\205\303\357>\270\215\305\277;\305\376\277\252\375\214\277\222o\210\2765m\274\277(\3737\277%it\277(V\267?\355g\322?Hz\224;\330\346x\277K\207\354?*\000\026>\265H\272\276Jo\216\276\014X\325?\320\267\245=\353\225\250?\314\361\300\276>8\025?\277\270\255>\006\3714\277\337\340\301<F\177\354\277\242Y\341\277\0313\263\277\217\304\223>\327\316a?\314y~?\nF\310\276C\317\226>\331\274\303?\030\206}?\035\n\272\277\005$5\276\367+\206?b\"\371?\261s\200?\003<\014>\325$\233\277\332]\204?\020\007\222?\025\237\261\277\331\225\375?\t\356\034\277.\273\360\276G\221\325\276\350\010\316>\350\303x\277\247\347\262?`\207\204\277\034F ?\306!@\277\177\355\270\275\306\257j\2779\315\234?=\003#>/\234\213\277?C\245?\331)\"\277\347tH\276=`\260?\\\235\225?r\236\351?\036:\320\277\244]\210?\227SS\277\213\262\362?4\177\215\277$\322\324?\365f7\276\273\350\227\277\271t\215?\211\366~?\216Y\202?\276\235\224\277\033U\243?$i\303\276\235P\371\277\2059\222\277\301\307D?\366o\327?\0139m\277\344\377b>\204\344\202\277\246\234\337?\007K\034\277\314&\360>\213\003\353\277\026\252\341\277\014FT\277\366\224}?\271\260\341\277\236)\333\276\356\266\205?16\352\277\327\332\236\277\322\241\243>B\327\313=\260H\233?\246r\247?\365\030\217\276\362\362\207\2772\357\314\277\230J$?\254\307\236>\221)\256?\356\314\275\277\r\035\360\276\204e\231?\007\367\274\277-\255\007\277\352\234\\\277\325\225\211?\035\263\226<F]\361?i\261*?Y\305\323\277\261\030\305=\362U\264?\235\275O?-\307Q\277[s\261?\302\221\376\276\3203\324\277\005\023\355?\016\250n?\326\353\363>\\9\307\277\277/\324\277\363 t?u\252s>\366\367;\277d\252\207?\010:\357\277\322\343\220\276\214\\\020?\230\257\367\277k1\005?XO\034?<\272\364?\000\177\312\275\362\376\236?L(\232>\267\346<\276s?\313\277-\304\345\276\335v\276\277\307\004\203\277\331\213\\\277\305\370\337\2753\302\277?!Q\361?w\010<?\025,\353\277\232qa?5\t+\277Vb\256\277dV\203\277\250\252\254\276\233B\323\277/-\330>3\225\200\277x\374\333\277(\345\254?6\325\257\277Y\243\373?\242\210*?\304\311\375\276\264\371\242\277\343\0275?\234\333\350\277|\346\223?*KW?j\303\313?0\326C\277h\371%?\253@@\2770|d\277M\'P\277\270\030t\277\230h\212\277\355\021\203\275\025\037\254\277Q\311\337\277#\374.?\337\203\341?\305\t\215>\\\352!?\000\340\243\277F\002\260\277D\227\202\277+\207\221?L\010a\276xA\273?\237R\262\277\305\215v\277\021\343\225\277O7\226\276`&\366?n\3346>},\001?\262\3735\2775\300\356?b\006\220?\021\262\366?\002\351\345\277\000{\240\277%C\224\277\205\300\345\277\335\370K\276V\336/\276\217\215y?\357\362\356\276\364\262\255?_\327\201?\260Jw\277\264K?<h\007\210?n?\243\277\302\364\235\276\233\373\201\277\023\333o?{\033\270:+\245\373?\024J\250?\322\347\025\277y\010U?i\234\257?n\021f=A\233}\277\344\267\272>\201\352\227>\024bD?_\261\016\277B\241\362\277y\205\244\276\237\265\363\276h[p\277Q\010\313\276a\306\341\277\nl`?\035+\232\277P\315\342\277\275T\273\277J\245\003\277\317\302+\277fwt\277\030W\007>\267\334,?m^\'>\213c\245\276\321\217\024\277\033\026C? \230\361?\265\312\315\276\363\271\305?c0\231?\260\201\244?\263\266\n?>=\214\277D\212(\277\304Y\017\277\013\367\334?\242\177=\277\000\225m>>\227\023\277\020\340\256=p>\237\277n|\372\277\317-\322\277d\343\374?\316\272\371\277\260f,\277\005\267\202?1\375N\277\0367\350?\3059\346?\212$\301\277\315\177\307\277\346\265r\275\233m\346?\020a\303\276p\274\257\277\030\372p\276 \314\303\277\365\204C?\242\220A\277\322\352$?\023\331U\277Q\331\346>B\215\245\277s5;\277*1/?\3419U?\216\317\324\2776\214\364\276\r\261\317\277Z\266\374? \303c?\252\360\t\276\344\\d\276\351f\263\277\243\001\224?b\023\026?\300\221\304=2b\275?\253\203\"?\221\322\327?@\315\351<\237\313#=i\020v\277\325\325!?\360?\325?\275h\177?\335\002A?\366m\367\277\302Sw\277V\323/?\024]\252=\252\225\305?\251\246\036?\177\252\342>-\204\347>\002]\215=r\300\200?\233\332\367?\222\311p?*\353\213\276\014\376\371\277\314]\342\277\350\001\277?\362\300\340\276\376T\272?7\226\376\276\016\223\272\275\213\nb\276\014\355Q?\220WR\277\006\360\337?\306\365\211?\230\211\213\277\262\030<\276\355V\311=\2315W?\257\373\276\275\213w|>\243\364\317?\362I\330>\312\343\210\276<\352\350\277y\267\310\276\321\014v\276\027l\356?\"\245\353?\244i\334?\252\341\'>#\033\217\277K\223\331?\377q\235\277\374\342\212\277\372v\000\277\214`\321\276\276\350\263\277Q\261\211?\247~\243?1\377\237\277\354\207\227=\352\231\335?L\272\223>\243J\364\277V\264\320\277\366e\370\276XZ\316\275\304\223;?\333c\332\277-\307\366?\256\203\277?r\3010?\300W`?\270\201\035\277\026\036Y>\274L\232?\350\327\223\277O\351\357?\246\366\351\277\272\206\243?1\274\263\2760\010\342?\274\332\213>\233\253\305\277\314J{\275\325:]?\365d\224\276\212\240w\277X\376u\276<\300\262\276c\'\212?\222\013;\277d1.?\217\301G\276\360G\314\277\026\005\213\277\022@\242\2775\235\334?A\352\016?\3319\343\277\251\227N?\"E\026\276\220a0?\360^\323?^\264\356\277\211\361\007\276O\355\266\277\206\343\250\277\320\033\305\277\336\320+\276O\220\230?\325@\340?\240L\333??\301\202=d\253\347\277Q_\014>\365Z\326\276<\001\034\277U\347\326\277\023M\200\277\004q\302>h;\326?t\266y\277\324~\025\277Kh\201\277?\214K?\256\014\344\276\311\266{\277\371\nX\276c\272\343\277z\246\343\277_\001$\277\023\\Q>\233\204\265\277&v??5I\244>\016F\377?\034Fz\277Z\232\202?\r\332\364>X\235\256\277A\002\232\277\010|\315?\233\356\264\277X\275\200?\263\337\244\277\0005x\277<\272\232?8\352\262\276\365\326\030?\303\255R\277\272\265\346\277\335tN\277v\177\"\277\302\253\356\277\020]X?\224\312\340\276\025\004\323\276\326X\274?\224|\354\277[\240T?\277\025\316\277\240\332e>/C\322>\245\231\303\277[\335\206>\267\244\213\277\205eh\277\035\007\324?\026\247\352?\304\210\257?A\300\322\277\002>\017\276?v\374?\261\002\311?\003//?\324\316R\277\335M\203?\221\276\374?\262\313\034\272\306\344#\276=M\333?\202f\r?\327#\020\277\007\234\215\277\226N\214>)\344\247>\326\t\234\276\177\236\031>C\262\237\276\2446\266>9\236V?\006\242\231?\266\354W\2775d\270\277\242\241*\277{\354\236>\362\323D?.f\322\277\350\264<\277\214\304r?N\035\202?\250\251\304\277\351L\300?\260\037\261\277\307t\300>O\'\350\275\305g\n\277\177y\303\276w6\246\277\314x\354\277\251\321\303?y]\272?\223+~\277\235\347\353>\260\261\302?!\360\253\277\254\221\006?\033\350\232\277\211\275\262\277\354[\320\277\313\350\204?\222\373H?\3439h\277\017Za>^@q\277{\256\222>a\274\350=\374\364\003>\\Y\220\277\230\371\226?\3261\325=#r\347?$\n\336?\374\026\366\275\364D\266?\001\370\376?\006\036\316?\006CO\277\251\265p?\366 0?\'c\240?\351-\335\277\312\243\215\277D\357\235\277f\241\353?\353\026c\276\210u\201\277\371\'/\276\270\217\251\276\256\031\353?\266\0317\277]H\243\277< \254\277\2139\340?\021\315\027?|S\344?\372.6>\224\272\204?\0350\272?\311\030\307?\022\322)\277\2521\301?\267\315\363?\273W\353\277F*\204?\332d\232?&\301\205\277\222Wp\2752\277\336?\371\316z\277x\367\266?e\243\372?\363\366=?\016\210\262>\227}\262\277\236\347\201\277\253\354\277\277\245K\276?\271\337\266\277\303\253\204\277UE\357\277RK\241\277\247\342\335\276vm\232\277`\274\000?>\020x?\267~V?\244+\237\277\327\330\235?\021\222\227\277+\253>?\033@\256\277\211W\275?\230\003y>C\267\351\276\312\374\253\277\200\266\006\277Y\325\355?\256\023\335\274\374Y\240=\303\213\360\277b)\317\277\031\307&>\337\220Q\276j!\233??cE?\'Z\247?\021\260X>1\256\362?\332\233\311\277l\316\274?l\371\220?B\265\361\277\266s\331\275,\0132\277D\261{?\350\356\374\277,\206\001\276\247S$>\357\224\350?*q\007?w\014\333\276\334\327\241?\343\316\313?\221\014\306?tq\274?\004Y_\277\257\253\275?\3511\372?\232Sc?\216\362\013\276\006<)\277y\030\207\277^;\320\277\276\024\233?)\376\024\276\245\2669\275P\223\347\275\210\226\241?Q\263\254\277_\270\212?D\3539\276\277\320\367>\003\277\262\277\337\346\300\277W_~?\222\207\365?|[\236\277\032\3073\276N^\\\276\312q\356\277\353\221\360>\010\013\343?\\*\200>v\342\031>\344\377\376>\241?\004\277U\241\326?N\242\201?\301\336\361?\177\270\244\276\237f\342\2778\231r?\345\202[?*`\200\277h\021\221?\301C\360\277\273\021g\276;\033\232\276\216\323\347\277\344e\257?\005\333\352?P\354,\276(\000\235\277eB_?\361Zb>\237\036\315?\002^0=\026\002\263\277\2240\243\272\177\343B\277#\305\253?w\352\332\277~\251k\277\034\304\302?\372\220y\277\310\356\320\276\355\336\213\277\236\221\n\277\332\013\245?\036f\255\277\205=\225?\313\372\212\276$:\230\277\324\243\351?\250\231Y\277\260Y\262\276\367\272\207?}\210\220\276\rX\017?\224\251\366\277\212\0045\276\252\216h?\007\215H?c\304:\277\255\222$>\357=2\277+\236\324\277J\216P\277\024\006\300?q\037\366?\346t\350\276\330\252h\277r\317\205\277\024\204\251\277\266a\020?[\273\335\277}\003\332\277x!\200?de\t\277\265\377\373\276\230*1?\300n\247?^\313.\277\010v7\277<C\343\277,v\027?(\031<\276\231P\345>\240%,\276\251hF>\3650\016\277\360\223\205=\302\374p<\275v\213?\177\222\357>q!\037\277~\037\330?\2276\334\277\271[\342?\223\n\341?4\'\300\276\332\270@\277U\335\371\2774\211\317\276\272\020<\277M\000\206\277I\r\212\277\336\267\277>\242\347\342\276\234*s\277N>\324?\312\226\270\277\311\205#\277f\311\032\277\326\264p?\032\347\264?_o\372\277\033\364\t?d\342\256>\207r\332\277\177\207\342\277\310t\177=\224\215S?\272\310\235?\300\207\242??\177\256\276\036\205\036\276\307\0021\277%_\236\277*d\341?\320+5\276\266\024_?\365\232\260?d\367\254>5\207\014?\351\036\315?\335#\005?g\207\231\277\252w\306?\216\210\262>\314\271@\276\301yE\277\376\033\263\277~:\206=B\324\215?\370\003\027\277d\264g>\347\362\363\277\367\005y>\304\352\024>\355\004\226\276\003\254p>\254\\\264?\302\334\310>\341\304\337?\354\206(>)\013\007\277\336\217C\277!\324:\276\202\311i\276\357\301\262?\243\007$?S\226\326\277]J\212\277\035\372\231\276\241#\237?\3251\247\2770\0068\277\253\201\317>\322|\226\277\225[]?O\332\300>j\257\214>F\221\206\277h\014\336<\346\201\303\277\351\004\306\276<\360\253?\304\336.?h1\215?\374.\362\2771\342\314\277\3211\332?\227\371\225\276\226\006\271\277\270\336\025\277\3019\305\276\213]\333\276\377-\244=\250\311\205\277\247>Y?I\362\226?\030hl\277\3219\000?\300j\264>\2764\220?\373sU?&\014\253\277\035\023\302?s*\205?\243\006\256\277\373\222\225\277\366F\376\276%\026,?\252\241\354?}4\360>`\371\304\277b\234\344\277\300\376N\277}=\206?^\306\207\277\005v\234\276\226\221p\277\2572\234?\\\373\325\277\022tZ>\204r\030\277\024\315\301\276Y\206E>x@\300>,S\231\277S\367\226?\232\t\342?l+\261>#\206\346?\017|\247\277]\303\236\277\251\322\031\277>k\373\276\001h\007?\361p\240\277\366\204 ?\233\301H\277B\357\300?\222\021\216\277q\251\320?\233ut>\356\006\337\276\034\211\356?\345,y?#j\207?\n\302\306\277\027L\377\277\t\226\350\277B\316\020\277\315\270\366?\016\006k>\370L\355\277\337`m\276pl\312?p\031\360\275\377\316\301?\223\355\360?\264\326\346?#\377\330\277\373\263\273\277y\"\366\277\211H\237?\374\361/=\355\310\035\277\204\361\232=\"\230!?Z\332\036\277\221\233\253\276X\3050\275\214:\355\277\205\307\332\277|\024Q\277\215\363\324\276\220\375\245\276\260\023\242\277\004mh\277%Hx?\375\270\036?\177\270\327\276\3413\027\277T\317\372?\207\220\355?\01432\276Ll\367?\306x\364\277\037g\224\277cR\253?\256\305\312?/\0041=c\371\317?\371\201\325?nM\267?\331\003K\277\3512\370>$\206\215\277\3479\335\276<\335t\276;bf\277\207\020*=]\210\303?\252\337\256\277\345\t\234>\014E\325?\301\326\303?\3457\264=\343\271\027?\372\2612\277\263\227\266\277u\223\221>Z\r\n\277e\254\200>v\274n?\361\220\245\277\346\340m\277\177|\305\277\032\034\213\277)(w\277o\025\260\276Q\023\332?wM\303?\007\306\245\275\262\227>>\240\346\204?Q\203=?\277x\263\277U\323\370?\302-\230?c$\007?\273\273a\276\314\013\035\277\346i\024\277\023\220\232?\213\335u>8\265\240?\2152\320?\323I\370\277\217PF>\364\344\205\276\352\341\313=\340\254\264\277\201,\201\277+a#?\014P\256?\034\300\343\277\254\032\256\277\241\244\352?\271\231\226\277\207\331\305?C\177\014\277>\033!?\226W\037\277T}\347\276\231\000i\277$\347\327?\266\022\376\275\027\265p=\201\0063\276\366\200\203?\374\253\376=\253\222Q\277\301\020C? \304e\276\345>\344=ZT-\276\350\261\326?\241\256\302\277=\325)?PT\236?Y\275\373\277,\315\002>\333P\352>\263m\225\276\276\335X?[\222E?kFs\277\005L\346?[o\330>\024\311\301\277\3671\214?\027\372\217\277r_\245\276\231\375\330?\325B\245\277#\031\017\275\311Y\253\277\334\3002\275\024\262\032?YP\234>\227\017\365\276\013\276\316?\304\312t\277O\264\364\2771\033\356\277\301\032\201?\364\001f>,\005\262?Y0\274\277+\214\365?\003\364\216\277a\302\262=\365T\363?\337\352\205?s\004s\276\221f??Y*\242\276\003\n\022\277\341\303\036?\233\375\344?[4\240\277*>\211?\373\373\324\2778\227\275?\307\365\325\277\356%\366?\3052\276?\001\207\356?\304\317\364?\356\316\225\277\313\354\202\277\2726\231>\303C\367\276\246\333b>P\213\361\276L.\351\277\3556\261\277\207g\030?\365\322V?\027\353\007\277\245\335\362?\313\r\342\276>\213{?mh\244?\374PI\277\341\017\276\277\223\240\224?\372\223\227\277o&\225?\002d\303\277\036\302\330>\201\322\014?\315\354f\277S\264\221\276\31736?R\213k\277\254\2074\276)K\277?\373\271\352\276,\356\323\277GQ\364\277\355b\207?i?7\277\007\345\313\276oc\265\277\312j\316>\345\334G>Q\361\027\277\241@\303\277g\224\360?\0174\350?\303\232~?\212&\341?\372\307\264>TI\333\277\\\343Y\277~\332\266\276\304o\277?\007\250\312\276[\360\302\277\3564`\2772i\013<1\023\216?E9s?\334\322\365>\273\310\205?\212\304 \277\027\232\367?\2567w?\363\274\357\277\223hP\277V\255\203\2778D\271?\242\333%\277\222\263D\277\\\022\346=\262\006\374?[\376<\277\302\240\221?\345\342\312?%=\210?7\025\330\277\246\352\354\277J\305\031\277Da1\277\363\022;?\304\020\266\277bDr?/\333\t?\031\245\262>\374\025\235?\3469\232\277v\t\177\277\212`\376\277!\024\306\277\330d\025\277&>\270?o\271<?*\273,?\244b\311\276\220`\273?8\361\210>fn ?~\372\324\277gE\367\277U\206Z\276\356\240\303\277\275 \261?9\264\211\275|\037\223\276\036\371\255\277\300\371\246?v\224\362?0E\344?\273\"E>n\246\010?E4\343?N\2463\277\002\277\222?\312\341\345?K\311\375?\022\233$\277\255J\212?\321i\301?\026\267\340?\346`\263\276\337u\330\277},\217?eZ\317?7\205\242?9[t>U\320\370?\260g\275\274,\315\372>D\325\334\277T6\374?\300[\221\277\346\355F=\251\002\211\277\217AG?m\230U?\032\016\346?\336\260\312\276\263\201\004?\3278\274\276\320\026b?\221\356\354\277W\300^?\204\257\317\275\\-!\277\253\327N\277\277\326\310?\274\235\265?\234\024\371\277\316\214\226\277%#<\277\206\007\177?d\024\350?\367\024\225?\376p\350:\"\013\264?\022\277z?f\013\223?8E\275\277\275\200\311?\037*\300\276\371a\323\276S\031e?N\2433\276\3452\300\277\344\206\307\275k\373\331>2(\355?\235\210\352\277\342\307{\277\3045g\2778\027\263?\275,\212?\226l\313?\333\366\277\276\201y\017?\273i\230\276\254U\247\277\036r\004?\274\036\337?\313\236\376?-\340\347\277\223\3206?\005\357\271?\036\244\214\277\004\272\324?=\341\223?\304\342\357\277\207MJ?\357\271\232\277\244\"\034\277_\270\233\277K\214\346?\365\354\217\277Q\326\177?o\224\032?\036~\325\276\272\032\t\277\256\026\243?\247\273\372?y\013\303\277t\351x\277q\245\260=Wj\341\277\200-\372\277\205]7?\201\024\232=\375G4\277\037\357\213\276x\333\337\276\021\010\335>1\227\216?\275+~\276\354\245(?^\335\017?\230>\214?\360\374\243?\363Y\336\277\253L\221\275d\021\353\274\345\010}?\213\024\361?\247\345\277\277\356[\213? \307\233\277\326\254\303\277iT\367\277\252\020\221?\n\307\214=\226\331\346?\315\014\005>\350o\\\277v\224\370\277\024A\255\277\310\310\247?_\201n?\362\276\363\277\001\031j?d\350Y\276Kq\262>U\325\005>\347\366\364\277e\311\264\277\212\232\217\277\266\222,?\005\272\373\277\204\274\276>\005\342>?\336M\233\275-b\324\275\002\3343?\373\370\376\277\312\360\302\277\211\003\352\277w\376\245\277\336\252\">\014\"U?\340\267\374?d]\312?\027\216\377?\204\247g\275\027\022\272?\334L\341?\340\031\305?\323\004_\276\376^\347\277\261\307\213\276P\341\232\273\351\'\321=\347\276\232\277\211\277\210?Z\r_\277y{\272\277\352\246\214>m\343\305?\374\362\375>\204^9\276\304\254\225\2767\260)\277O\204\277?nK\007?$\035\351\277\341\270\"?\225-x\277Pr0>\377\016\216\277\250\310\207?\225\362\246\276\251\273\332\277\005@\337\277\353q\377\277m\037\205?:T\212?X\215\323\277\336\365#\277\233\007f>\tlt\277V\021\320\2775?=\276\226(p?\240.@\276\271S\333\277\365\242\243\277\232\225H?Z\245\327\277d\343\320\277ee\027?\302\275\275?\273\247W\276h\302\234\277\030\r\206\277\271\316\031>c\303\273\277{V\312\276Y\007\361?s\307\311\277N\367\214\277\217\250\233?5\032\234\277\353\221O\277u\027:\274kV6\275\234\206\311?v\351\250?\330\225\315?\031\357p?\205v\204>\310o\315\277f7\352?\230\346)>\033t\243\277e\364\333>\360ya?\312\235\002?\237m\236?\206\017\223\277\251\201\036?\203\303\217\277\322\2176?\325\034u?\313\036/?\025\026\236?w\361\312?a8\365?y[\212\277\330es\277\325\233\375\277\035h\376?sh\264?\306\032\206\2779\005\341\277\034\225\327?\334B\245>\254{\226\277\201\334\305\276\242\251$?6\200\320\277\314Lg?!\014\271\276\315\025V?\357v\017;\001\024\350\2766\314\366\277,\372\366?S\324\326?\376\261\317\277\315\355\323>\335\224G>\3162\212\277a\3535>Fe\243?x\031\377?\240\027\310?v\311\264?J\024\232?\264\336\315?\207\367\357\277\335D\235? \247\335?\317\360\031\277\206\336\246?\032l\315\277m\343\355\277]r\325\276;\270E\277\265\022\245\277\256\250|?\224\256\252\277\222I\260\276\241R\362\277-\351\231\277\204\031\332>\265\211\327?\232\203\314\277\177\216\227\276v9\263\277\214\261\373?\372b\n<\256+\357\277e\321\350\277\232\023\017\276\226\030\316\276l,\324?\354\334\306?<\033T\277\355\365\336>\236\315\244\276O\331S\277O6\233\277&w\000\277\314\235/\277\027MV>\354_\315?\323L\372>[\240\242\277stW?eX\332>?\264\232\277!\256\216\277\353\364G\277\356S\361?\026;w\276\002\206\237\276d\206\232\277\231\021\234\276\006\370w?W\372\035\277\361\320\352\276S\271\374?\235\3072?l\036\323?U\366\037?\257;M?\235L?\276\251h\225><Z\234?\311\022\256\2779Zg>M!\207?\265\326\276?i\260\265\276\334\366X\277=G]\2773\001\024\277\342w\322?E\325\254\275kdK?\235\t\372?\237{\250>\'\256\350\277\265R&\277\2041\346?\026/\006?\247\370\273>\350`\036\277*\254\346?\300\343\330\277\002\355|\277\335g\200\276R\201F\275\302\302|?\265\004\212?7~q\276l\'\311?\234\211s?\231.4\277S\353\022>\t>\365\276\216\001#\277\003\300\037\277^=U\277\277+\254\276\202\361\003\276\257\306\273\277!\332\213\277\014n\366\277\332\313\321\277\027\037\317?~\030\266=\234T\330>N\352f<\255p\220>\326\363u\277\371\314\251\277\351\315>?+\020n\277\275W\200=\225\306x\277?C\270\277\235\360\212?\021\377\272\277\300\260\204?\212\"\345?\301 6?\236Sy?\204{\317>v\"\313=\013WH>\033y\337\277\341\223\207\277CBC>\327\255\270\275!{Y\276\rB\307\277\357\033\226?\251\201T?is\240\277S-\326?n\264(?P\346(?G\350\362?O\210T\277R\376\300\277S\316g\277x\333\365?\221\003\304\277%\207\251=\216\023\257?x^\345>_\203\240\277\210K\346\276\031\306\241?Qv\177\277\006\213\261\276\310\274\230\277\002S`?0\364\351\2763\210\207?\267\256w?\311\352\351?\002\211\205>\302@\334>s\223\303\277<\016\201?\010\263\376?\335\322=?\025\376\255\277\265\315y?\013\352\233\277w\267b>\301\r@>\264\307\350>\347\326\210?\203\267\214\276\301\241\277?\304Xl\276\337\331\272\277\345\237\264\277\326\364\236?\231\324V?\301\217\331\277M2\303?T]f\277\r\235\235\276\t\346\321\277\251h\007\276\033\355\227\277\300\336\t\277\3158\210?\316\374\343?3\266\005\277\242\2522\277\200\302Q\277\237\014\237?\331h\363?\030\257\321?\000 \032=\203\232\030\276\324\306N?\275\356\263?V\n\260\277\374\255\336?\245\213\233\277e\024{?95L?r1\223>p\240\371?\343,X>\327\234\212\277\350A\230>FI\200\277`B\373\277\362\231\324?\264;\322?\374\217\247\277\001\377\341\277\323/h>\013\352\323>\263\337\357\277\277\221\217?\213gC\276\305\245`\277YH\352\277\236\241Y\274m\217\266\275\010\333\255\277\014\377P?\371\230\364\276k\324j?\304\235\345\277\257\225m\277\242\374\376?\313\344]?\317i\334\2771\207\347\276\211\001\347\277hS\232?\327K\334\276F\357\250?\0132\257?a\001\351?4\321\306?\316Gm?\352\250\223?\200\200\346?\333\367:?\346\234\020\277G!\225>9\177\271\277\n+\276\277\306]\333\276M\260(\276&\331Y?\256\271&>J\277\031\277\374v\254>h\272\335>\021k\241\277-\247q>\254\247\372>`\366\263?\232Z\233?\340aP?\272\216\025\277\2221c\2778*{\277\343\023<?W\202\202?\346.%\277\205\333\211\277t\275\t\277\\]h=\276W\307\277N\202\310?\224\253\204?\354\244\212\277p^\366\274p>\024<\026oy\274\331\"\022\277v\003\233=\345\326\255?\304\307\314?\206\206q\277\305\206\247>\000q\301>\320\363\243\276\331y]?e\327\303?\366\321\376?$\"\221>Q\022L\277u\266\231\277D\210\337\2762\261\351?ot\223?\234\245\250?\345c\247\277,q\225\277\227\207\215\2765\035\272?\222\002\233?\263[\303\277Z\321\246\277\305$\027\277$J\300=yp\361?\321\020\'?\347\205\311?[\376\264\277\004\016\256=\006p\322\276\017\262\233\277\326\364\332?\024K\304?\004\356\345\276\224y\310\276\2025\251\2766\254\225=\242Y\371?\265\306\352?\0219\017\277\317\313\235\276\027\321D?\226\227v>\212MT?Y%\351?\251\300\014\277\302|\226?\267\\\365?)@;?\360B\270\277\320\226\376?\231\264\350\277dE\262?\000\224\316?/\345\200>\220\212\223?p\016H\2777K\227\277\223\337\215\277=\3656?\370\270\022>\377\035e\276\370D\345?\235\362\234\277\270\316\341>\265&\321\277\0344\275\277\337W\372\277}\2627?\317- ?0\271\305\277\364\033\355?\005\374f?a\305\354?\322\253\035?\0165\305\277\233\314\317\277\243\315\350?\316\304\311\277\273I\306\277\235\030\364?\226<\000?\002\277\034\276d\324\250\277}~\222?\035O\347>\024\216\337\277\267\225Y\277\\\300\335\277\n\235\354\277\243q\325?\251{\360?\360\366\225>\200t\257\277\270\221\345>\205\na\277\247^\271?\226\032\311<\017Y\361?S\205\204?\253\263\222?|\271\204\276B\233\250?\367\342\345\277\000\347X?\216\r\244?\244(\037?\352f\355\276rV\245\277a\177\371\277\022\353\'?\250\247\027\277e=\332>\021\004\250?\261\3529?\365\237\242\277\020h\344\277\236\315\317?\177\241H\277sw\333>g\304\032\277\354\225\311?H\001R?Z\007\320?}\264\265\277mF\240?r\322\351?HbF>*\305\212\277>v\231\277:\231y\277\20624?Lg<\277uv6\277\346\221\354>\303\213\266?D\000\304?0\277\227?\027\232\366\277\372\ru?\020)\227\277\\\353\335\277z\264v?\037\363\262\277\036J\254?\306!\302?\257\245\330\277\255\330\022>\226\373O?NE\241?\234\014\005?\354\0045\277(N\345?\260\017Y\277\315\325\205\277\245}\325\277\002\3422>\231f\356\277B\312\314?\222\nF?*\323\274\277\024\244\354\276R^\014\276\023\033 ?\240\005\355?\\\274\333\277\375\337\321?u\214\361\276\273\313\316\277\253\224\341\276gh\t?\032\326\331\276\232\370|\277\353\206\244?\323\354\316?\3128\326?su\312=\376\317\261?\034Q\313>\205\316\271\277On\322\277\237<\204?0\271\036\277\2069\201\277t\210\221\277\036\327\227?s7{\277)\225\326?\022T!?\241_\257\276\256\027d?\037O\211\277_q\035?\364\037\242?\325E-\277\320\241\214\277\013\026P\276\370\030\236>8\200\217?\260\001\340\277\2679\346\277~E\370?\204e\356?\224\234\017?\211\315\250?Z\355\347?^J$\277\005R\312?\202\207=?h\311\263\277\326\3059\277\273T1>\352\227\247\277\007\321\275\277qc\243\276\002Q\361\276\343)\243\277[%\036\277\343]\231\277Y\034\275\277S]\263?O\017t?\024\273\375?\275\375\334?\360\327\336\277\372\367\340>\376t@\275\353\373\302?\033\204\224?\270\354\277?u\350\340?\376\203n>U\224\302\276\346\340\370?N>\331\277\023G&?\037\371^\277\007_F\276\1776\337>\033\245\272\276\304\033\030\276\252\346\341\277\332\232}\2773\210\360?\nQ\303\277\334\324\320?l\035\326\277\341w\210\277V\317\374\277\203\250q\277d@\371\277{J\225?\274\224\347\276\342\232\362>\371\010\263\276e\027n\277\034\202\242\2778x\271\277 v\244?\022\244\312\276L\335\326\277\271r\331?\006\234\211\277(\375\360?7Zk\276\321o3\277\362\006\256?\361B\247={+r?.\001\214?\314\014\337\277@\034\037?2sb\277\332qI?\357\273\257=D2\320\277\030\225\372=\022\013\242\276\362z\375?qv\330?\230\021\362?*\234\260?\335\255\227\275\220\256\022?\370\272-?\316G\010\275\213\215\324?\t\350\344\276p\245\356\277\022\201m?B I\277\201\352\037\276\375\362\256>P9\354\277\323\n\377\275\345\315\321?\026\207y>i\216Y>!w\254\277\017\222\245?\370\362_?\036}\363=\3549\341?\343!\001\276kW\212?UZ\270?\353\216T\276\365T\363<\3645\252?\225X-\276\227\322\217\277\266\225v\277`r\317\275Cp\357?\331\324\302?\311\035;\277\270\234\357\277\374\021Q\277\230\034\235\275.\266\347?|\311\202>=\006n>\255X\354\276\\ \014\277N\335\316?\215\374\304\277\321E\360\277\020uD\277<\214\360?\371\333\211?\373(\245?\341v\346?\320X\027\277J\244\001?\024P\222>\214\264y?E\216\273?\370\256\202\276\352|\234\277\352\323\221\277p\032\227?\267e\363\277\311\217\275?\323\177\204\277\232\331\276?\324UH>\320}\232<\3000\375>\236\230\353?\027\370\326\277w\005\003?\336\347\214\277\035\351\245>r\241\336\277\276\3338?#\036\376?\3263\313\277uT\201\274\215Hx\276\210\357O?\250\016\017\277,\010r\275cQ\214\277\260\306\306\276\271\330\304?7\237\201?\374\315\326=\252\363\217?\254N\363\277\224f\'>\206\316\211\277\212\271=?c\234V?\316\272\227\277\325\361\244\277#\001\337?\034C\301\277b?\022\277\215\376\322\2774K\360\276E\224\246>I*\211?\351\\\266\277\244-\232\277\336\227\230?\246w\256\277\022g\210\277\361\201\214\277\350\242\243?J\274\272?\226L\366\277S\353L\2773\377\033?\240\220\337\277\310aV?glR\277\333\342\271\277\007\024\322?W\000O?\325\315\333?\230\002\315?R\206\276\277\210\230\265\277\016m\221\277\025Y\267?X\211\270?\327,\237>\245\227s?\r\002s\277\223\311\277\277\336\224\214\276\277\317\207?\336\312\252?9c\371?\306\263\365?5\315\216?[\271\316?\r^\304\2762$\270?Zq\353\277^yo>>G\272\275!\020\221\277\367\231\310?\343\036\332?U5\210?J3I?\206\231\316\276Q\354\334>q,k\277\234\351i\277\271S\335?j\237\346>1J\352\277\221_1<\271Ga?a\014\207?\003I\360?\351\006\234\277\314\202)\277\214\315\302?\234@\244?\035\350\355\277w\3007>s\271E\276C\247\037?\177|\320\277\270f\205>`\016.\277\371\341\207\277M\206\023\277eR??n\363\277\277\223\375U\277\273\013\204\277k%\243?\022\231X?<\304c>\352\237\225?h\271\312\277~PB\275=\010C?\325\376\034?L\007\330>\332l\243=*l\372\275(\\7\277\277\265\366?y\373K=/\021\247?\354\265}\277B\253\364\273mx\242\277\267\024\250\276\030\210\313\277f\013\316\277\025\254\002\277\220W\237\275D\222\247\277P\215\345\275\351\345\344\277\335\006n\2772\256?\277\265\300\n\277\260cT\276s$B\277\330\221\235\277o\334Z\276\375\367\340?\027\302\261?\005\334\362>H\262\316\277\361\222\370?\n\252\242\277d\007\347\277#\311\220>\240J\265\277\360\253\224\277\242\233h?\273\322o=\007\031e\277\225\010\244?\332\210!\277\001\233[\276\372\344\333>f\346\354?\2621q?\003\373\255\277\253F\200?\347=t\277\325=\303\277\221\365\200>\325O\367\276cv\220?h\370 \277\233\3024<\3641\341>\016\253\366=KXK>\227W\205?oV\273\277\354:\356\276C\317\367?\352\265\362?\346\361\366?-\216\267\277D\260\311\276+2\372?\\\220\206>x\225\352>B}\306?\212\322A?\227]\347<\255V>?\311\313\033\277\n\334\315\277\\\232\026\277\0148^\277\3201\346\276\303\202I\277\247\372\265\277\360\366\312\277\261x\274\277\305/\037?\271\000X<[\026\267?\360\177\025?\034\027\311\277\376\007\276\277-\341&?\364!\354?\336At\277\305\304\373\277$\202\244?\220;\240\277\222:5\277\231\031\352\276\367V\227\277Z\376\364?\261\037d?\227\330h?F\241\373?K\213\325\276\016\356\275\277\322\364\336\277\353\371\027\277\2075\243?\372\327\257\276Q#!\277\255\033\224\276y\330\350>L\000\324\277\311\307\320\277+\224\377?\216\217\032?\020mp?Z\032\227?d2\236\277H4\257>\326\202\274?\263\255F>\320\347\371\276\0377\354?\374\201\204?\001\376+=\265\271\230?\301\306\342\273\316B\200?=]\340?M`\221\277\177\265!?\343\253\311\276L2\336?s\2637\277_c\247\276a\023\323\277g2\235?{\253\366\275 \226\341?<\250\020?\\\335\303?\033\322\030\2771\010\210\277\006c\215\277i\200\270?\304F\320>\272\353\023?mc\234\277\t\374x\276\305\260h?I\277\260\277\361\216\344?\310C\342\277A\273\340\277\372\241O>\267c\324>\333\364\331?\251\306\326?\027\204\203\277\372R\232?pJ\031\277K\000\204\276\240\031\277\277\230k\251\277\016\255\207\277\330E\377\277mx\350?UY\256?\202s\236?8}\324>i\340\333\277\256 \357??_A?4\013\220\277G+\230>\334C\335?\022\344P\277\td\252\276\355?\317\277\332`\362\277\014y\025\274\225\363D?\237E\265\277b\345\276?\323g\206\277F\321\241>\201\373\221?\226\311\271?Y\3717\277\361\205\224?\234F\216?:\200\272?F@\216?$\214m\277\312\033\341?4\240\360\277\343\336\272?\306\253\275?\215J\260\277\362h\212>\316\313\214\276d\352o\277`m\237?\211\233\203\274\'=.\277k\000\032\276\207\022\032?\0136\231\277\000\373\236?\177\337\367?\365\365\220>\335\332{>\265\206\371?&\217\250?\365H\240\2776\214\023>\032 0\277\302\264\315\277\035\307~\277\242\n\346?i\256\252?\246\257Z\277\306\375\237?^Ti\277\202C\326?\225\250\356\277\204mM\277\3259|?\3638\235?\374B\225?eN\270?4,\331\276\235\020>\277&\351\241?\230\311:?[\232\211?+k\345\277~\311\315\277Xy\247?\326\222\231?7\363\303\277\224\357\262\277\3313\330\277\307\367\230?z\177(\2773\364\314\277\327K\334\275\310\2216\277\351\035\022?Jt]\276_\350]\277/m\213>\250\315\036?\024)\235\277\010\226\362\277\016\320>?\007\315\306\277JH\231\277\311\265B\2757\203\254?\024=\027\277\210\230\027\277\306_\316?x\315\262>\215 B\276\251\005d\277q\035\315?\347\031!\277Y\264\373\276\264\n\354\276\377<\373?\371\034\266?\370\315\343\277\003\302\240?\202!R<\361\273\362\277$\207p\277]>\253?]\223\006?P\330\300?\034\3262?\036~\215?\376\233\302\275\357k\341?E\224\247?\366\314\361\277\020/`\276i\017\313?\021\"\320?\032\245>>\267\310\332?\005\323C?\302\227\321\277\221\035\300=\036r6\277/T\353\277?\371\360?\010\262\350>U\262\216\275D\\\360\276\2150\353=j]:?@.\307\277\2401F?\331f\221>}j\353?q\303\205?\266\315\211\277\207\020R\2773~\245\277*^\035\277\024\032\223?\306\375\214\277HE\264\277B51<\301\013\350?\021\335\324\277)\343\220?\203\330\235\276\231\210q?E\270\273\277\037\360|<\367W\346?47\363\2779\234\240\276\317}\252?\240\330Z?\336Z\366\277s\343\230\273\023\222\005>y\224\377\2752V\025<=\000D\277\206\233\030\277\204#p?\366:\336?(H\245\276\212\236\310?\267}l\276\334\033\275\2764\346B>K\330Y?\033\217>\277a\200\265?9Si\276G\354D?\225;\243>\004\220\252?v\272\'?\371\342\022\277\364\307\321\277Lm\016\276\024}\344\277\tY\256>[\\\373\277\023\246V?\266\014_?\354`&>Lx\363\277`P\323?\035\313\271\276\351\300\324=[/|?\211yV<\260\223}\275e+L\275\030\'\017>6U\232?5\325\355?\327h\027?\t\207\366?+\266\036?\021P\376\277 \305\350?\374+\232\277\250Ow?\2165\000?\233\004\206\277\022\237\266\276\226{\'\277\375\270\230\277\330*\341\277\311*\270>\274~N?\233\031\376?\274\030\177?Z\241\337?\313\032\361\277e\004\373?L\223d?H\333l=4\234I?\362\213\205\277\234\341\314?\303F\346\277RnF?\266ut\275%\316\352\277\226\320\322\277tIb?\005\262\337\277\335\203\212\277\205(\263=z\n\233\277\024\244\275?\315\271\227?\004\242\020\276\320\217\335>\017Z\000>\017\220\356\277x\005\264\2745\222\240?\331\\\253\277\3029\024\277sD\307\277\240\261\331?\rY\223?\206s1\277\3025\267\277\241\234\335\277\350e\325\277K\245o\277&\205\242\276[\017O?\246\227\360\277-m\007\277-\251\236\276\260\257\303?\210\370\037\277=\363R\276@\3654\277\374\3451\277\203\273\251?\222\304*\277\177\237\367\277e\274\360\277(0\223\277\332\317P\276\342A\r?\236K\271\277\205u\221?\267\215\241>\356\324\322\277y\344\202?\020E\353?\n\251\034>Wu\327\277\317y??\370w\235?#:\254?\017\233\365\277\221R]>\302\373O\277\001o\001\277\0140\361\277\203\034\002\277C\003\235\277\343l\222>q\037\352\277v\241\210\277\345H\365\277\202Y\205?Xs\370?\314\223\356\277\232\246\324?\272\324b\277\272A\316>2\225}\277\265\373\374?\370\356E?\227<\001\277@\230\302?\324\320\373?\227\256\023\276\037Oy\276\261\352\246?\206\205\255\276j\331\253?\035\n\345?s\325*?\n\207e?\270h\345\2773y?\277\355Z\222\277|a\337?\252\2644?02\247?#\276}?K}H?7|\377\275^+K?\330\347:\276X\020%>\'\224Z\277!eC?\023\251\330\277\226\324\343\277 7|\277\222\243\203\277\315uU>X\274\234\277\360\350\240>1\210/>{[i?X\010\241\277\202(\335\277\352\255B>\302%\201\277\351F\343\277\244\304e=\030\303\375\277\366\220\351?\034q\030\277\314\001\305\277\n\363\320\277H\321\346\277\243\002\'\277\220\223\265\272\304\254\t\277\330\216\274?\223\002\215\277\243p\350?\021\346\006\277\004KJ=\350x\357\276\032\032\305?\263\232\207\277\032\245\303?\037d\225>\352\n\245\2776\\s?\261\350@?\200-\314\277|w\256\277\244J[>\235&\030?\242X\260?\010{v\277g)\363?aJL>w\037A>\334\"\204?R\340\347?\276g\212\277\202i\306\277?\307\224\277\237\347\022?\037\032\266\277\207\301\254\276\354\266r?x\"\302\277O\254\007\277\367d\342\277h9I?\235G\201?\212f\222\277(\205\310\277\007M\007\276JE\325?Yg\273>=\300\272\276\211K\030=\t\272\353\277\003\265\323\277F\217\001\277j~\337?\014\257`\276\317\032\235\277\347}\327?\235\272\020?b\006\246\277\007\275\352\277\332\'\251\277J\267b?b\215\270\277\323\227\341\277\246R%?qV\336\277\364(\271\277u\232W?\177j\023\277\036\376\347?\264\204\207\275f\252\366\277\323p\306?\262\224e\277\211\237\364\277\355\306\010?h$\263=\320f\017\276w\271\367?\235\001I\276\3711\202\276[\000\334\277\023\002\325?\323\347\361?:\300\213?f\252\241\277\2331\242>9\254h?j\220_\277\005\217\206>XN\374?\245&\244?^Q\242\277\002\376\303\277\234\313\314?\030\375\326?\027\223\264\277\221\343\263\275\305\203\336\277\261#\240?P+\230\277=\251\253>q\377\003?\207\241\317\277p\014(\277\'A\374?\320{\227\276\350\257\243\277\264\177\343?\204\244\343\276\234w(\277\355,z\277\030\344}>\262\221\334\2777\007\320?\273R\237\277\316\312\275<z\212\312?H\250\242?\230\272\\?\303f\337>\3144\271\277\300\367\307\277\276\224\217?\205 \231>7p\206>\367\277\221\277\016A\255\277v\320\005\277\315\240\217>\357\004\305\277\221\353\215=/\251\205\277\233\271\022>\210 \207\276\274&H\277\377i\245\277y\252\335?\337\306\215\276@f\353?l\252\245?b\342\234\277\254\262\270>d\324<\275\025.\277\277\202\264(?\215\232\313\276\372q\206?(L\360\276>\027\255\275\375\350w< k\374\277\330\266\260?\2455\000\277,\362\222?\026\311j>Y\014\267\277\366M\235>-\341\364?\241\251\n?\311\305\234\277u\303i;\250\302\327?0}\006?\212\2052\277\241\027\376?\274h\256>8Md?7\241\253\277\325\025\320\277)\300\243\276v\276\034\277\347[$?%3\241?\243\246`\277\224\3157?\264Y\343?\322\272\311\276\243&7\277\035\236\336?\255\327\r?\370\237`\277\270\234\273?\306\367\336\277l7\312?\262\370\230?\326.\360?\n\376\347\277gl\317?\235i\256\276\306\363\204\275s\220\306\277k.\215\273\250!\352\277\013\370\242\277\343-v\277,$\237?\"\241\205?\332\266\"\276c\325\277\277\232\352\270=B\216\315>\243\256\353\276\032\352\240\277\250\034\363\277SV\302\277\256\341\202?\031\360/\275\324\024x\277\214\204\336?O\355\233?o\027\210\277RP\031\276\226\003\243?o\344\027\277i\250\240\276\222}\255\276\334H\204\277\210>q\277\\\304\231\276\335\206\340?0\302A?\230E\355>H\2261?\335\014\200?U\341\212>\331Z\207\2774\000\204\276}D\366?bi&>\370$\336?\352\225\234?\207\022\364\277\253\374\341>)\326\207?\177\202\317=?\026c\275\031\212\333?\265\027\312>-6k?\021\352\233\277,M\024?\377\347\223?-:\014>\216\036\237?\213\307\022?\230oc?\251u\311\276\031\260\322?\320RM\277\225\037N\277\222\262\242<\221]\205?1\204^?*R\327\277/\030|\277\332\3212?\341\312\342?/)\334\277\375\223\317\277`\264\005\277\005\034Y\277\241\314\247=\262\372\356?\347Cr?I\340O=\240s:?\366\223#>\336\223}?a0\261?\025j\030?\331\376\270\274\360\265\374?Y\205e\277\267\336\321>&y\343?Z\017\227?\001\247;>\256O\200\277\237\351\375\277\303\347\362?O\333\276\2772\342Y\277>\213\224\277o\304I\274\274\350\320?\010{d>\036\217\224\276\361\363r\277\3570\334>\352}\026\277;\270T?<I\342\277e~\"?\\\201\277>^\035\310\277\305\037\300\277&4\214?\nG\234?1\204\023=8\340\346?\357Q\243>\214!D>%\002\330?W \224?\3709\376\277\366\324j;b\303c\277\224\326i?bK\006>n[I>\265\334\372\276,\033\367?\244?\353\2777\031:\277\347\003\250\277|\374\324\276\200\343A?\200G\203\277U\273\016?of\343\277$\007M\277\032\351b\277j\325\337\275\260\350\375\277\227\263\240?\203I\r\276\256\001o>\245\314\355>\376\375|\277l\037\370?\'\223\246\277\350\322??\006\363\265?v\331\372\277\227c\246\277\352I\343\277,i\243?q{\265?\313L\233\277\205\033\352\276\375\331\247\277F\307\321?\352\371\306>|O\207\2771\203o\277^\255\321=\037\021\201>\\#\366>\343\367\365\277\235\261\203\277{?w?\270\204\347\277\036\310\224\277\203Z\357?\035\275\317\276&M7\276\241\373J\2777\251\250?\320\270\306?\261\027\304=o\232O\277\023R\252?OL#?p\223\311?\212\305h\277s\272\243\275\267\370\260\276\025\353p\277\271\367\212\277\251\254\336?3\217\233\277`\237\230\277*H\317\277gYA>#\016\317\277)\251\246?\213\214\327>\312\333\263?9}\355?\212\342\303?g\307\260>\301\017\270\277\016p\315?\352\030\014\277\343\245\315\2763\3409?=\334\307\277\334\3707?\210m\356?<\211\364\276IX\325\277X\0008\277\366\240\305\277\226\366\356\277\223\n\321?8\346G\276\305s\337\274\250\274a>r\222\311?\352\351\300\277\231\206\266\275\211Au?\022S\340?\263\202\276?\313\313\037\277,\236\026>\264\177\262?\033o\023\275\310\267\343\277\200 \343\277\272\334\335\276\006\367\317?\031D\363\277\243\024\224=\261b9?\272\373\214>\022\261\241?\352\300\256>\360m\252=\363\355\275\277\356\327\213\276%7\332>\315\003\300\277\010\2221\277\324jv\277B\2425>\324\3137?\230(f?\277\367\362\277\216\270\202\276\236\365\240\277\212*6>\214E3?\036\r\237\277\001\2117\277\256\320\366?|\204\365>\301q)?\323\002\251\277\364pE\277Dl\254\276\030\251E>\014\366d=\002\001(\277\267\0175?\200e\005\276sfE?\"\310\373>R\245\364?}C`?\'\233\272?$\224\317<\256\030\335?\264y\320\277)2\322\277i\034\032\277\261\363\212?\026\017\024?\302\310t\277\350A\251\276\210\205X\277\333a]\277m((\277W/\306\276\335\0333?\311\373\374\277\210\345\216=\334X\000?\312U\244\277\273\321\276\276\0149\\\277df\341?H.r\277G\340\271?\213\205\306\276\\\262|?y&\314\276\0042\025\277\314\324J?\272}\211\277\347h\347\276<\273\361\277+@\367?#\235\320\277\000\205\247\276\251\036\205\277\213\266\217?(\330r\277\306B\322\277\270\213i?c\353\260\276\260\376\373>d\332\344?9/p\276\357pV?\271Z\035?\221D\272\277v\350s\277\335|\300?>\000\341\277H\3403\277\017\354\270?\343F>\277\365i,?8\177\332?\t\\\266\277e\205\213?\362\205\362\277\256\\\\\277\337=\227\275\n\366\363\277\253z\304\276\242\373\350?\261r\r\276gG\000\276G\250F?\360P\233\277\263\260\201?\273$]?\2124x>\204B\352\277gc\355?\262\300_\277&\304\247\273]\3612\277\300\304M\277\324t\304=\257\311\177>\373\303\346\277vV\215\275k,\234>C}\345?j\205\341\277\363\003\265?=\005\323?!|<\277\211/8?\271>\370\277\347S\215?\003\241\365\277\000\307\277=\371\203\206\275\221\3223\2772\020\217?W\300{\277\017r/>3-\017\277\"\024\341\277tw\026>\254\037\317?\321[\020?\356\002\t\275\213\212\021?\177cw?\305\362\321?\323o\267?\377\341\366=N\372\262\276\026\016\315?*\014\315\276]@p?\327\254\314\276O\007\363\275\233\017\277\277M\032\021\277\016F\313?\344\224\326\275\264S\221\276*\033\236\277\263\004\362?F\305\242\277\275\370}\277\363$\217\277\010T}\277\211\334\350=J\232\355\277L\242\331?]\327\315?\247\275\t?\213\234D?W\347\276\276\255[\231=\263\tn\277T\224\311\2770\202\362?\201,\300?\342v\274\277\241~\212\277\004r\217\277_;W\277j\353\256?%6\260?\264od\277C \216\277\315\361?\277\370\2549?\302\251\014\277\302C\331\277\t\311\207\277!\346\216?H\212c>w<\326\277\366.\"\275\245$\361?\346\224\261?\362\377\223>\213\013\252>Z\360\244\277\007!\251?\313>;\277\025YS\277\207B\261?\275\014\337\277\313\200\261?\365\375\362\277\034\257\265\277\310\300\360\276Y\256\363\277c\267F\276\275e\224?0\251\335>Z\007\221\277!\2353?\002\213\223\275\252\353H\277\353\350\333\276\3117\355?X\030N\277\305\324\364?\270\221\204?\004&\273>\264\000\231\277U-\241?L\220\256\277\353p)\277\351\263w?\236\'3?\277-\361\277^E\307?x\245\005?\351G\372\276jag?7\237t?6\240\311\277\206\234\035?dE\323?\272O\356?\274\342k\276\223r\225\276\007\346\235\277Y\352\305?\013y\232\275\203\266\303=\323\207\371\277N\025\372?\316\310\206\277\323\223\311\277[\212(?u\301\205>#0\001\277\311\204Y\277e\230\372\277f\267\250\277u]\341?\274\327M>e\010\256?\036\344\360?\033\0241?~\367\335?o\233\023?\300R\'?wF\321\277@\034\341\276\243+\206\276\234\177p\277m%\227?\177\236\020=X\003\266?\365\247\350\277W\221)?ZJ\025?[\366t?\340\247\302<\260\346\236>\tV\265\277\256\001\252?\017\336b?\021\001\212\277H\376}\275\r\030F=;;\330\277\025+\254?\n\336\327?\316\323G>\274\005\236\2774\207\336\2767\227\275\2774\262d\277\353KP>\2766\200\277\315\317\020\275\267\220!=^0\252?H\360\240?jI\321>A\242,\2776ms\276\344\332\243\275lh\267>\2217U?\344\217l:qI7?s\025\030?\334\037\367?\341\177\001>>\262\322\276\356M\010?\030\346l\277\247\370\252?\2443h\277\004J\372\277\350nI?\351;\317?\363\030\356?\014\3428\277\207\'\312?\261\356\336>\214\316\347?\231L\373?G\252\201\276\211\231P\277\211\366\245?\"\2768\277\"\255\200\277\010\331\004\277\274\244A>\331\257\016\277\244\017\t>\034\t\333\277\\\366\227>Tg\276>\'>\035\277\301\006\315<\274\377\224\277\367c\322?\342\343\230\277\327}\235\277\310%T?\210\205\315\275\227\300\215=\214\n\274\274\264\003\237\277\377\3263?\275+\371\277H\314\263\277\311\231\373\277\037\230\354\277\024\373N\273\310\303\356>\300\261\374>\331#:?\342\253;>\200\262\020\277\231o\005?]_\222?\267\242&\277\364\004\344?\345!\032?\256Lf\277\251e\377?@=\301\2774\032\r?\306\367\253\276\236\330\353?]\276h\276\355\336\236\276\204J\213?\212_8?81\032?z\360\267\277eX\232\277\370\021\353?\210|\004?K\365\034\276\347\200s>\244\240\345\274\354\000\215?\275]\203\277x\312\354=m\\u\275\304\201\313?\005\241\374?\"\341\263\275\252\th\277\275\307\244>)2\325\277\301g\212\276\2079\342\277d \316\276\241z\361\277\027\340-\276\026\036\374>\037\263Q>\002\221\217\277\035\256\321\277\033\010\002\277\177\3730?\250\232\272?\351\256\366?#\022\244?\004\372\016\276\341G\311\2777\352\201\277?n\341\277\362Z7\277.\307\246?\336W\207?q\334f?\365|\010\276\327\305\245\277\235H\315\277\030\022\376?\242\026e\277Z;\311>\276\254\376\277\350V\322\277Wr\227?L\017\213\275B\305\322?\0061\357\276\226\213\300?\262\205\302\276z\321m>\242\372\200\277\253t\262\276S\277\373?$\366\202\277X\270\010\277\314\200\233\277\311~\377\276\356\005\277\277\224i\250?\266\277\213?\203\236.?\276~L\277B\214w?yC\031?`\256\367\277\"\037&=c$\302?|\262+\276\275\374\322>P}v\277\025\267-\277\204\314\343?\270Xs>Z\n\322>\370\325\n\277u@\336\275Xj\365\277\206~_\277\270\000\327?3\271\211\276c\032\271\276P\371\317\277,\\\177\277n\025\226?o\\\320\2771f\243=E\301\251>\253\336\254?Y1\036\277\200U\234?s\336\313\277\350\353L?\362)9\277\241\031\374\277y\017\216\277#P-?|:H?\214\306\205\277\301\324\330?\320iU?\021\005\337\277\246\331\305\276\361?T\277\262\304\202\276\376t@\277\004\311\245\277a\026\331\277f0\353\276\302F\335\277\340\354\017>\031r\306?\\\204\340?\020\320\037>\257\263\264?\234#\352\276\332\365<\277L\200\322\277\314\0224?\263\330\354?\300\0050\274\275\0004\277\345\344\253>\221}\302\277]\361\207\277;\306]\276.\031&?\372\360;?\037ba?\"\241\271?\315\326b?\326\266)\277C\244\265>P\323\324>\342/\236\277*\240\014=\037H\241?\315?T?\026\227c?\2230C\277@\340\310?\207\237\356>\324v\315?\334\003\233\274\204\232\222\277\336\311\343\277\022\215\342>e@R\277OG\233?q\005\276\277\177H\361?\037:\244>\371D\035>\202\351\220>K\241\336?4\335D>vGS\277\261\t\246\277\211\370\233\277\305[\332?\355^P\276\242\314D\275\262\241\225\277g\233\342?\033\324;?S\271\254\277\3402\260?3\023\206?\300R\302\276\331\033\272\277`\037\231?S\004)\277\004\334v\277\037`\033>{~\300?.|\373\276\020\223\372?\017\314\265?\232\356\237?\004\246\334?\343\321\355?w\217B\277j.h?\241\350\252?Z\305\247?l_\254?WV\224\276Y\345\010\276u\\\367\277\256\032\363\277=7v?%\326n?N\351w>k5\333?\357j\312?\027\314\201?\303\305\\>\323\374%\2774\250\214>\222M\367\277WL\320?\024\337K?mf\255\276\222\324\034\277\344\255\316\276\036e;?\027\333\250\277\341\354w\277\340\341 \277A\204\017?t,\245\277\255=\253?$\026\276?\231@\272\276\374\3707\277\244V\250\276\300\247o?#X\237\277\023/\016\277!r\n\277\256\032\237?\n\371\271\277Pd\223\277m(\243?\245\371&\274\255\260\265\276\376V\332\277\025\231q?\314 \n?\320j\356?1\305\211\277\302\357\227\276\027\351\311?+t\374\277\376\023\377\277w)`?\216*\375=j\004\274\277\372\377\321\277\205\220\230\277\030\336\272?\033\016\226?X\017q\277t^\314?\200s\236?g\333\223?K\020v?\341u\252\277SB\022?\031J\006\277\331?\223>\007&\222?\201H\260\276\200\364\312\2778\020\301\277,\377\270\277\377f\252>\200\332\343\277\247>\037>\2379\271\277\276\313X\277Q\323\222?a%\317\277\241\304\350\277\223\211\343\277\342\230\375=I\013\361\2774\210\233>\270\360\234?\324\024\302\277\311y\201\277\005\235\230\276c\340\230\277\212\341\243?\307\243\370?]\263\331\277\212\r\264\277\301\224\002\277\265<\265\277\343\315\375\277C \317\276\364>\335\277\361\037\334?\337\016a\277\030\373\000\277Q\262\247?c\247\"?\301\223\007\277\\\002\263\277VP\206\276*\211\272\277^\013\261\276+]\035\275>h\301\277]\242\n\277g\016\361\277\312\020\334?\372\241\326\277\264\252!?\241(0\275\214\013\366\275@U\362\277\367\024%\277&\352\005\277\223z\331\277$D\352\277\364\026\316?\232\221>?\225\t\274>\264\026\001\276\333b\322?!<z\277\252h\256\276GY\242?\377\373I\277\301\032\207\2776]9?\t(\370>\371Pq\27780w?Sr\214?M\200\372?\377\003\001?\334 \377\276K\n\317\277\225\002T\277\352\326h\277\3030\252?\225\224W\277\355\337\305\277\3312\204\276\246T\241\277\257E\006\277;<\365>\331\024\345?!\313U\277\'k\235?]m\"?\025I\202\277j\315\314>\343d\212\277M)\317?I\354\363?I\006\"?\275e\214\277\254h>>\376\327\212\2771UN\275\3408\204?\031\366\036\276k\207\226\277\257\316\364\277\007\331A\276\313!S?\313\241n?\270\275\306?3\220\340\276\205\275\267?\362\253\210\277*7\340\275(\212\217?u\375$\277\360\264\250\277c\225\362\277\t\024\027\275\271\261\252?\224\366\006\277,\311\013\276\2008\006?\021E\237\277b|\326\2776\302\234>\372(\032\276\274\242\205?\213\257\r\277$\036\364\277o}\333\277R\373\270\275U\234f\275g]\t?b\266\261\277\323\375s\275\261J\213?\037w\372\277\357\353\021\276@_\363? \323C\276\377\235\335?\360>\201>\274\016\235?a\177a>\022\242\037?u@\331\277\273\230\224?p\313\222?\211R,?M\002\216>}\004\305>\365\357\374?\331\332\211>\340\272&\277\340\367\014?g\"\033\277\315\004\207\276&\214&\277li\250>FC\233?\n\250\305?0\210\315?\353\374\316\277\3679\245\277E\302\272\277\006[\305?\354m\260?\367\201\033>(\374\317>\357 \342?_\213i?\363a\376\276\300|\032>\200D\231?\264l\000\2773\310\222\277\323D\312\277\326\004@?\321?\314\277\314/\224>\340\367\351?\233Ub>\241\367\372>{\204\360\273\246\256\203\277S\373#\277\275r\351?\302\330\240?\313V\252\277w\211\305\277\360G\351?\233\273\024\277\276\212\366\277sX\336\277\345\241\236??\313\010?;V\232\276\033\001\341\276\232\245\331\277\003\335v?\321\'\315\277\\\310n?\333\033v?\352\370\260?\275\000$\277\271y\244?O\243\217?1\250>?\227\207J?U\333\326?\010\274)\276\376&0?\275I\372?\3740\274?N\206_;\342\200|?\243\324Y\277v&\020\277I\331\357\277\351\'\253\277r\260\023?2\357\331\277\274\016e?\317\336z>v, \277\325\032\210\277@N\202\276\232\016\323\277\253<\315?\202\315W\277\223D\241\277}\233\352\277\363\330\205>\334\316\322\277yN\377\2778!!\277\312ZC\277\313\024\204?\0336\240?\332\306\210?\323R\341>\255\357\310\276\252\3300\277Li\202\277y\020\240>\237\336\373\276\303\r\177\276\326\313=?\206\204\253?\363\333\355?\344+W\277\265\366\362=\315\240\364\276\250I\330\276En\325\277wMw>\254\324\331\277\360\351\312\277S\263\352?\305\331\344?\232\275\030\275\242{\254?\252\352\225?C\217\244\277\322V\245\276L\214\037?\247\031B?\374)\005\277\240\331\313\277\266N\345=#\265I\277\370\363\364?\272\"o?\243\360\360\277\315P\\\276\351\335k\277=[\033\277\377(\255?v\n\t?\317q\351?OP\345?\214\246\363\277\344\251\361?f^:?\004!\034\277\225\347\'?\013\345\266=1\177\327?\260\n\205\277E\241\224\2771\215\334\277\205T\356?)&\353?\253\n4\275\t6!\277\205w}\277\013w\313?V\342\337\277\270A\203\276-\202\203?RzV?\311\177\375\277z\327\337\276\034A\271\277\036\213\205\277cG\206?\355H\275?\206\021\201>A\273\222\2770W\232\277\325\332\\\277\253\0218\277\372\220&\275\320\270\203>W\230\332\277\343?\362\277\343O\352?5\355\243?v\301\274\277\303*\324?\253\311\376\277\371\007\330\277\327\004\266?U\314\364\277W\274H?\351\243\305?\367\335\200\277\034W\343\2771\227\337?\005\340\t?w\343\274?\275\315\225\275\303\207\353?|Y#?\242\023\332?U\335\326?\313\361$>4\013r>\255\t2?\343\246\236=\036\031\303?6c\360\277\240\355|?\325\353\226\277\200\204\244\277\232b\323\275\375\214\025\277E1/?\rY\231\277z\336\362\277\000\036\253\276MX{?[4\027\277\210i\363?d\314)\275-\321\363\277\036U\271>\363\356\366\277q&\353\277\223X+?\227sG?\230J\337\276\3572\221>R\266\237?\317\344K>\227%\354?.\026\000=\337Y\363\277\316D\366>\014dN\277\003\265\234?\240{\343\277\205\350==\303\205#\277A\322y\277\022\320\335>\262\376\237?\262\301H\277/\035~\276\017P\007?\303(8?\233\3477?dQ\323\277\005`\240?\014Q\215>~\204z\277e\025\347\277\316!\366?\025q\231>\220\343\213\277>\233\n\277p\250\371\277\274\022\300?\023\027\004>\024\361\250?%\230\360?\306wb>\246S\253\277\343\016\200\275\231\315k\277\367l\235?\027Wj\275_\360\353?bP\363\277\342+\036\277\337\307\260\277\007\n)>\237 \360\275s\031\301\277\354\374|?\304Q\242\276\252H\373?\367\004\246\277\326O\372>\374O\317\277\345\010\375\277R\304\236\276\275\236$>\245\212\336?m\322/\277\317\376#\277\323\020\241\277\300\033%\277\253\271L\277\206\245\361?d\037\225?\301\017\236\277\341\335H<z3\311\276\004\223\354?\320\362\350\277\327\353L>\304\277\264?9,\351\277\037\031\334?\007\264\213>\236W\256\277\253`\206\277\352x\200\277\365\265\267?\323\216\276\276m6\245\276\272\370\260?\327\246\221\277\205\361b?ucl\277\016E\371\277\205\267f\276\363X\365\277\230\014\356\276\034\325\330\276:\315x\277\254\301\360\277\312\303\320\277\023\\\357\277\347\265t\275v\014\306?\234\300\227>UH\335?\014\356\344\276\013k\361\277*\312w>v\306\237>\251\327\233\277P]a?\262\241\355=R\263\364?&pH?6\337\251\277\270\210\270>|\226G?!\352\350?\345\246\243\277X\252\260?\303\360\233\277\037\032\216\277dy\311?\345)\262\277\201\323&?8j\262\275{\372\227\277\236\036y\277\235\207\235\277\302=\324?\357<\221?\234n\223\2776F\376\277\223\247\345?\311+\000?tD \277\374\210\224?m\267B?p\254\266? =\321\277\362!A?\203\377\204=\257E\364\277e\277\321?\201\036\n?\211n\234?.\366\322\276r@\217\276\020\306\215><\244\354\276\\\252\342?\334+\346?;\241\"?k\003\206\277\230D\356?]\227\370\277\216P\351?-U\242\277b\271\366<\277\360\274>X\373\245>&}\203?\266\257\202?\177F\212\276\277\037\214>,\305\260\275\351\006\242?\0042(>\331\242y\277E8\232?\351\351\321\275\236\245\364\277\032\273\364\277\2737\277?\247p\362?\243o\300\277\324k\231\276\344@\336\277l>\250\276\010\235\256?K\014]\275\332H\232\277\002\006\232?=0\331?\231H\237?\317\027\353?\214\254\007\277\234\263\224\277&\201D?\244\036\252\274\213~\373>\322\036\220\277\016\241\207?\340\377\220?\205{=>~\345\336\2772=\325\277\247\014\257?\200\263\010>\232\306\214\277QZ+\277\353\270\010\277\r/\252\276UU_\275\242E&\277\242j\360?{\213\252\276\356*\277\276\021=\343<f\024\270?`T\314>\351\213\345\276\207\274\350\277g\001\201?y\237\227=\213)\343\277o/f\277w\nU\276b\"\211\277.\210S?f\024\267>\252b\357\277;\203\216?_%\366\277_\037\366?\233\000\253\277HI\261?\232,\215?\227>\354\277!\367\353\277M=\357\275\355\035\272<\324e\243\277\336Q\367\275\337\234q=\304h\234\276\305X\255=\025M\203\277joi?\253\210\354\277_\314\332?\025)\363\277\242\353!?\204N\303?;[g\277\303h\230?\202Ip\276\337[l\277\3011\t\277\331\320\265?O\246\277\277\207\360\273?\310\356O>P\354\376\277\347\302X\277\207g\240\277\345\326\311\276\005e\035?U\006\036?\272\367\324\277\016\370\257>\215\321\007?\035\005\241\276\333U\313\277\370\232\300\277\335\242\320?L\017\205<\216\316\204\277H\374\365?\024^\244?g\237\277?Ad\007\277\337\374\035\276;L\364\277kP\262?\263i\377?<`\263>Ej\024?T\340L>\266\006\205?\001\261n\277\010]s?\331Ng?\207Y=>\020\022\242?\212\201\255?C\255/\277<X\332?\204\264\206?\321 ~\277\336\030\371?\331\001\215\277S\213\031?5\010\360>\245\005\320?\023\233k?t.\302? }\016<$[\340\276 \006\007?\263\370\262\277\243&\275>+B\221\277\036|\322?\\\364G>\177m\316\276\226H\334>\227Q9?\313*\302\277\226\370R\276\274\n\202\27781\361>,\231\370\276\371\250\362\277\256\t\332>\266\364\263?W\312\373>\356\0041?K\245%\276\312\240 \274\322|\014>\335\020\034?\372A\333\277\002\214\240>\247\251\r?\025 \301?\271\334\216?W\214^\276\253I\203\277\320\032\312>\345\362\312?W!\303\277^\367\037?Y\244\247\277Wv \277\346\245k?\311\361\205\277\017\177\007\277E\271\255\277\275\215\240\277\377\341,\276Hz\354\277\202=\346?\271U\265\277\314\030=<\337\354\257?\316\013\001>^}\363?\337{\361\276)m\332?,\335\317?]\023g?\'=\206\277\037\372f?\206\013\256?7\364\372?\334(X=\257h\266\277\245\255\363?o\030\317\277\034\033\363>\300\031\035?\016gK\277\253(\345?!.\352=\345\"\202>m-\033?\004\257\333\277\037u\352>\t\335\367\277\240?%?\3110\336\277\276\251z\276\345\373\307\276&:b\277>\274\235>\273\271\005?\377\3407?z\331\226\277Nr\364?\021\211)>$\322\237?\210\n\253=\347\n\316\276jb\335?ju\330?\023,\244\277\001\215\236?\357s\301\2775\341\313\276\312\222\327?\343\364\341\277\365\0019\277\202\201\207\277A\376\247\277\024&\377?\230p\355?\244j?\277Mz7\277%I\260\275\263\266\203?\270~\337\277\340AU\277[vz\277\201u\217?]6Z\277\370\302\n\275\235L\247?\365 \262\277?V\302\277\2226H\277\241\336\301?\002\364\225\277\314\262M?\373\262\340\275_~.\277\332\263\266?\351\034#?\323\2663\277.\020@\272\210L2\277\222/\243?\352w\312>\323\024\t\277\243\302\347\277\031\302\246?U\026\334>{\r\204\277\"\034\376?\335\316\211>\005\214\362\277\214\004\257?\0167\323?\037I\235\277\244\311l?\005%f>\274\334\007=\035\344\341?w\203\223\276\217\322\215\277\247\007z<u=\212?\251\365\273?kp\017\275\010\376\217\277\245\002\354?j\347\221>\367\270-<\303d\257\277\366\034m?\300r\256?B%\322\277\241\372\362\277\371\334\354\276\326\001\346?a\340\202\2777)\330?\351\231\343\276d,k?8\317\311\277\251\310\216\277\240\t\247?\302WF\277\021\232\264\277L\252\263?\0323\212\277\240\321\035>\332\277\374\277\025(\354?\241\256\204\275GT\341?\243\001;\277\242\023\262?\272\374\r<\231\201\332?\322\370\334>\244\354\373\277\354\030\375\277\342\023\260\277\263\000\241>\354\004\302?\3206\201\277\274s\345\276\027(\266\277\351z\237?F\222\233\277\325_\372\277\2612\370\277\205?\203?0\256P\274\3627e\277\341b\345?\363~N\277\214h\224?an\277\277$\211\240\276\364\310z>\371T\373?\200\250T\277\001\332\325?}2\236?)\277\001?F\233\261\277\024f\225?\320j\205?BE\341=\247\0321? C\331\275\251\347\200\277b\031\216>\374g\263>\347\317\326\277FQ\271\277\260\035\r\277ag\373\276\241\323\334\277(\023\376?4a&\2778\304\377?S6\246\277bz\247>\010\247\351\277\311\372\371\277g\010>\276\310<\217?\245<\270?E\305n\2771(\200\277jY<\2765\326\225\277N\257#\277\251\300*\277\352=*\277\201\211V\276KK$=L\306\305\277,\242\216?\374z\304\277\324\021\377?\356\241\373\276|\020\353?7E\311?\365\036\210\276\243a\255\277\004:\207\277\252\367\235?\2522`?\342\237\207\277\331\3614\276\226\236\027?\324\007_\277mk\200\277\261x\275?\013\266\224?E{\027\277\232#\333\277\323B\337?H\n\265>&\254\367\277\201\256\220?\232\370$=\377\346%\276p\243\342<os\253?.m\315\277\201\220\344\277\351\261\203\276a\3568\277dp\302?\262G[\277\373.7\276\034AD\277\323\220\364\2774J\014\277\nVb?\023\212\334=\013B\347\276\364H\313?e\373\310\277s\032\007\277\311\302\013\277\t\250\341?\353\2535=\027\375m?b\215\350?\213\221\253?\222\337\214?\010^\020?\242b\236>\237\355\270\277\213\300P=e\256\372\277;Am?\275\255\334\277\235X\321?\334[\247\276\370\306\334?rr\302?\316B~>\342\363\273\277\347C\001>\327\3560\277\324`\377\277I\220}?\363v=?@\340;?\323Hd?e\321\307\277\203\0036?[\'\215\2777\236\361\277\247\r\320?\373\351\337\277\277\372\331\277c\243\364\277\013@\'?\002\212\347?\255\177\273\276(\006W?\265\357\210?*G\006\277f\213\017?D\263\306\277\027\305\223?[\'\263?l\0171\277aL\376\274\252\313\221\277\312\002\300?\255`\235>\300\335l?\250\364Q\2765R\014\277\225\177\264\277X\350\375?\001x\376\277\"\275I?\345]\030?\010\352\311\276\272\256\002?\237\376&?\275\004\003=\247#m<\335u\305\277\270\353Q\2777E\327=g\344x\277\241;@\276\334DD?\216Z\335>.\241\267?\247\360\252\277\212\265\245\277\275\372g?B\366T\277\366\014\213\277n\204K\277.n\264\277\362\325\276?\006L\033?;\236\206?\337_\261\277\031n\373\276\264m\201?k\033y?\205\037\312?\236\342\360\275\330\362\375?\327\316\024\276\272#\220?\241\3275\277\364\263\352?_\337\305\276\261,\203\277\327\334t\276\331\370\350>\3359\346\2767>\337=\244\275\270?\276\217\002?\250\022\357?)\230\263=\323\267Y>=\223\351\276(\312<\276\224\355S=\001\322\335?\372\204\302>D2\330?\233E\265\276*N\225\277\331\024\261\277H\177&\277\345rx\277\370/+\277\340\236\220?J\322\261?s\220Q?\333&\207?$\273F\277\221\227\312\276\336\033A\277?\223\224?\2064\313\277\"\254\301\277\200&\306\277#\275g?^\272j?\027\225\213?\200\346\321?\272*\233\277\352\371\256\277\361\n\276\277V\000\302?~>\245?&\010\010?\277\353\362\276\204-C\277~\016\364\277\323\005\302<:w\223?\022\320\326?d2{?dT\357\277\272\240\376\276\t\254\200?\315\251\214?\235P\253\276@\346\365\277\023,\373\277\257I\310?\006\230\236?\267\237h?M\320\346?\301\001\322?\001W\371\277\307\211\266?Q-g>\306\242\020?\236\t\223?\013\017\346\277\214\342\316?\360hr?M^\337>\352\206\277?\007g\025\277\220\262i\277\266s\221?q\340\232?\373\316\321?\360m\205>\204\346\211\276\343P\231\277\362H\233=\033\016\266\277\263\332\341\277\031.J?$fM\277\324\235\031\277K\210\001?\"O\251?\256\210_\277\273\233\334\277\375\306\237?\000\311\231\277\250I\364\277&\2429\277i\t\234?x\353\351\277\223\030\347<\230\3422?]`\342?\242l9?I\3654?\303\035\204\277\251\016\314\277<\033\305\2770\331\300?\220*U\277!\377\347?;\332\330?\347?\377\277\256\223\367\277Z\215\213\277\215\246@\275\025\025\016\277\230\033\336\277Y\374\346\277\274\275X\276M\202\233\277\205@;\276\351\346->_\202.?B\246\227?5\253\317>\343`\222?\276\255\376?\324\226\357\277\253R>?\251D\202>\306R\330?\353\354\202\277\274wT?0\375\t?\316lk\277\271\2424\275U\367^\277\364\222\333?#\213\254\276\234Z\025?\276^\306=c\327C\277\360E\274\277Su\360\277im\211\277\235\342\305\277,/\215?T\262\301\276\240`\351\277\230M\030\277\210\236\337?C\021\362\277\306\215~>\326&\271>i\r\347\276o\005??`;\273\277^\363\260?\361\346\343\276\276Q!\277:<P\277q\356\200\277\3623 \276n\327:\276\342\240O\276\350\241\341\276\377\201\346\277\370\032\272?~\234\333?+R\033?\306R#?\310\277]\277)o\252\276\232\250\236>\004\342\317?3\221\227>\350\327\361\277\241\350\347\277\354\340\336?2~R\277\376\000\335\277\356\352\215?<AN\277\241\261\271?\327\236\300\276\032,\201?\033hF\277sG\030\277\3517\301\276\270l\r?\372\367O\276\241\362\010?\241\375\037\276I\363\265\277\022W\316?W\263\232\276\252\260\315\277\036\354N?\351t\346\277\330|\023\277\342\215j\277\217\216\354\277q5\301\277\211\217\301?\365;X?\217\223\203\277w\014V\277\n\227X\277\307\"*>\341\225\302?\343:\253\277\000\332\270\277\271\350\230?97\315\277vX\310?6{\352?nQ\\\276\360\270\333?\034\357\343\275$\264\372\277\354\354\251>\232\205\234?Y\373\346\275\n\221\263?\340F\216?C%T\274<-\366\277xJ\343\277S?\313>\350\357\323\276\3369?=\r\216\207\277\212\374\330?u\307^\276H\215u\277\345\214d\276\005\3261?\253\203\335?\025\357\003?\245\272\262?\323_\324\276A\204\234?JYZ?\002\266\010\277\374:\310\277\274\332j\276\312\2647\277\252\366\342?[\355\"?\024\241\220?\\hW\277\374\360\003><2\275?+0J\277\337\350\307?\227\320j?1\237\313>\236\272\350?Tj\272\277\332\233\351?\3370\341?rZ\331\277\0361\253\276\010L\264>P\2326?j\325\301\276\373g\266=s\353\304?&\001\333\276mb\200?\224\\/?\023\246\325\277\'\203J\277q-o?y\265^?\312+\216\277cY\343\276\374z\365>d8\354?Z-\326>Bqg=\253+\006\277Z.\305\277\004\031\345\277`\304\311?\331,\325\277\024\305\257?X\037\266?\266\216\375\277A.\001\277\005\246\345\277\021\352{\2767\212L>\366[\340?<\214\227\276\253\206\264?\013\303\333\276E6\321?\247\"\246\277\203\274\320\277\'7u\277\274J\372?=\340\300\277\377\241\305\277o@:\276\274Q\255\277:s\335?(K\321\277\004\037\357\276P\266k\277S\247Y?\317\207\325>\\\374\n>\232^\311\277\310\227\303>\272\356\336\276\326d\331\277\277:\304?\315\335\367\277\020\311\224>\313\314\200=,p\014\277\027sk>\327\333\234\277!]\235\276\026\266\303=%\374b>,\000 \277*Z\035?r\353\226\277g\364\366\277\324h\222\276\240l\231?\027m\316?\205\231\200>\314h\360\277\335\331\342\277\266\024\203?L\261\230?;\224\354?osO\2779\234\243\277\271\273\206\277\366\035\227?\025\337\261?~B|\276Or\025?O\272\200?\364\271\203?\346d\342\277\006\303\271?\013q\251\277\222\254\350\277\343\255w?\\\007\351?\024i\226?\201\206\304\273e\312a\275\"Mu\277&\314\177?w\250\214\276\036F\217\277\230r\211\276\272\021\244>\315\210\325\277C\312\372\277\341;\224\277\276Q8?\333\'\242\277G\266\232?\201\210q\277\260\370}\277\321\234\336\277\247ke>\016\313\314?%U\354?\003\235z>\224\256o?[Z\036?\342)N?\024\250\'?[\213f>\266(\310=\022w\205\277H\373\360?\326:z\276\216D\350\277;\003x?\036\313\344\277\367\214\361\277&\354\300\277\312J\177>\225H\257?\'\2557\277d\360\345><\020\t?\020\345\343?\227\201\260\277\365\004\236\276\t\341l=\200\366\252?a\336\213\276z\272\272\277\347\177\334?>f+\277\024\314\204\2779.\242\277\367\250\016\277\245!\232?y\334\014\277J\002\343?-\237\334>@|7\277\276i\324\275\311K\357?S\203\314\276\364\267\370\277\332>\246?@0\022?\330P\256\276\025\244\354?}\323\320?uI\240>\305\200\r?\037K\232\277pU\375?\341Z ?\330\217p<\277u\335\277\r\342\236\277\353\207N\277/2K\277V\307\320?s\276\300?\320\034*?\007\313\345\277R\001\335?\256<H\277{\037\317\277b\022\372?\010<\302?\217\005\271?\302\334\211\276\212\351\245\277\025\305\312?8\260\204?\006C\032\277\312\311\225\276\"\313\014\277\022\177\331?\006\320f\277\257M\013?\333\257\327\277\245\272!?\356!\316\277\271\037\313?$\337@\277A\031\312?\006`\272\276*\003\226>\265_\336?\242\332\342?\312\352X\277Lo(?2\344\217>\377&\214\277\204,\353\277\266\020\373?\373H\227?\241\037\220?\216\354\257?<p\017?\006\025\330?\333\222\033?\325|\374\277\343N\302?(!\260>\277\226\017\274<\243\025?^>\213\277w@\277\275aA\200?\360to?\344\324\200\277\001\201\223\277\324p\371?\246\204\355\277\317\202\273\277@=\211=>\036\340\275\304a\360?\272p\353?\376K\201>\000m\233\276\326\3617?)8\213?\377\322y\276U\236\267>\344\200\346?\327\353\241\277\216\245\360\277\327yg\277,F\027?_`\"\277\301\006\205>\365\003\354\277\355A\255\277\027\351\032\277\321`\322?$I\"\277\013\326x?\006\024`>\337\010\344=\365\353\237?\325\302\271\277\331\377\347?DC\301?W\254\245\277\001\356\343?w\306\200?\274p\230>3@\361\277K\261\340=3{\330?\370\347\212\277\014\212%?\304/\322?\370\t\203>\tz\343\277\246C\233\277<=\353>\n\'\354\275\217\257\307?\210L9?%4\230\276B\315\017\277Vs\306?<\207\241\277z\'\346\277(\212@?\016\\\200\277\251\025_\277\307\\\224?\t\216.\277\255\272\"\276\224\346 ?\023\261\330?\211\245\351\277\367\240\322?wS\234?\211\344\240?\002R\261?;2\275?)\224\335=w\313\210?{\350\265?\252MZ\274?\262\023\277B\306\232\277h\270A\277\226fh\277\240\233\231>\3331y\277\0303\235\276\3209=?\021\324\250\27733\351?\322\276\276?x\241\376\277YZP\277CZc\277\305L\330?\215{\177\275\275\220\030?F\225\250?g\321q\276\247%M?e\\\367>,\026\206\277\222\332\263\277^1\000\275\266H\375\277\376\324*?uh\350\277g\004\324>8\004\310?\201\354\311?\371\236\240?\016\332?\277.<\227?\365\356\005>w:\272\277#\265\320?\375\324%\277\023\314\276?\241\n\027\277\266\206\003?om\214\277\320\213\027\276\213\233w\277\241\010\241\275\212O\213\276\3130Q?\343\037\306?X\325\205?\317\365\357\274\212\261\302\277U\345\220\276\0339\\\277r\021\247\277\343\370\200>G\200\245?T\016-?;(\022?\203A\334\277z\247\315?\332R\376\277@\006\371\2777@\253?\265\324\336>\333\247,\277Q\356\217?\326\367\352\277)n\206\277\217\253\236?\233|\257?h\376\020\277\001\216o\275\247\241\241?\007/\214\274k0\300?>\3129\277\214\240\036?\035<\276\275\361\273&?-Q\263?\372}\252\277R\245%?\326\034%\277\347r\356?L\317v\277\263|\233<\370\332b\277\361\236\305\275\277:\310\2772\344\"\277\222\270\305?\231b\211\273\224\031\222?\271\203\374?\335\210\316?\361\371A>s{>\277\017A\304?\331\366\335\277\233\336\232?x\365\001\277\365\220`\277\006|\332\277\034\000\365=\341\024\246?\270\340\\?\350\370=?\336M\227\276\277;N?K\337\344\276\332S\325>o\355\212\275|T\367?\001\353.\276[\007??\212\226\362\277\312\327\206\277\301y\203?%\344\021?{\016\220?n\311\247?\\\335\022?L\354\233\277\022\334\303=O\r\352?\355\371\333?\355\306\034?\207\254\027\277\300\305\305>\005F\300\276\236T\201\277I\236\210\277.\327\264?\rD\302\276\202\\\374?\247\237s?\377;\320>\365xq\274\273\305\233\276e=\222\2779O\344\277!~\257?B\354\320\277\341X\367>np\356?\222\326\022\277\253\3245\275\355\347m\277/Y\353?\276\302\350<\243g\374\277\222`\351\277\03483\277\260\300\302\277\257|\255\277\245v\264?JZ\010\277-\307\"?\2166\333?\344\023T>\333\262\261>\326\007\331?\016\232\274\277\220\324\332\2778\235\346\277\215\360\203\277\2664\320\277\346^v\277\004\307\321\277\304\3771\277\303b\320\277\324\317\334\2773s\\\277|\326\356\277y\222\253?C\243\214>\326\203\335>>^\202\277\260x\230\277\260\332\267?\373A\220?wZ\256=\323\361\017\277\252pS\277O\345\222?y\313B\277\301\231\261?\232\247\215\277\300\224/\277\"\207\261\275\300\276\313?Es\366=\216s\273?\260\234\344\276\310\361\340?\207\005F?~\341\223?\323\n\366\275/\330\363?\217\340\345?\336\211(?\332VY\277k\003\332?\"k\016>*\310\252\277\323\374\200\277\005\376U\277\266\r\251?\373\314\256\277\314A\243?\226\030g?\371\355\375\277\351\336-=\367\220^>\301$\272\277\266\372\373\277\\T\231?\'9\370?7\220\272\276\t.7\275\326\352_?~\214\242?\0074\315\277\016\010\365\277\235\246\310\277\2725\235\276\246D\325\277Y\336\330\277Q\364\256?~\200i\277{\324\303?\322\221\244\276\306\214\217?\221\333\037\277#H\373\277\363Y\232\277-q\217?\203v\311?\320\025R\277t\307\277?M\340\316?0\027\362\276\214\260\006\277\244\003\333\276?\324\274\277\361\320J?\023W-?\275\353\313?\223\251\306>G\236\037?\217\213E\277&\372\354?}\013\253?\325\244H\2771\032H?\273\257\342?\264\367\267?\036\177\021?*=\306\277bD\336?J\202i?\250V\206\276~O\207\277I\203\307?\222\250\251\275]\237\343\276\320\207\376\276\213+\210\277\356)\373?+MO>l\222\353\277i\033J\277\366q\215?Mn\t?a\242[\277$\342\227\277\217\316\364\275Yx\260?\225\317\204?\305\017a>\242\242\367>\257C>>6\226h?\034\346\313?zM\364\277\016\235\341?Z\003\215?\020v\340>\316\013\000?+4\266\275+=E\277P\357\322>\022k\333?\rK\247>)\215\270>\274O\025?63<\277\246K\221\277\356b\205\277\245\366P\277\004T\243\277m~\360?e\377\253\276\230\r\226?n1\304\277\257\315 ?\031W\355\277\227\367\323?y\321\036?\224\366A\277\021_X?\t4\263\277y\223#>\374\177-?\311\235U?\022w:\277\022\256\214\276\256\3723>?G\365>-\313\003\277\267\377F\277w|\212?\200\271\300?\000#d?S\206\030\277<1\215?\303\256\326?\026\030>\277\372\265O>\366Q\025?\216[\257?\330\316\356>\272;\262=\002\002\350?x[&?\223\373\205?\006\001\214<\3070~>F`\207?\000\241\350?-i\315?\347\020\272\276\300\3018>\324T\307?\257P\\\277l\001\227\277\215]\220=f\021\223>\031\260\340\277\031\316\377\277D,]?\277\271\376?1\341\340\277[\221\304?\343pf\277+g=\277\204\314\345?\233\245\266\276\000\277\260?\017\203\307\277\216_\252?Om4?\321c\320?\352\334o\275\271\252\024>\251\032\242\276k\266\337\277@\3537>\270\356\263\273G\372\361>\375\242\254?\214|\303>\0019\206\277\261\017c\277\324\037\225?v\371\375?\301\265\347?\236\263\355\277\222E\333\277\014\370\200?\217\356+?\265\272\032>\245*\206\276\246\033X?\350\347\r\275V*\350?G0j?\022V\237>E\377~?\206\251E\277\325\247R\277\2218\354>\370\003\023\277F\232\263\273\217\312/\277\373\221:\277\345\264\207\277H@M\277\016J\330?\306\242\255\277\0023u?\236jT?\366\035\303\276h$\352\277\361\230\313?\013\203\360?\365\264\321?u\2020>%@\205>\206\265\337\276w4\340\276\377OO\277*A\314\277f(\001\275I\320(\277\347\262\314>\342B9\277\366\257\250?\370\320O\277_\213>?x\014o?\316OD\277\244\326\005\277\035\332\273\277\274\356w?\321b\246>\000\242\271=SG\031?;\305\321>\300\212\254?\365\204\376\277^\312\371\276\003\205\213\277\\\277\214?\300?\340\277k\330`\277\263e%?\035\014\320>\026Z\200?g\343+?\224\323\246?s\354\254?\004\243\312?\317\035\215\277\020;\027\277\"\261\005\277\267B\242\277eM\366\276m\273\n<z\241\331?\\\024\320\277\303\004\363\277\343F<?\373\362\341?\352\007\226\277\224E\020?\342n\202=\275\024\307\276K0\274?\202\207=\276\342\326\352?DV\261?\0317\357\277\217\244\220?\333CO\275\350\032\350?LI\270\277p9\322\2775\212\316=a\236\020\277OEV\277*\361L\276o\033\331\276\204 \356?\t\375\230\276QC\351\277BS\211>\203\365\243\277c\017\362?m\201\352>x\251\231\272\006d\252\277\025\330\274?\222\240\216?t\362]?~\003\351?^\200\006\277\243\355\206\277\242\310i?\373}\376?\263\006\377?\314\224\357>\324\245\234\276\023\371\356\275\026\361\253\276e|g\277\276\nO?\014\236Q?\304\021\t\277\203\020\242>\360\n\331?\371\004\310?94\337?\004f\265\277\216\306\352\276\370\026\374>M\260\277?\376P\005\277\3663V?2\253\300\275\216\3033\276\020\305\000?\360\260\360\277P\006\205?\340\362}\277\307\203\340?\251+\356>\252\203@?\031\214\337?\002\003\320\277\237\221\333?\223\221\341\277\363R \277ml?\276:<\200?a8\230?H[\342\277]\253\017\277c1\341?`\2331\276\312/*?\323g\351>\342@\335\277\323\322\315\277\352t\337\275\322\335\373\277\006\326\205?|\371\007\277}\277\006?2\270\213?\n\263!\277\221\010\270\277\231B\333?1\035n?\177 \255\277.d\361?\374.\230?\226\326\204\277\031\034s\275c\030\320?\025@a\277+\206\263?\036\314\037\277\021Et?g\230\353\277\275\037\363\277\371\244}?\371_\361\277y;\003?\306\313\316\277\365>\270\277c\367\330\277\230\255q?B|\022<\022W/\277O-\207\277\232\331\251?\271\236 >\272\376\335\277\302\272\364\277\255f\264\277D\"m?-<y?\323R\037?!\204\270>\26149?\031*n\277\335\035\225?\277\307,\277@\014\222?B7\347?\373>\034?\350v\t\277\215\305\240\277\257\202\270?\271\334\205\276\270\223\363?\203\250u?\316k\326\277\257\331\274\277\342M\233>\211}\205?\n\234\246\276V\270\275?\341\2142>\266R\325?L;\335\275\324\025N?t)\'=\321\327\374?\363\372\327?\220\305\224\277\367+w?\310i\214\277.\315V?\347Q\377\277\263v\300\276:;\204?\344\n\311>_\006\216\277\256rJ?1\353\232\275\340\324\360?\004h\307\2774\227\267?\005c \277p\3617?\322g\304?l\\\351\276\210It?\023Y\237>\254l\003\277\034\214\352?..\264?K\034\272\277\\q\217?(\016\376?\234\201n?\035a\222>I\353\263?\014L\371?\367\274\306\277\305\276\241?\004\355\221\277h\212\322?3\277s\277\234S]\277^k(\277L\234a?\004\303\232\277F\271J>\016\034\327\276P\246\237\277\222\200\211?~^->\212M\350\277\223W\252\277\356\351\004?s\355\342\276\304\235\305\277\"k\324>\035X\241\277]T\033\277W\314L?\023W\227\277E@\336?\007\214\264?\215\374\331\277R\033\267?\236\027\256\276\344)h\277\rw\257?(;\224?\316\247\275\277\021n+\276:\2511?\003\331\320\277G\313\\\277g0\315\277\360\205\324?wJ\025\276)\367r\275\004\370\231\277\001IT\276q,\330\277\262\300\302\277\236.\362?@a\333\277_\346J\276\251\220\335?>\331\207\277F:\362\277\236\030\367\277\247/\226?b\211\023?\t\010\241?\255\\T\277\366\257\241\277x\362~\277\204\364\031?\240\213\271\277\204\232\315\277z/\317\277\320\255\231\277\316<f?D\305\352\2763\367\240?\327\327\274\277\375\r\364?6j\323?}\000Y=\344\277\340\276\275I)?\016\256\225\277\255\365z?:Ns\277\025_\243?\202\016\025\277e;\231\277\372h\315\277\026\320\343?\2631\232\277rc\223\277\243c\030\275\333\r\371?\304yF\277M\365\206>\032\247\254?\263\n\312>-\013\"\277\330\004\303\277,\352\301\277\272\270F\277\312~\200\275\363Y\354\277\0055V\275\325\215\320\277\006a\377?l\250\206?\265\372\213\277J\264\351\277\361\261\360?\346(\362?\362\'\277>]E\313?\335\020-\2763\336/?\256\206\374?\001Z$>\177\367\314<\315\204\240\277\024\270\337\277\214h\336?\\\205f\277Z\332^\275\274\270\230\276\255\244\236=%\355\316\277^\343\271\277\320\342\365\277a\373\t?\265\340\251?\351C\255?\344\320\267\277\222\215!>}\324\363\277\",\025?\022v\202\277\366:\271?\366U\025\277\264EN\277\364\226\016\276\311j0\277\367\262\362\277M\303\253\277}\331\221\277\37292\277\322\364\023\275\023_\221?#\326\321>\347\222\215?z\035\037\277\305\374\001>-\342\377?\306\037\244\277\306;\341?F\324\342>15S\275\344J\241\2770\325{\277o\357\031?|\365_?\364\376\013>L\326\335\277\271\315??z\222_\276kk\347\277Yi\n\2730\276\371?2\t\204?\365\340[?\370\306\245?B\013\204?\332\363\372?\200>\250\277\004\213\367?qt\330\277\231b\350\275\032_\303\277M,\230?Jj\001\277\005\033\251>\252\217+>\360\344\203?\255%\276>G^\357>\262\225\342\277\217oJ\277\352@+=I\2224?\214_\271\272\350\357g?\322\301\313=\310H\274\277\340\325\206\277\3728\303?\306\214\332?4\257\301\277\263e\370?C\260\247=L\021\377\277\013\342\265\276\273\035\301>b\336\225\277\347\300\343\276\240\254\334>\217\031\332?\255\360\330\277\243\315c>\240\002\216\277\344/\342?\310\0102?CC\363\276\243a\237?\212\"\346?\223K\370\275~\313\312?[\244\341?\034J\030\277\217\266\230>8\200\233?\035\331\310?\217{\207?\275Y\003\277c\341\037>8}\320\277.\240h?8Rk\277\253K\242?w\nX\277s=\373\277\263%\356\275\240l.\277\341Z\224?\300\355K>[\330\342>\000\005\\\276\037p\227\277\006\024P>+\335\265?\342>\334\277\'\353\313?A\0000?\\\262\300\276\006\017\244?\220\343\244\277\360K\004>\243\350\322\277x\240\227>\374\370\265?\232d\305\277\220x@?\365B\306\277$\362L=\313\335{\277:\321\357?\367\260\362?\177\277\324\277cO\316?\204\262\367?\016|\230\277\311\220\263?\260f\301\274\317F|\277\256\r\371?\010\336\\\276N\352\362\2774\375\336?\362\332V\277\031\340W?\216vM\275\026\013\t?\271\336\364\276\272\256\331?9q\t=b3\265?!\320\247\277q\005\'\277\004\356\022\2770\032\272?\321\206\262\277\235[\023=E\314\\?<\252\350\277\347:F?\255\027\200\276<#r?LM\263>\315\250\324?b5m\277\240\336\244?]\332\225?\226\r\277\277\314,\303\277d\200\230\276hs\013>\206\005\001?\367\226\375>c\337\026\275\317\275r\277\177\345\337?\372\376P\277T\246r?\204\232_\277\024\\:?\336\354\254\277\327\031\363?H\2221>w\256\014\277fx\320?\312\232\226?\004\351\354<\034Sq?B\222\023?5~\262\275\345~\347\277M@\377\275<w\273?\217j\271?\246\313\321?Un\231?)}\340\277\212\337s?UP\366?\317\177\217?\222\006\370>\317fr\277\213\260\356<\253\205\274?-\201\271?|\373\251\277\303[\356\277\217|\370\275Y\254\345?I\3410=\203Y\265?I\344\310?,\367\260\277\260fp?\263\033\217?0\351\274\277\347\001\272\276\273\026\020?\030!x>\276>\273>\316\020p\277\347\267O\276\341\2213?\330\177\232?\341\200\221?Y\200\327\277\243\177\234?\360/\275?lX\201?\370\034\323?_`\376\277\232|\023?\010\303\321=\344e\204?\207\010\362\276\023T_?Q8;?\277Ur?\241\036\301\277\000k\323?\2100\210\277\233\207\343\277\216Z\276?\307H\375\277e\347\307?\224G\263?\212\306\217\277\026Y\"\277\327*\375?\351\017\314\277\265$\002?K\260\356?\033\010\247?\213\177\220\277{m\270?\236\263\220>\014\211\242\277)\034\023\276\0317\337\277\206t\207\276c\001\234\277\206\326\'?\224\027T?\277\277R\277c!\344?M\250>?\307\227\302\276\354G\347>\215t\252=\007Q\350?oQ\255?\204z\252=GDb\277P#\256;/\001\003>;\271i\277=\342\231\277^L\235?p\257\341?\024+\351>\360\376\214\277zn\257?`a\263\277\216\314\200\277x^\301?w+;\277L\226\370\276\3207\320?4?\026\277?\202\216?\263\216\375?\307\225\365?_X|?&\251\333?\215\036\207\277CN_\277\320Q\034\277\\\227 ?\313\301\325\277I\027\222?p\036\010\2764h\325?\246\3030?\206_\335\277\241\326\275?G\300\224\276Q?F?Q\253\252\276i\\\351\277(\243\300?\300z\277?R\372\203\277\236\342\223\277\3106J\277\177\270\310?\374\025K\275\033\316\310\277\006\274\261\277E\272\315=\326\335\315\277[\021\205?[\356\315?\337\303\271\277\"9\356?5\362\310\274Du\265>l\276\345=+\235B?\377\376\250?\002\005m?\223\366\234?\323\314H\277*\303\271?\324e\261?\371\317\207\277\237\314\377\277\217\273e?\273\235L?\374\217\231\277\243\370\377?\305\346\013\276\320\203\350\277\335*\226\277\350\261\374\276Q\367\256\276\013b\225\277<\257\227?\325\005\335?\257\272\237?\245n\276?\274\357\322?@\351\215?g\332\214\277u\350\342=\2017\365\276\270]\211?\276c\362?v\'\017\277\313\367\246?\364\264\214\277edC\277\327t\"?\356*\224\277:S\257?\262/\324\276K\002\031?\\\035\350?(\273\207?me\320?\271u\257\277j\363\271\276\r~,?{\207\006\277\2422\272\277\t\3122?z\224\263\277\0357\252\276\027\230\001?\261\324\220\277\356\246O?\271\263+\2777\032\244?\212C\245\276\005\373W?\263I\366\2776\225<?\030\237\232\277\375`\361\277\257\231\215=\364\334a?*\316~\277]EH?\016\231\212?\0079\302\277\032\310\232?E\227\212?\220\014\224\277\304po?\201J_?qH\331?~4p\277^\352\223?t\324\214?\225m4\276\252\031<\275\231\357\363>67\211?\252\231\355\276\365=\304\277\362\330\250>\006\260X\277\317)\330?v\371\200?x\253\201\277\205=\344\275\014|\347\277S\216\201\277\014\341\334?\227\213\265>\014\003\366?\026 \225=B\315\366?\220M\177\277N\005\325?\010O\351\277h\256\223>sB \277%\033.\277\010\344\022=\373\t\226\277t\331\236?J^\350?m\233m\277H\220\005\277A\331C\2771\252\357>\306M\252\277\265\010\231\277P@\205\276\273\211\256\276\366a\337\277\027h\274\277\310\275\271\277(\010j?@x\350\277\267\356B>\255.\351\277TI\301\277\227\315\033\277\177\305\031?\320\214\232?E9\365\277j\245\374\277\305;/\277s\255\201=-\314\222?\0058\027?6g\242>3W\353>\363\213\357\2779b\350\276\317\'\n?\335|\361?\273\036\312\277\266y\371\277\324\207\326?%\323\226?\n\274\352?\275\035`\276\255\304[?\006j7\277l\214\270?\254|g>\022;E\277\2245e>\020\364E\277ur \277\240\036\300\276\267\374\270\276\305\250\257\277\312\352\227?\367\034\233>\340V\356?\n\004\241?\302?\211?\275\314\247?V\261\374\277\341t\346\277Cvh\277\234/V?\362\274\242\277\304\030\272?\303\312\373\277\202\252\366?\317G\234?\005\267;?)\216#?|K\352\277:@\241?N\275E>\330\210\377>\306\202\001\277\360\007\253?-(\300?\332\272\357=\303\005\330\277\027`\237\277\204\202\346\277u\303\256\277^h\256?<\000\030\277x\006\214?q\312\347?Y+\231?\010\242\222\277\336\211\237\277\364qx\277^\302\322>\254\000\313?\356f\234\277\241\022\372\277\345\262\274? O\244?kU\206\277\260\240\363?\226\263\211?\273#\307?\323\001\266?4\222<>\365\341\257\277\310\267j\277b\316y?\266 \210?\361;\025?1#\340>\002W\247\277\010\240\324\276y\235\221\277\355\220\214\277\034\226\032?\362\263\245\276\255R\217>\tg\245\277\273P\231\277\354B\206\275@\346\364\2768\037\215\277:\306\270?\247\337\212\277\317/\227?\211)\256\276\352\333\213\277\020\277\331\277~\031\263\277\036\225;?B~y\277Q\tw?\355\002\265?\361A\322?1\032\362\276\234\3715\275\\Z\024=/\254\333\277G\346\323\277T\"\273\277V\310_\277$j\274?\237\234\366?-s\216?\376\306\201?\240\375\201?\324G\314\277\331r\267\276\370\257\222>i\244\324?\324\265M?4\236\200?1\312\364\277R\251\217\277\020\\\351\277\224\023\362\277\r\264\215>$\356\311\277\314\251\261>\334M{?G\230\365\277\312o&?Q.\337?\034O\033\2773.v?\374tV\277m\257\243?I\216\n>\236d8\277r\311\363\277\241A\341\277\005\274z\277\002oO?l\232\343\277\236\r\314?\234/\271?5\221\340?I=\322\277\324+\336\276\316b\234?#\006d\275e\256m?\247\013\207\275K\003\236\276\270\317\t?\252!\344\277\334\303\324?\230i\320\277\204\010\302\277\216\227\360?\306\341\216\277G%l?c\227r\276\315\374\272?\016\020\'\276>\346\027?a\027\246\277\362\240\240?\264\370\025?\234M\035?\303B3\277\355E\265?x\025\302?\354\257\377\277|x\245\277\013\352\251>\377:\326?x\227\345\277&j\221\275@\216h\277bH\270\277/\001\255\277\t\347\322?\223&\311?\233\211\322\2778=\001\275H\3047?$D\252\277\362\177P\277\256\003\256?y\356\252\276t\\\345\277k\025\241;`\006\037\277]o\000?;\230I?\245\n\254?E\344-?/\330\230?c9\313?\344\262\375?\367\004\204\277\277aN>2\353\310>(fH<\302\237\345>\277\204\201?\264&\267?x\253\203\277<<.?\230T\361?\310\372\377\277\002s\373\277\n\277\027?\337\017U>|\317\210<C6d?\221\233\371?\213\027f\276\337\265\316\277T\275\334?\230;V\2776\376\351?;\343\312\277\343\352\377?Q\312a?\204\t\005\277\031\255\244\277 \370\311?\272\263S\277\260\215\330?Z\224\214?\005\247\332\276\302\264\303?)\257\265?o\207b?\254&v\276.\303\321\2765v\313\276>Io\277\022\324>\277P\"\031?\331E\235\277\350\035\222?^\315\226\277\374c\247\277\t\364\276\277M\325\323?r?\275?\222\335\262?#\320\312?\342\241\244\277\267\027\'?\247\200#\276\207\216W\276\024\024\261\277Zg\250\277\300\214\213>\344\233\302\277@\223\220;\021\330K?5?\263?\301\220\001\277\341\210\376?\263\032\204\277\344X\232?\312*\355\277\251\317\226?\265\320\360=\201\322\300?+\010\262\276q\301L>\231Aa\277V\362\030?\023\022{\277.\017\366\277H\365\251\276\236}F\277\233\020\313?D\266\361?I\300\014\276\027r\032?~\t\323?\312\271x\277\266\002\352\277h>t\277\265\320\033\276\243\010\351?\010\345\035>\\\t\221?J\241\371\277\177\014\367?\342~\311=?\036\325>\301\r\360\277S\200\372\275\365\340\375\277t\222\217>\224\335\004\277\327\277\300\277\206\212\265\277\240\224\231\277\275\377\320?\316\373\201\277\306\267\354\277\2442A\276\243k\364\277\031D\203\276\246\274\317\277\'\264\376\276\2569\344\277\252\330\267?=\273\357\276\022] \277\210m\302\277\252\227\367?&\227\035\277\376\273\340\276\325\260\231\276c\025\333\275n\216\352\277Y\213\347?\374A\333\277:9=?\2376\340?\333J\326\277\026\033r?s\032\217\277\254\320!\277\263\251\243\277\377\017\367?\033\267V\277\3135\037\277\026\ts\276\205\320\315>V4\315?\337G$?v=\254\277\302\316\256>q\206\334\277\327#\237?\312X\224=\371H??\201ZM\277:\004\202\277\201\nA\277Q\360\037?c\177\355\277\226\276\017\277\032x\n\277P;\354?q\323\031\277\025s\210\277\023\314e>\360\245\262\277\007\017\335=|\225H?5B\377\277\177\214\267>\223Y\200\277\265\004\362\277u\027\237?\211\332\222\277\261\263Y?p\336\263?t\223\217\275\340\007\266?T\275b?/f\343\277\310\361\357\277\357\332$?\206\224w?\353\370\352?\tB\010\277)l\006?@%\364?3\315\323\277n\3743?\364t\032\276Xi\204?(/T?,\236\'\2775\222\211?\3140t?\357\363\212?\320\261\246?{\227$\277\2730\242\277V.\360\275\317\375\366?\341@\177\277\365\216\247?NZ\206>\222\211\243\277\244g\357?u\307\354?vk\213\277\345\255\332?\355\274\370?V\357\351\276\314\3434\276\201\n\027>~\202\260\274\343\002\350?\240\331<?\376\371\322\277\225\305>\277\016IQ=\213\362\337?s\202.\277Q\267\266?\250\374\324?a\020\357\275&\376!\277\302\351\366\277\373\326\253>\212\374\267\277\3648\206>\302n\327?O;\240?\3642\313?\212\177\375\2773k\323\276;\355t?\010\263F?\335Qh\277\377\240\226\2758\225\274?H\320)\277\217\356\245\276\201\241e>nC\335?\n\362\237\277u}\216\276\353]\010\277\202R\234\277\346\225\232\275(\354\372>\037l\227\277kO\331>\335\036&>g\242%?\007\204\336\277(\351\242>Q7\341\277\347\234W\277b\'\222?\016\201\025??hI\276\004@M\277\243^U\277U\262\310\277Y\304:\277\256yY\276\325\244\224?\363\345o?O\255\353\277\007\236\">\215]\337\277y\257i?z\263X?\006\002\352\277\207}\306?QlB\277\002\374\273\277\250\344\024\277\347\032\274?\362\270\367\277\303v\226?\277:\014?\371K*\277q\0176\276\224~ ?\247Hu\277:\325\217\277\025\314,>z\230\361\276~n\356\276~\026\020\277\277\002%\276l\027\376\276\243&?\277e\374\267\277\000\300\311\277\240\356\243>\3747\357?|\306\236\277\303\242;?\222\364\377=3\000\326?\337\311\305\277V\260y\276{\236/?w\260\t\277\260\254V\277\336\030V\277\312\321L<\274C\313\277o\323\264\276rZ!\277\2703\324?G\'p?\263\377\303\277\331N\235?T\360\276?\2652\253?\213\344\211>\262\312\254\277\nRL\277=\275\367\277\376\276c\275\001\371\302\277\255\331\301?\312\331\221>\217\003\367\276s\214\003>\036\322\237?n\000\214:\177\371\375?\032\373\177?\374\353\253?\364Q\236\277k\000\255\277\241\340\321\276B[O?\323D\231\277p$^?\243Zp\277\226Tl\276l\304\203\277*\307\200>\321\320\376?=\243\'>V=T\277\025\337\321?\021\374-?\217.\244?\223g\357?l\374\263\277\216w\325\277-6\306?+\007\232?K\367i\277\025\307\330\275K\020\335?HCH\277 \273h\277QKN\277N\020\246\277\301\260\334>*T\271?\234\371}\276><C\277:X\250\275\332\304\337?\021B\250\276\030\342\356?\277\235\270?\3602\302?\2357\023\277\333\232y=7\326\342?\375\n\320?\241\rl?\327\2555?\204@d?\223\224\241?#\316\317\277m\333\004\2771\325\035?\275\2562\276&x[\277\014Z\322\276\034\330B?J\200\337?Z\361\023\277%u)\277\224\255\215?<>|?4\270h?\311\207A?\030+\265?\320\306\351?3}\275\276w\352\274\276\251<\332?\035\257\225\277\254\360\272\276\t\317\275\2778\244\254?\300\334Z\277X\222\377\277\221\321\261>-\243\270\277\016\304\355?m\315i\277\342\310\267?\024\335\177?\030{\326\2770\010\314?\321\316\214?\010\345\310\276~\240\327\277\326\\\343\277\372\030l?\203W\225\277\202\035\323?\034%\361\2769\177\330?C\371^\277@O\264?\272\370\336\277K\027\"\277\037^\346>\024\030\272?x\375\225\277\307\201\303>\344W\344\277Hk\004?;\363\326?\342/\234\277\304\203\252\277\342\244\336=\355N!?\264\335\220\277\362l\231?\257\237Q?\020\275\350\277\236_!>\316X\260\277X.\346?mQ\223\277 \372\034\277\233\334\\\276#>\265?C\024j?T\242\216>5#\275\277\203\262;\277\343\340\330?\t\010\273?\212\276`?\330z\353\277\341E\203=\271\274\311?&\305\261?\023\031\207\277.\263:?b\266\217\277\334\010\206?\247*\321\277SC8\277dnO>}<\307>\266\025\352?\300\177\366\2769\\s?\275G>\2771\022\n?\037K\215\277\260\221\324\276\007\253\313?\0362\333?Wn\264\277F0\021\276U\253\300\277d\306\346?,\304\323>\352-l?C\223\243?\334\365\233\277\315\320\013\277\030L\333\277\234\232b\276\271(\021\276\263:H\277\230\245E\277$\302\211\276\270r\344?X\225z\276\216\370\214\277P;\227?\253\307T?\372\2052?|\211\224\275\010\3679?\355\265u?\351\320\370?^^U>\332\340\022?\304\372\032?\253j\264\277;\r2\276\257\005\232?\325\234\267>\037I\337?\262\237\027\277\202J\341?BU\264\277\273\303\357\277\322\213D?dp\203\274\372ig?9\337\371\276o]\250?\300\335\222>7y\276>\334d\'?6\227\205?\022\205\302\277\220<\245?\207\3152\277.\'X\277\206\363\204?\364c\321\276R\263\335?04\237?,\306n\276\"Y\320>\260\223\214?I\366\225\277Y\001\340\2777\256Q?\026\000\210\277R\363\327\276H\005\267?G]=?\3614D?\014\360\345\276\327h!\276\'\344\365?\357U\307\277\316Z\246?\240\305\266\277s@\335?E>\376?>\363\244\277\334\330\350\275\032\342\256?\rw\223\277\022q\035\277\027\241\320\277<\261\257>pY\327\277\024\222\312>rL\242>I\204$\277\0010\200?@g\365\277\000\322Z?q\337\257?\262\345\347?\037\334\230\276\341\335?>\324kf\276\201i\345\277I\016 ?ij\350\277\247\337\353\276\312\365%=\"\256v\277\306m\340\277|\010\236?\351A\246?\013R\362?\306)\244?T\353\271>BW#?\330\246\027\277\013\371\201\277\213\243\366>Ck\350>\216\016\322?\030\324\355?\334g\261\277S\016\210>\266\223\364?\3768\247\277(Q}\277\262\007\224?\346fx>\354r\240\277WLF?\2107\272?9j\224\277\013 4>\022\232T\277\264\356D\277*\305\227\276\341Z\345;\255\323\363?\3212\367?\375\266\024\277\266\006\360\277\327JB\276\364e\024?\037<v>\022\314\321<5\202\207>2\263\336?.\nr\2777n\371\277y\021\037?\375\n\373=.\352K\277\343\365\330\277\030\214\200\277d\213\266\2771y\016?\321\374z\277,\031Y>\266\262\240\277\307\247\035\277\013\037\363\277\257\267\257?=-\246?\237\321\364?\333\337\026?\363r\215\277\357\303\312\277\202\262\317?\022\007\344\277*\230\314\277\re\207\277\243\020\267?\262R\345?\320\333\027\277\201E\214\276\241\'\340\277o[\214\277\037!\244\277z+\276\277DZs\277F\016\310?\324z\252?\252{\362?/\211\226?j\014\327>\t\200\306\277kr\211?\311:\251?\005\266\333\277\204V\215\2772\235x?w\216\253?\to\225\277\226\273\234\277X\343\233\277\2371\357?\203\241Q?\203O\334?-\\\303\277\031\207\217\275\343\353\242\276\266\332\376?\214$\300\277YZ<\277\001{\316?\204\257\344?\241\376x\276\205g\310?\214r\356>\364\236\270\277\277\3067?~_\230\277(\340\353?\001=#?yN\254\276\235\013\323\2774\312V?R\216\365\277\326! \277z,]?\3525G\277\204\324\312?\227\200\031? \300\337\276\304I\241\277\313\335\'?\"\360\331\277\317\177\320?\235ft?\36404?\251|\270\277\025\003\313\277K\215\300?\2541&?\345\202\331>F\311y?\032\361\304?h\264`\277\276\346\253\276l\231\376\277\336\373\337\27736\300\276a\242\377\277\026`\375\277^\336s?\326\203j\277C:_\277\245\275\353?\261\233\366>\232\003\215?-\317\341?\r\266.?~\304\217?\214M\356>\201Of=\302\211\351\277\320\273\326\277\305\254\353\277\335a\300\277F)\n\277\003V\271\276!\254@\277\352\373\235?9\352O?\020\205`?\360\307\245\274\007\312\370?\234c\252?A\266\305?\227\033\265\277\246/\366\277=\307\307?\311\373\320?l\205\364?\023m\276?Zug>\234\017\343?\2015;?\353\325F\277L\024t>\270k\336\277\037\033\216?H\313n\276i\327\322>L\263\226\275D5\252\277&\005J\275\037\316\263\277\327\324*?\034!\376=\363\245N?\276\2517?\275\271\361?\370\214\305\277\224b\305?l(\331?\233\356\206\277Unx\2778\3029>\225z\216?\275G\272>o\231S\277\304\017\257\276\346\223^\277\304\202\235>\337\326\260\276\257T\250\277\322\233\243\277\212\371?\277\347\345,?B{\300?\031!\302?S\016h?6\026\255<\347Y6\277\203\252f\277\314p\030\2767\335R?\032\377\347\277q\302\232?e\353\220?\3744\212?\351A\255?\336H\334\277\304\022D?\006^\216?\337\336\367?.zo\277H\036\211\277\363\344\274\277\371,\276?\"t\265?\\\373\232<L\217\342?wy\330?\\^\371?c\251\324\277j,\305\276\255<\247?\317\336\353\277\354[\255\276)\265\340?\250\257o?p\r\233?O1\366\276\370\246\033?T\332\001?\275\031(\277\001\313a\276\327CA>\t\023\277\277@\233z?a\202o\277\361\205\322\277e\204\034?\260\254~\277u\233W>\240N\220?Jw\361\277u\031\257\276\200\001\347\277\307U0>2\305\237?\320~<?\276\301\346?\272x\206\277\2279\\?\250\013\260\277\334\275\n\277\217=\345\276REy?m3x?\022\270{>\365\227\356?\343\360\353\276A\021\273?\002\017\311?\366t\243\277\271\232w?j \272\277\271\265\221?\310_\373?\2670T\277\331\274u<e7D?P<\232?\355M\362\277`<\305>\262B\356?\301\347g?\372\206\272\277\327B\321\277\017Y\324?O\244\353\277k\345\216?W$o?\350\210\333=\345\313\237?\353\213\353\2774\265\217?\037\366\346\276\r \201=\227Bj\2773\232\006\277\007\235f>n\346\272\277\2729\326?\354U\273?Q \224\277yT>?\231\3019\276rH\271\277\374F\324?\007?[\276b\220\320?<M]\277\207w\257?\327i\273>%\214;>\0209\362?!\317\362\277\2128\361?VG\204?\r\242\322\276\305a\357\277#\350\245\277\n=\262?j\324\203>f\336\237\277\022\332\r\277-s`\277\350\267O\276\036\324\360\277\277rU\277\246\301\354?\311\250?\277h\373\353==\201\275>\244\036\340?\274\372\264=\016\202\225?+\000Y\277\032\221\244>\263j\303>a\374\254\276q\216\302>\374I]\277\034B\276\277\340\352\263?Q\001I\277\344/\322\276\\\037\364\276uA\354=\004\377\007?\263\347\215>\377\0146?\265\231m\277\305\270\256=.^\355?\307\n\367?yOF?\364\374b>Q\233\227\277\361x~?\320\n}\276\037\375\243\277\\v\265\277\034\227\202\276\345k\370>\"5\201\277\246\344\251?\210x\252?V\001\255?\0345\221\277,H%?\316\2102\277J1\342\277\036]\222\277\320\305\232?o\024\355\277\326S\355?\340\177\217\277\310\234m\277\0220\212?\317\240/\277\023\344\350?\334\301s?\205\207\242\276\211J\220?\227B4?H\366\013?\274\315\361>\243\026\206\277\357\207\222\277\005\226\236>\240\036\216>c=)?>\364\332>P\343\305?1B}?\312)+\277F@\231>\244F\325\276\314\345q>\221\025\310?P/\202?\323\353\014?(\351|?4\014e?\317\r\274\275\223\361\213\277\352\304@\276\3227\237\277\350\267(?QL\206\277\373\213G?w\007\007\277$\232\324\277\3331\"\277\267\210\276\277\324v\263\276\204\256\305?\\7\342\277C\365I\277\312D\202?\304\035\335?>\260\247?O\226\246\276\030\203\300?\216\263$\277\376\205\177\277M\371\320\277\003G\017?e\224\355?)\266J\276\362\260\240?a\341\301>^Y\300\276\341\"\372\276\234l\247?I\302\244?\202\300\234\277C\234:\277\357L\226\277\r\323l?\321\034Q\277(\343\301?\366?\336\277\261\323\261\277I\354\314?p\244\351?\327\341F?\325\332\324\277\3107\025<\227%\177\277\244\261\356>)\372\316?l%\013\277\317]\272\277\214\350%\276\373K\343\274o\354\237\27718\364?_\262\273\275\244\364\333?\2056\263?\000\'\334\276\314F\343>|\224??o@\262?\323\371\271\276\374\t5\277R\300\326\275F\247\246\277\2371\201\277\300:\312\276\n\227\336\276P\207+\276\377q\324?\242\207\231\277\031\313v?]\n\352\277\310|\227?\370[8?o,\251\277\201\023\255\277\030Z\213?\"c\234?@\327\366\277\250\261\277\275\230\310\302\276\264\304\363?fm\224?\351)\032\277\345U9\277\264\201\271\277\022\344\307\277\310:\271\277L\304\316?I\274\310\277\270\237\202\276N\244\257\277\350\241s\277\276\r\241? \244\244?Hk\335\277\337&\264?\007\255\302?\303\374\374?iN\312\277\333\221\311?&-\275\276\271\033\177?|4l\277\034\236f\275M\214\275\277\216\260\350?\265\315\376?Y\370\342>\304\2323?\336\3515?\263\2619>\333|e\277v\033\311\276\332\337$?\230v\235\277\031\016\351\275\021+\234?\265\033\307\276\242\356\347?\260y\004\277<\222$\277PP\302?\3108[\277\366\301\306\277o[\205?\036\354\375?\360bp>zi\255\277$L\\\277\nO\225\277\337)\210\277\276\002Q\276\364<\200\276JDY\277C\035\267?\317\032\321\277~g\215\275\n/[\277\231\307I\277\212\037\245>U\251\226\277\027\020\271\277\206\320\261\276^i\310\276f\336\367\277r$\211\277\343y\322\276\261\355\226?\026\231\360\277\225\003\252\277\376\325\204?\326C\236\277)t\243?\223\035\000?\036\003\252\276\026\271\317?\357\023\362\277\264J\334\276\010$\274\275W\226:\277\267\347\004?\034\331\211\277(\313\206>\322$\341\277\234\263\300?G\336\005?\331u\203\277\251\350@>)6y\277j\307q?V>\367?\274\\\201\277\324\"\366>U\376\320\274)\005\203>j\367\274\2779\256\266=\025\333\366?\026\367\356\277\341\200\226\275\005Q\271>\304\000I?\017\336\341?\215\363\346\276\236s(?q\214\241\276\'\351\251\277c\227\036\276I\331\231?K\372I=z\352\213\275\026\024\361\277\210G\302?_\320\302\277\267y\207\277u2)\277\375E\235\277=\225\267\274\221?\306?\265\346\226\277\324|\002>a\365\311\277\265\242\347\277\025\373\031?\207G\311?\230\023\336\277\264\273\256\277\266\032\270>f\311\243\277@\237\217\277W\345H\276\204\261\323\277B\213\032\277&]\320\276q\226\244\277\336B\342>\232\022\252>\366\376\257?P\024\247\277\251@\226?\233\273\311>\200A\312?\352O\245?\357\277\265\277\222\341\234?\307q\343>\030l\372\277\242\316(?1\3767\277Di\277?\005\352\257?ZE\313?\270-\237\277\360\316\245?\020X\330\276d\276b\277\021\243\215\277\3233Z\276\0202\271\277t\310\314?A\331\254?\326|{>$\017\264\277{\375\306?\336\210\374>E`\237?\002\342>>\n\244\342?\215\004M?\201\370\323\276L\016\231?#\t\017=\3600\333\277\020\240\217?|\363]?T\254\214?~\344\240?\312\260\217\277\324z\376\277\264\260v?\344\223\336\277\225\260A\27724\310?F\227\263>8\222P>\343\032\202?\272T\225\277\002\250\213?\232\333f\277Dx\221\277\037\366&\277O\314\303?9\231\373\277\222\027\305\276\266Z\203\277\005$\003?\232h\222?%\215\212\276%\233\257?\203O\336\277\336W\304\277\357\007\342\277\014\326\207?\263Wa?6\330\275>\322\255\217\277qm\240\277w\335\267?\212\351\344\277!s\342\277.\211\347\277\243fc?\361\357\374>o\341\216\277\271\335\260<\r:\343\277\204\016b\2772%\214?\235\005\312\277C\301\234\276r\332x\277\261\214\370>d:\340>\026\317\234\277\202\341Z?\216\213*?\304{\334?\0368\277?\221vG?8r\327?I~\206\277\255\325x>\243\n\222\277\345\301%\277\245\313\204?\241)\340\277\265%\253\277\002;\243\276\342\027\253\277\300\372\226\277V\314f?\230M\342\276\241NM\277\013\371e?\216\026\032\277\343\377y\277\323q\330\277\203\223L?\355\360\257?S^t?{Q\372\277\306\270\230\276\226U\221>9B\345?\322\003\277?\371\222\342?\361\301\232\276\224\267\204\277\t\253\354\277\201\177\357\2756\332\373?*\270\324\277\3037\034?\317\225]\277\361J\356\276\006qz>~q\200?\234s\372\277\026\312\271?x\304\371=tI\217\277\254\235\177?p\367\322?K\351<\277\371l\300?\"A\357\277\274\366\323\276\374\327\311\277N\266\207\277g\333\t>\177q\262?\302\004$?` \321?\024?\313?\347\262\237\276e\331\264?\005:\366\277\256\367\317>\351\352\346?\013\221\"\277m\225\313\2761\345k>2*\r\277\317\003\245?l`\356\277A\304\351?\0061\377\277+\255{?\021\332\253\277LO\254?f\265\357\277\331\036\331?\021Zl?\331\334\203?\211\360C\277\245\273\277\277\366\001\345?\213v\261?*6\260?\202\035\331\277~^\201?\351\257\324?\352\032I?\302\337\316\277$\027\237\277\177T\344?\026\307\201?\017K\021\277-\027\204>\017T\240\277_C\235?\333b\232>\026;W\277\005\232\037?"
+ }
+ }
+ }
+}
+node {
+ name: "gen_quant_npy"
+ op: "FakeQuantWithMinMaxArgs"
+ input: "gen_quant_npy/inputs"
+ attr {
+ key: "max"
+ value {
+ f: 1.9997752904891968
+ }
+ }
+ attr {
+ key: "min"
+ value {
+ f: -1.9998407363891602
+ }
+ }
+ attr {
+ key: "narrow_range"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "num_bits"
+ value {
+ i: 8
+ }
+ }
+}
+node {
+ name: "fake_quant_inputs"
+ op: "FakeQuantWithMinMaxArgs"
+ input: "placeholder_0"
+ attr {
+ key: "max"
+ value {
+ f: 1.9997752904891968
+ }
+ }
+ attr {
+ key: "min"
+ value {
+ f: -1.9998407363891602
+ }
+ }
+ attr {
+ key: "narrow_range"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "num_bits"
+ value {
+ i: 8
+ }
+ }
+}
+node {
+ name: "fake_quant_wts"
+ op: "FakeQuantWithMinMaxArgs"
+ input: "const_1"
+ attr {
+ key: "max"
+ value {
+ f: 1.9936614036560059
+ }
+ }
+ attr {
+ key: "min"
+ value {
+ f: -1.9891220331192017
+ }
+ }
+ attr {
+ key: "narrow_range"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "num_bits"
+ value {
+ i: 8
+ }
+ }
+}
+node {
+ name: "conv"
+ op: "Conv2D"
+ input: "fake_quant_inputs"
+ input: "fake_quant_wts"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "dilations"
+ value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "explicit_paddings"
+ value {
+ list {
+ }
+ }
+ }
+ attr {
+ key: "padding"
+ value {
+ s: "SAME"
+ }
+ }
+ attr {
+ key: "strides"
+ value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "use_cudnn_on_gpu"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "result"
+ op: "FakeQuantWithMinMaxArgs"
+ input: "conv"
+ attr {
+ key: "max"
+ value {
+ f: 10.0
+ }
+ }
+ attr {
+ key: "min"
+ value {
+ f: -10.0
+ }
+ }
+ attr {
+ key: "narrow_range"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "num_bits"
+ value {
+ i: 8
+ }
+ }
+}
+versions {
+ producer: 498
+}
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt
new file mode 100644
index 0000000..0ba8afb
--- /dev/null
+++ b/reference_model/CMakeLists.txt
@@ -0,0 +1,76 @@
+cmake_minimum_required (VERSION 3.4)
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+project(tosa_reference_model LANGUAGES CXX)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+if(CMAKE_CXX_COMPILER_ID STREQUAL GNU)
+ set(CMAKE_CXX_FLAGS "-Wall -Wno-ignored-attributes -Wno-format-truncation")
+else()
+ set(CMAKE_CXX_FLAGS "-Wall -Wno-ignored-attributes")
+endif()
+
+set(FLATBUFFERS_DIR "../thirdparty/flatbuffers/")
+set(SERIALIZATION_DIR "../serialization")
+
+set (CXX_SOURCE
+ src/main.cpp
+ src/tensor.cc
+ src/graph_node.cc
+ src/subgraph_traverser.cc
+ src/func_debug.cc
+ src/func_config.cc
+ src/ops/op_factory.cc
+ src/ops/tensor_ops.cc
+ src/ops/activation_funcs.cc
+ src/ops/ewise_binary.cc
+ src/ops/ewise_unary.cc
+ src/ops/ewise_ternary.cc
+ src/ops/comparison.cc
+ src/ops/reduction.cc
+ src/ops/data_layout.cc
+ src/ops/scatter_gather.cc
+ src/ops/image.cc
+ src/ops/type_conversion.cc
+ src/ops/data_nodes.cc
+ src/ops/custom.cc
+ src/ops/control_flow.cc
+)
+
+add_executable(tosa_reference_model ${CXX_SOURCE})
+
+target_include_directories(tosa_reference_model
+ PUBLIC
+ $<INSTALL_INTERFACE:include>
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_SRC_DIR}/include>
+ PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}/src
+ ${FLATBUFFERS_DIR}/include
+ ../thirdparty/eigen/
+ ../thirdparty/eigen/unsupported/
+ ${SERIALIZATION_DIR}
+)
+
+target_link_libraries(tosa_reference_model
+ PRIVATE
+ flatbuffers
+ tosa_serialization
+)
+
+install (TARGETS tosa_reference_model DESTINATION bin)
diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h
new file mode 100644
index 0000000..554a7a2
--- /dev/null
+++ b/reference_model/src/arith_util.h
@@ -0,0 +1,194 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ * Filename: src/arith_util.h
+ * Description:
+ * arithmetic utility macro, include:
+ * fp16 (float16_t ) type alias
+ * bitwise operation
+ * fix point arithmetic
+ * fp16 type conversion(in binary translation)
+ * fp16 arithmetic (disguised with fp32 now)
+ */
+
+#ifndef ARITH_UTIL_H
+#define ARITH_UTIL_H
+
+#include <fenv.h>
+#include <math.h>
+#define __STDC_LIMIT_MACROS //enable min/max of plain data type
+#include "func_debug.h"
+#include "inttypes.h"
+#include <cassert>
+#include <iostream>
+#include <limits>
+#include <stdint.h>
+#include <typeinfo>
+
+using namespace std;
+
+inline size_t _count_one(uint64_t val)
+{
+ size_t count = 0;
+ for (; val; count++)
+ {
+ val &= val - 1;
+ }
+ return count;
+}
+
+template <typename T>
+inline size_t _integer_log2(T val)
+{
+ size_t result = 0;
+ while (val >>= 1)
+ {
+ ++result;
+ }
+ return result;
+}
+
+template <typename T>
+inline size_t _count_leading_zeros(T val)
+{
+ size_t size = sizeof(T) * 8;
+ size_t count = 0;
+ T msb = static_cast<T>(1) << (size - 1);
+ for (size_t i = 0; i < size; i++)
+ {
+ if (!((val << i) & msb))
+ count++;
+ else
+ break;
+ }
+ return count;
+}
+
+template <typename T>
+inline size_t _count_leading_ones(T val)
+{
+ size_t size = sizeof(T) * 8;
+ size_t count = 0;
+ T msb = static_cast<T>(1) << (size - 1);
+ for (size_t i = 0; i < size; i++)
+ {
+ if ((val << i) & msb)
+ count++;
+ else
+ break;
+ }
+ return count;
+}
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+// Compute ceiling of (a/b)
+#define DIV_CEIL(a, b) ((a) % (b) ? ((a) / (b) + 1) : ((a) / (b)))
+
+// Returns a mask of 1's of this size
+#define ONES_MASK(SIZE) ((uint64_t)((SIZE) >= 64 ? 0xffffffffffffffffULL : ((uint64_t)(1) << (SIZE)) - 1))
+
+// Returns a field of bits from HIGH_BIT to LOW_BIT, right-shifted
+// include both side, equivalent VAL[LOW_BIT:HIGH_BIT] in verilog
+
+#define BIT_FIELD(HIGH_BIT, LOW_BIT, VAL) (((uint64_t)(VAL) >> (LOW_BIT)) & ONES_MASK((HIGH_BIT) + 1 - (LOW_BIT)))
+
+// Returns a bit at a particular position
+#define BIT_EXTRACT(POS, VAL) (((uint64_t)(VAL) >> (POS)) & (1))
+
+// Use Brian Kernigahan's way: https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetKernighan
+// Does this need to support floating point type?
+// Not sure if static_cast is the right thing to do, try to be type safe first
+#define ONES_COUNT(VAL) (_count_one((uint64_t)(VAL)))
+
+#define SHIFT(SHF, VAL) (((SHF) > 0) ? ((VAL) << (SHF)) : ((SHF < 0) ? ((VAL) >> (-(SHF))) : (VAL)))
+#define ROUNDTO(A, B) ((A) % (B) == 0 ? (A) : ((A) / (B) + 1) * (B))
+#define ROUNDTOLOWER(A, B) (((A) / (B)) * (B))
+#define BIDIRECTIONAL_SHIFT(VAL, SHIFT) (((SHIFT) >= 0) ? ((VAL) << (SHIFT)) : ((VAL) >> (-(SHIFT))))
+#define ILOG2(VAL) (_integer_log2(VAL))
+
+// Get negative value (2's complement)
+#define NEGATIVE_8(VAL) ((uint8_t)(~(VAL) + 1))
+#define NEGATIVE_16(VAL) ((uint16_t)(~(VAL) + 1))
+#define NEGATIVE_32(VAL) ((uint32_t)(~(VAL) + 1))
+#define NEGATIVE_64(VAL) ((uint64_t)(~(VAL) + 1))
+// Convert a bit quanity to the minimum bytes required to hold those bits
+#define BITS_TO_BYTES(BITS) (ROUNDTO((BITS), 8) / 8)
+
+// Count leading zeros/ones for 8/16/32/64-bit operands
+// (I don't see an obvious way to collapse this into a size-independent set)
+// treated as unsigned
+#define LEADING_ZEROS_64(VAL) (_count_leading_zeros((uint64_t)(VAL)))
+#define LEADING_ZEROS_32(VAL) (_count_leading_zeros((uint32_t)(VAL)))
+#define LEADING_ZEROS_16(VAL) (_count_leading_zeros((uint16_t)(VAL)))
+#define LEADING_ZEROS_8(VAL) (_count_leading_zeros((uint8_t)(VAL)))
+#define LEADING_ZEROS(VAL) (_count_leading_zeros(VAL))
+
+#define LEADING_ONES_64(VAL) _count_leading_ones((uint64_t)(VAL))
+#define LEADING_ONES_32(VAL) _count_leading_ones((uint32_t)(VAL))
+#define LEADING_ONES_16(VAL) _count_leading_ones((uint16_t)(VAL))
+#define LEADING_ONES_8(VAL) _count_leading_ones((uint8_t)(VAL))
+#define LEADING_ONES(VAL) _count_leading_ones(VAL)
+// math operation
+// sign-extended for signed version
+// extend different return type (8, 16, 32) + (S, U)
+// Saturate a value at a certain bitwidth, signed and unsigned versions
+// Format is as followed: SATURATE_VAL_{saturation_sign}_{return_type}
+// for example
+// SATURATE_VAL_U_8U(8,300) will return uint8_t with value of 255(0xff)
+// SATURATE_VAL_S_32S(5,-48) will return int32_t with value of -16(0x10)
+// note that negative value can cast to unsigned return type using native uint(int) cast
+// so SATURATE_VAL_S_8U(5,-40) will have value 0'b1110000 which is in turn 224 in uint8_t
+
+template <typename T>
+constexpr T bitmask(const uint32_t width)
+{
+ ASSERT(width <= sizeof(T) * 8);
+ return width == sizeof(T) * 8 ? static_cast<T>(std::numeric_limits<uintmax_t>::max())
+ : (static_cast<T>(1) << width) - 1;
+}
+
+template <typename T>
+constexpr T minval(const uint32_t width)
+{
+ ASSERT(width <= sizeof(T) * 8);
+ return std::is_signed<T>::value ? -(static_cast<T>(1) << (width - 1)) : 0;
+}
+
+template <typename T>
+constexpr T maxval(const uint32_t width)
+{
+ ASSERT(width <= sizeof(T) * 8);
+ return bitmask<T>(width - std::is_signed<T>::value);
+}
+
+template <typename T>
+constexpr T saturate(const uint32_t width, const intmax_t value)
+{
+ // clang-format off
+ return static_cast<T>(
+ std::min(
+ std::max(
+ value,
+ static_cast<intmax_t>(minval<T>(width))
+ ),
+ static_cast<intmax_t>(maxval<T>(width))
+ )
+ );
+ // clang-format on
+}
+
+#endif /* _ARITH_UTIL_H */
diff --git a/reference_model/src/debug_modes.def b/reference_model/src/debug_modes.def
new file mode 100644
index 0000000..51b151d
--- /dev/null
+++ b/reference_model/src/debug_modes.def
@@ -0,0 +1,20 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Defines the debugging printing modes
+
+DEBUG_MODE(CONFIG,0) // Configuration parsing/initialization
+DEBUG_MODE(GT,1) // Graph traverser
+DEBUG_MODE(OP,2) // Operation
diff --git a/reference_model/src/debug_types.h b/reference_model/src/debug_types.h
new file mode 100644
index 0000000..bd93f19
--- /dev/null
+++ b/reference_model/src/debug_types.h
@@ -0,0 +1,57 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ * Filename: src/debug_types.h
+ * Description:
+ * Defines fundamental debugger datatypes for the functional model
+ */
+
+#ifndef DEBUG_TYPES_H_
+#define DEBUG_TYPES_H_
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif
+
+ // Debug verbosity mask
+ typedef enum func_debug_verbosity_e
+ {
+ DEBUG_VERB_NONE = 0x00,
+ DEBUG_VERB_INFO = 0x01, // Informational debugging messages
+ DEBUG_VERB_IFACE = 0x02, // Interface debugging support
+ DEBUG_VERB_LOW = 0x04, // Low, medium, and high levels of debug printout
+ DEBUG_VERB_MED = 0x08,
+ DEBUG_VERB_HIGH = 0x10
+ } func_debug_verbosity_e;
+
+ // Generated debug modes enumeration
+ typedef enum func_debug_mode_e
+ {
+ DEBUG_NONE = 0x0,
+#define DEBUG_MODE(NAME, BIT) DEBUG_##NAME = (1UL << BIT),
+#include "debug_modes.def"
+#undef DEBUG_MODE
+ DEBUG_ALL = 0xffffffffffffffffUL
+ } func_debug_mode_e;
+
+#define DEBUG_INST_ALL 0xffffffffffffffffUL
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/reference_model/src/func_config.cc b/reference_model/src/func_config.cc
new file mode 100644
index 0000000..bd1ce32
--- /dev/null
+++ b/reference_model/src/func_config.cc
@@ -0,0 +1,632 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <ctype.h>
+#include <signal.h>
+#include <stdarg.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/types.h>
+
+#include "func_config.h"
+#include "func_debug.h"
+
+#define MAX_NAME_LEN 128
+#define MAX_DESC_LEN 128
+
+#ifndef ARG_ERROR
+#define ARG_ERROR(...) \
+ fprintf(stderr, "ERROR: "); \
+ fprintf(stderr, __VA_ARGS__); \
+ fprintf(stderr, "\n"); \
+ return 1;
+#endif
+
+// Parameter base name string table
+const char* config_base_name_table[] = {
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) #NAME,
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) #NAME,
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) #NAME,
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) #NAME,
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+#undef DEF_UNIT_OPTION
+};
+
+// Parameter description table
+const char* config_param_desc_table[] = {
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) #DESC,
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) #DESC,
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) #DESC,
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) #DESC,
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+};
+
+// String table and enum for the option hierarchy level/sub-levels
+// (no leaf options). Attribute at the top level have "BASE" as their
+// enum value and an empty string for the value.
+const char* config_hier_str_table[] = {
+ "",
+#define DEF_UNIT_START(UNIT) #UNIT,
+#define DEF_UNIT_END(UNIT) /**/
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) /**/
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) /**/
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) /**/
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) /**/
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+};
+
+typedef enum config_hier_enum_t
+{
+ BASE,
+#define DEF_UNIT_START(UNIT) CURRENT_UNIT,
+#define DEF_UNIT_END(UNIT) /**/
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) /**/
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) /**/
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) /**/
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) /**/
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+
+ MAX_CONFIG_HIER
+} config_hier_enum_t;
+
+// Mapping from a leaf parameter index to the
+// position in the hierarchy.
+config_hier_enum_t config_hierarchy_map[] = {
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) BASE,
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) BASE,
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) CURRENT_UNIT,
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) CURRENT_UNIT,
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+};
+
+#define CONFIG_PARAMETER_COUNT (sizeof(config_hierarchy_map) / sizeof(config_hier_enum_t))
+
+// Dynamically generated at initialization
+char** config_param_str_table = nullptr;
+
+// Initialize the configuration data structures
+int func_model_init_config()
+{
+ // Initialize string table (builds the hierarchical names)
+ config_param_str_table = (char**)calloc(CONFIG_PARAMETER_COUNT, sizeof(char*));
+ ASSERT_MEM(config_param_str_table);
+
+ for (uint32_t i = 0; i < CONFIG_PARAMETER_COUNT; i++)
+ {
+ size_t len = strlen(config_base_name_table[i]) + 1;
+ if (config_hierarchy_map[i] != BASE)
+ {
+ ASSERT_MSG(config_hierarchy_map[i] <= MAX_CONFIG_HIER,
+ "Configuration parameter\'s hierarchy is out of bounds");
+ len += strlen(config_hier_str_table[config_hierarchy_map[i]]) + 1;
+ }
+ config_param_str_table[i] = (char*)calloc(len, 1);
+ ASSERT_MEM(config_param_str_table[i]);
+ ASSERT_MSG(len < MAX_NAME_LEN, "option expanded name is too long: %s", config_base_name_table[i]);
+
+ if (config_hierarchy_map[i] != BASE)
+ {
+ snprintf(config_param_str_table[i], len, "%s.%s", config_hier_str_table[config_hierarchy_map[i]],
+ config_base_name_table[i]);
+ }
+ else
+ {
+ snprintf(config_param_str_table[i], len, "%s", config_base_name_table[i]);
+ }
+ }
+
+ return 0;
+}
+
+int func_model_set_default_config(func_config_t* func_config)
+{
+ // Set default values in the global configuration data structure
+ bzero(func_config, sizeof(*func_config));
+
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) func_config->NAME = (DEFAULT);
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) strncpy(func_config->NAME, (DEFAULT), (LEN)-1);
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) func_config->UNIT.NAME = (DEFAULT);
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) strncpy(func_config->UNIT.NAME, (DEFAULT), (LEN)-1);
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+
+ return 0;
+}
+
+int func_model_config_cleanup()
+{
+ uint32_t i;
+
+ if (!config_param_str_table)
+ return 1;
+
+ for (i = 0; i < CONFIG_PARAMETER_COUNT; i++)
+ {
+ free(config_param_str_table[i]);
+ }
+
+ free(config_param_str_table);
+ config_param_str_table = nullptr;
+
+ return 0;
+}
+
+int func_model_config_set_option(func_config_t* func_config, const char* name, const char* value)
+{
+ // Increment an index variable on each parameter position
+ // so that we can index both the position struct through the macro and the
+ // array of parameter names through a simple array of strings.
+ int param_idx = 0;
+ char* endptr;
+
+ // TODO: does not handle strings yet. Can set magic values on FMT to
+ // choose a string copy vs strtoull
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ func_config->NAME = (uint64_t)strtoll(value, &endptr, 0); \
+ if (endptr == value) \
+ { \
+ ARG_ERROR("Cannot parse option: %s = %s", name, value); \
+ } \
+ return 0; \
+ } \
+ param_idx++;
+
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ if (strlen(value) >= LEN) \
+ { \
+ ARG_ERROR("Option value is too long: %s = %s", name, value); \
+ } \
+ strncpy(func_config->NAME, value, (LEN)-1); \
+ return 0; \
+ } \
+ param_idx++;
+
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ func_config->UNIT.NAME = (uint64_t)strtoll(value, &endptr, 0); \
+ if (endptr == value) \
+ { \
+ ARG_ERROR("Cannot parse option: %s = %s", name, value); \
+ } \
+ return 0; \
+ } \
+ param_idx++;
+
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ if (strlen(value) >= LEN) \
+ { \
+ ARG_ERROR("Option value is too long: %s = %s", name, value); \
+ } \
+ strncpy(func_config->UNIT.NAME, value, (LEN)-1); \
+ return 0; \
+ } \
+ param_idx++;
+
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+
+ // No match!
+ ARG_ERROR("Cannot find option: %s", name);
+
+ return 1;
+}
+
+int func_model_config_get_option_by_name(func_config_t* func_config, const char* name, uint64_t* val)
+{
+ // Increment an index variable on each parameter position
+ // so that we can index both the position struct through the macro and the
+ // array of parameter names through a simple array of strings.
+ int param_idx = 0;
+
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) param_idx++;
+
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, FMT, DEFAULT) param_idx++;
+
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ *val = func_config->NAME; \
+ return 0; \
+ } \
+ param_idx++;
+
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ *val = func_config->UNIT.NAME; \
+ return 0; \
+ } \
+ param_idx++;
+
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+ // No match!
+ return 1;
+}
+int func_model_config_get_str_option_by_name(func_config_t* func_config,
+ const char* name,
+ char* value,
+ const uint32_t len)
+{
+ // Increment an index variable on each parameter position
+ // so that we can index both the position struct through the macro and the
+ // array of parameter names through a simple array of strings.
+ int param_idx = 0;
+
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ strncpy(value, func_config->NAME, len - 1); \
+ return 0; \
+ } \
+ param_idx++;
+
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ strncpy(value, func_config->UNIT.NAME, len - 1); \
+ return 0; \
+ } \
+ param_idx++;
+
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) param_idx++;
+
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) param_idx++;
+
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+ // No match!
+ return 1;
+}
+
+int func_config_print_config_help(FILE* out)
+{
+ fprintf(out, "%-40s %s\n", "Option", "Description");
+ fprintf(out, "%-40s %s\n", "------", "-----------");
+
+ for (uint32_t i = 0; i < CONFIG_PARAMETER_COUNT; i++)
+ {
+ fprintf(out, "-C%-40s %s\n", config_param_str_table[i], config_param_desc_table[i]);
+ }
+
+ fprintf(out, "\n");
+
+ return 0;
+}
+
+int func_model_print_config(func_config_t* func_config, FILE* out)
+{
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) fprintf(out, "%-40s = " FMT "\n", #NAME, func_config->NAME);
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) \
+ fprintf(out, "%-40s = " FMT "\n", #UNIT "." #NAME, func_config->UNIT.NAME);
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) fprintf(out, "%-40s = %s\n", #NAME, func_config->NAME);
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) \
+ fprintf(out, "%-40s = %s\n", #UNIT "." #NAME, func_config->UNIT.NAME);
+
+#define FOF_HEX "0x%llx"
+#define FOF_DEC "%" PRIu32
+#define FOF_DECU64 "%" PRIu64
+
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+
+ return 0;
+}
+
+static const char* programname;
+
+void func_model_print_debug_masks(FILE* out)
+{
+ fprintf(out, "\t List of components:\n");
+#define DEBUG_MODE(string, value) fprintf(out, "\t\t" #string "\n");
+#include "debug_modes.def"
+#undef DEBUG_MODE
+}
+
+int func_model_print_help(FILE* out)
+{
+ fprintf(out, "TOSA Reference Model help\n\n");
+
+ fprintf(out,
+ "Usage: %s [-c] [-C <name=value>] [-d <Debug Mask>] [-h] [-i <uscriptfile>] [-l <verbosity>] [-F "
+ "<flatconfig>]\n",
+ programname);
+ fprintf(out, "\t-c - Print list of config options\n");
+ fprintf(out, "\t-C <name=value> - modify config option <name> to <value>\n");
+ fprintf(out, "\t-d <Debug Mask - set component debug mask\n");
+ func_model_print_debug_masks(out);
+ fprintf(out, "\t-F <flatconfig> - parse <flatconfig> as file of config options\n");
+ fprintf(out, "\t-h - show this help message and exit\n");
+ fprintf(
+ out,
+ "\t-i <input_tensor_name>,<filename> - set input tensor <input_tensor_name> to the values from <filename>\n");
+ fprintf(out, "\t-l <verbosity> - set log verbosity\n");
+ fprintf(out, "\t-o <debuglog> - set debug log file\n");
+ fprintf(out, "\n");
+
+ func_config_print_config_help(stdout);
+
+ return 0;
+}
+
+static const char* get_arg_text(int& index, const int argc, const char** argv)
+{
+ if (strlen(argv[index]) > 2)
+ {
+ return argv[index] + 2;
+ }
+
+ if ((index + 1 == argc) || (argv[index + 1][0] == '-'))
+ {
+ fprintf(stderr, "No option value found for option %s\n", argv[index]);
+ return "";
+ }
+
+ index++;
+ return argv[index];
+}
+
+// Read the command line arguments
+int func_model_parse_cmd_line(func_config_t* func_config, func_debug_t* func_debug, const int argc, const char** argv)
+{
+ int i;
+ programname = argv[0];
+ for (i = 1; i < argc; i++)
+ {
+ // All command line arguments must begin with -X where X is a recognized character
+ if (strlen(argv[i]) < 2 || argv[i][0] != '-')
+ {
+ func_model_print_help(stderr);
+ ARG_ERROR("Command line argument at position %d not valid: %s", i, argv[i]);
+ }
+
+ switch (argv[i][1])
+ {
+ // Model parameters may be overridden with the -Cname=value switch
+ case 'c':
+ func_config_print_config_help(stderr);
+ return 1;
+
+ case 'C':
+ {
+ const char *name = nullptr, *value = nullptr;
+
+ // Break the string into name and value parts
+ name = get_arg_text(i, argc, argv);
+ value = strchr(name, '=');
+
+ if (value == nullptr)
+ {
+ func_model_print_help(stderr);
+ ARG_ERROR("Cannot parse -C argument at position %d: %s", i, argv[i]);
+ }
+
+ *const_cast<char*>(value) = 0;
+
+ if (func_model_config_set_option(func_config, name, value + 1))
+ {
+ func_model_print_help(stderr);
+ ARG_ERROR("Cannot parse -C argument at position %d: %s", i, argv[i]);
+ }
+ break;
+ }
+
+ case 'd':
+ case 'D':
+ {
+ func_debug_set_mask(func_debug, get_arg_text(i, argc, argv));
+ break;
+ }
+ case 'F':
+ {
+ // Read a flat configuration file
+ if (func_model_parse_flat_config_file(func_config, get_arg_text(i, argc, argv)))
+ return 1;
+
+ break;
+ }
+ case 'h':
+ func_model_print_help(stderr);
+ return 1;
+
+ case 'i':
+ {
+ // shortcut for '-Cinput_tensor='
+ if (func_model_config_set_option(func_config, "input_tensor", get_arg_text(i, argc, argv)))
+ {
+ func_model_print_help(stderr);
+ ARG_ERROR("Cannot set input tensor config value");
+ }
+ break;
+ }
+ case 'l':
+ {
+ // Debug verbosity/logging level
+ func_debug_set_verbosity(func_debug, get_arg_text(i, argc, argv));
+ break;
+ }
+ case 'o':
+ {
+ func_debug_set_file(func_debug, get_arg_text(i, argc, argv));
+ break;
+ }
+ default:
+ func_model_print_help(stderr);
+ ARG_ERROR("Unrecognized argument at position %d: %s", i, argv[i]);
+ }
+ }
+
+ return 0;
+}
+
+int func_model_parse_flat_config_file(func_config_t* func_config, const char* filename)
+{
+ const int MAX_LINE_LEN = 1024;
+
+ FILE* infile = nullptr;
+ char line_buf[MAX_LINE_LEN];
+ int line = 1;
+
+ infile = fopen(filename, "r");
+
+ if (infile == nullptr)
+ {
+ ARG_ERROR("Cannot open config file: %s\n", filename);
+ }
+
+ while (fgets(line_buf, MAX_LINE_LEN - 1, infile) != nullptr)
+ {
+ char *name = line_buf, *value = nullptr, *comment = nullptr, *ptr = nullptr;
+
+ // Remove comments
+ comment = strchr(line_buf, '#');
+
+ if (comment)
+ *comment = 0;
+
+ // Break the string into name and value parts
+ name = line_buf;
+
+ // Remove leading whitespace
+ while (*name && isspace(*name))
+ name++;
+
+ // Empty line?
+ if (*name == 0)
+ {
+ line++;
+ continue;
+ }
+
+ value = strchr(name, '=');
+
+ // Missing value
+ if (value == nullptr)
+ {
+ ARG_ERROR("Cannot parse parameter in %s at line %d: %s", filename, line, line_buf);
+ }
+
+ // Remove the =
+ *value = 0;
+ value++;
+
+ // Trim off any whitespace at the end of the value
+ ptr = value;
+ while (*ptr != 0 && !isspace(*ptr))
+ ptr++;
+ *ptr = 0;
+
+ // Include a nested file
+ if (!strcmp(name, "include"))
+ {
+ if (func_model_parse_flat_config_file(func_config, value))
+ return 1;
+ line++;
+ continue;
+ }
+
+ if (func_model_config_set_option(func_config, name, value))
+ {
+ func_model_print_help(stderr);
+ ARG_ERROR("Cannot set parameter in %s at line %d: %s", filename, line, line_buf)
+ }
+
+ line++;
+ }
+
+ fclose(infile);
+
+ return 0;
+}
diff --git a/reference_model/src/func_config.def b/reference_model/src/func_config.def
new file mode 100644
index 0000000..004cf36
--- /dev/null
+++ b/reference_model/src/func_config.def
@@ -0,0 +1,90 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ * Filename: src/func_config.def
+ * Description:
+ * Defines the model parameters/options for the functional model.
+ */
+
+// Placeholder values for the Functional model Option Formatting (FOF) fields
+//
+// FOF_DEC is decimal
+// FOF_HEX is hexidecimal
+//
+// Floating point values are not supported yet, but there is no fundamental reason
+// why we can't have them.
+#ifndef FOF_DEC
+#define FOF_DEC 1
+#endif
+
+#ifndef FOF_HEX
+#define FOF_HEX 1
+#endif
+
+#ifndef FOF_STR_LEN
+#define FOF_STR_LEN 1024
+#endif
+
+// Options are defined as follows:
+// DEF_OPTION() defines a top-level option
+// Arguments:
+// option_field_name: a C-syntax field name in the struct
+// description: a short string that describes the purpose of the option (printed out with help)
+// C type: the type of the option (typically a uint64_t, uint32_t, etc)
+// Format field: the FOF_* type used to figure out how to format/print the option
+// Default value: the default value assigned to the option, if it isn't assigned by an configuration file
+// or command line override
+
+// For defining hierarchical options (example hierarchy is 'cle', use the following formula).
+// All options within the hierarchical space must be grouped together:
+//
+
+// #define CURRENT_UNIT cle
+// DEF_UNIT_START(CURRENT_UNIT)
+// DEF_UNIT_OPTION(CURRENT_UNIT,...)
+// ...
+// DEF_UNIT_END(CURRENT_UNIT)
+// #undef CURRENT_UNIT
+//
+// The CURRENT_UNIT argument is required as a parameter in these definitions because
+// macro processing rules only allow stringification of macro parameters. Unfortunately,
+// Other tokens that are NOT passed in as macro parameters cannot be stringified.
+
+DEF_OPTION_STR(operator_fbs, "Flat buffer syntax file", FOF_STR_LEN, "../serialization/tosa.fbs")
+DEF_OPTION_STR(subgraph_dir, "Subgraph directory to load", FOF_STR_LEN, ".")
+DEF_OPTION_STR(subgraph_file, "Subgraph file to load", FOF_STR_LEN, "")
+DEF_OPTION_STR(input_dir, "Input directory path for dumps/files", FOF_STR_LEN, ".")
+DEF_OPTION_STR(input_tensor, "A list of pairs <name0>:<npy0>,<name1>:<npy1>", FOF_STR_LEN, "")
+DEF_OPTION_STR(output_dir, "Output directory path for output dumps/files", FOF_STR_LEN, ".")
+DEF_OPTION(eval, "Evaluate the network (0/1)", uint32_t, FOF_DEC, 1)
+DEF_OPTION(validate_only, "Validate the network, but do not read inputs or evaluate (0/1)", uint32_t, FOF_DEC, 0)
+DEF_OPTION(output_tensors, "Output tensors to a file (0/1)", uint32_t, FOF_DEC, 1)
+DEF_OPTION(tosa_profile, "Set TOSA profile (0 = Base Inference, 1 = Main Inference, 2 = Main Training)", uint32_t, FOF_DEC, 1)
+DEF_OPTION_STR(output_tensor_prefix, "Optional output tensor prefix", FOF_STR_LEN, "output_")
+DEF_OPTION(dump_intermediates, "Dump intermediate tensors (0/1)", uint32_t, FOF_DEC, 0)
+DEF_OPTION_STR(fp_format, "Floating-point number dump format string (printf-style format, e.g. 0.5)", FOF_STR_LEN, "0.5")
+// Example of a hierarchical option
+//#define CURRENT_UNIT arch
+//DEF_UNIT_START(arch)
+//DEF_UNIT_OPTION(arch, ifm_width, "input feature map width(x dim)", uint32_t, FOF_DEC, 10)
+//DEF_UNIT_END(CURRENT_UNIT)
+///#undef CURRENT_UNIT
+
+// START Do not delete
+// Required for keeping the FOFs clean
+#undef FOF_DEC
+#undef FOF_HEX
+// END Do not delete^^
diff --git a/reference_model/src/func_config.h b/reference_model/src/func_config.h
new file mode 100644
index 0000000..f941300
--- /dev/null
+++ b/reference_model/src/func_config.h
@@ -0,0 +1,55 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FUNC_CONFIG_H_
+#define FUNC_CONFIG_H_
+
+// Parameter value structure
+#define DEF_UNIT_START(UNIT) \
+ struct UNIT##_t \
+ {
+#define DEF_UNIT_END(UNIT) \
+ } \
+ UNIT;
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) TYPE NAME;
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) char NAME[LEN];
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) TYPE NAME;
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) char NAME[LEN];
+struct func_config_t
+{
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION
+#undef DEF_UNIT_OPTION_STR
+};
+
+// Forward declaration
+struct func_debug_t;
+
+int func_model_init_config();
+int func_model_set_default_config(func_config_t*);
+int func_model_config_set_option(func_config_t*, const char* name, const char* value);
+int func_model_print_config(func_config_t*, FILE* out);
+int func_model_parse_cmd_line(func_config_t*, func_debug_t* func_debug, const int argc, const char** argv);
+int func_model_parse_flat_config_file(func_config_t*, const char* filename);
+int func_model_config_cleanup();
+int func_model_config_get_str_option_by_name(func_config_t*, const char* name, char* value, const uint32_t len);
+int func_model_config_get_option_by_name(func_config_t*, const char* name, uint64_t* val);
+int func_model_print_help(FILE* out);
+
+#endif
diff --git a/reference_model/src/func_debug.cc b/reference_model/src/func_debug.cc
new file mode 100644
index 0000000..f5f045e
--- /dev/null
+++ b/reference_model/src/func_debug.cc
@@ -0,0 +1,436 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <ctype.h>
+#include <signal.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/types.h>
+
+#ifndef _MSC_VER
+#include <execinfo.h>
+#include <sys/prctl.h>
+#include <sys/ptrace.h>
+#include <sys/wait.h>
+#include <unistd.h>
+#endif
+
+#include "func_debug.h"
+
+#define MAX_FRAMES 100
+
+#ifndef _MSC_VER
+pid_t func_print_backtrace_helper(int num_tries, int sig);
+#endif
+
+void func_print_backtrace(FILE* out, int sig)
+{
+#ifndef _MSC_VER
+ for (int i = 0; i < 2; i++)
+ {
+ const pid_t child_pid = func_print_backtrace_helper(i, sig);
+ if (child_pid < 0)
+ {
+ perror("Backtrace generation failed on fork");
+ break;
+ }
+
+ int status = 0;
+ waitpid(child_pid, &status, 0);
+ if (WEXITSTATUS(status) == 0)
+ {
+ break;
+ }
+ }
+#endif
+}
+
+#ifndef _MSC_VER
+pid_t func_print_backtrace_helper(int num_tries, int sig)
+{
+ const pid_t child_pid = fork();
+
+ if (child_pid)
+ {
+ return 0;
+ }
+
+ const pid_t ppid = getppid();
+
+ printf("Attaching debugger to pid %d\n", ppid);
+ // Check if we're in a debugger
+ if (ptrace(PTRACE_ATTACH, ppid, 0, 0) == 0)
+ {
+ // If we reach this point, no debugger is present
+ // Undo effects of PTRACE_ATTACH
+ waitpid(ppid, NULL, 0);
+ ptrace(PTRACE_CONT, 0, 0, 0);
+ ptrace(PTRACE_DETACH, ppid, 0, 0);
+
+ dup2(STDERR_FILENO, STDOUT_FILENO);
+
+ char parent_pid[20];
+ snprintf(parent_pid, sizeof(parent_pid), "attach %d", ppid);
+ fprintf(stdout, "Caught signal %d (%s)\n", sig, strsignal(sig));
+
+ execlp("gdb", "gdb", "--batch", "-n", "-ex",
+ // Don't print startup messages for each thread
+ "-ex", "set print thread-events off", "-ex", parent_pid,
+ // Turn off pagination
+ "-ex", "set height 0",
+ // Print a backtrace for the current thread
+ "-ex", "thread $_thread", "-ex", "bt",
+ // Print a backtrace for the main thread (uncomment the next two lines, if desired)
+ //"-ex", "thread 1",
+ //"-ex", "bt",
+ // Print a backtrace for all thread (TMI)
+ //"-ex", "thread apply all bt",
+ NULL);
+
+ // If we reach this point, it is bad. Attempt to print an error before exiting.
+ perror("Backtrace generation failed to invoke gdb");
+ exit(1);
+ }
+
+ // Debugger present. Exit here.
+ exit(0);
+
+ return 0;
+}
+#endif
+
+void func_backtrace_signal_handler(int sig)
+{
+ func_print_backtrace(NULL, sig);
+ exit(1);
+}
+
+// Note: this overwrites other signal handlers. May want to make this
+// more friendly sometime
+void func_enable_signal_handlers()
+{
+ static const int sig_list[] = { SIGABRT, SIGSEGV, SIGILL, SIGFPE };
+
+ if (getenv("FUNC_NO_SIG_HANDLERS"))
+ {
+ return;
+ }
+
+ for (size_t i = 0; i < sizeof(sig_list) / sizeof(int); i++)
+ {
+ struct sigaction act;
+
+ bzero(&act, sizeof(act));
+ act.sa_handler = func_backtrace_signal_handler;
+
+ if (sigaction(sig_list[i], &act, NULL))
+ {
+ perror("Error calling sigaction");
+ }
+ }
+}
+
+const char* func_debug_mode_str_table[] = {
+#define DEBUG_MODE(NAME, BIT) #NAME,
+#include "debug_modes.def"
+#undef DEBUG_MODE
+};
+
+#define DEBUG_MASK_COUNT (sizeof(func_debug_mode_str_table) / sizeof(const char*))
+
+const char* func_debug_verbosity_str_table[] = { "NONE", "INFO", "IFACE", "LOW", "MED", "HIGH" };
+
+const uint32_t func_debug_verbosity_mask_table[] = { DEBUG_VERB_NONE, DEBUG_VERB_INFO, DEBUG_VERB_IFACE,
+ DEBUG_VERB_LOW, DEBUG_VERB_MED, DEBUG_VERB_HIGH };
+
+#define DEBUG_VERBOSITY_COUNT (sizeof(func_debug_verbosity_str_table) / sizeof(const char*))
+
+// Initialize the debug mode
+int func_init_debug(func_debug_t* func_debug, uint64_t inst_id)
+{
+ // Set the default debug settings
+ bzero(func_debug, sizeof(func_debug_t));
+ func_debug_set_mask(func_debug, DEBUG_NONE);
+ func_debug_set_verbosity(func_debug, DEBUG_VERB_NONE);
+ func_debug_set_inst_mask(func_debug, DEBUG_INST_ALL);
+ func_debug->func_debug_file = stderr;
+ func_debug_set_captured_warnings(func_debug, 0);
+ func_debug_set_output_unbuffered(func_debug, false);
+ func_debug->inst_id = inst_id;
+
+ return 0;
+}
+
+int func_fini_debug(func_debug_t* func_debug)
+{
+ if (func_debug->record_warnings)
+ {
+ func_debug_set_captured_warnings(func_debug, 0);
+ }
+
+#ifndef _FUNC_INCLUDE_WINDOWS_SUPPORT_H
+ if (func_debug->is_gzip && func_debug->func_debug_file)
+ {
+ pclose(func_debug->func_debug_file);
+ func_debug->func_debug_file = NULL;
+ }
+#endif
+
+ return 0;
+}
+
+int func_debug_set_file(func_debug_t* func_debug, const char* filename)
+{
+ int filenameLen = strlen(filename);
+
+ // Open the debug output file
+ ASSERT(filename != NULL);
+#ifndef _FUNC_INCLUDE_WINDOWS_SUPPORT_H
+ if (filenameLen > 3 && strcmp(filename + filenameLen - 3, ".gz") == 0)
+ {
+ char cmd[256];
+
+ snprintf(cmd, sizeof(cmd), "gzip > %s", filename);
+ func_debug->func_debug_file = popen(cmd, "w");
+ func_debug->is_gzip = 1;
+ }
+ else
+ {
+#else
+ {
+#endif
+ func_debug->func_debug_file = fopen(filename, "w");
+ }
+
+ if (!func_debug->func_debug_file)
+ {
+ perror(NULL);
+ FATAL_ERROR("Cannot open debug output file: %s\n", filename);
+ return 1;
+ }
+ if (func_debug->is_output_unbuffered)
+ {
+ setvbuf(func_debug->func_debug_file, nullptr, _IONBF, 0);
+ }
+
+ return 0;
+}
+
+void func_debug_set_verbosity(func_debug_t* func_debug, const char* str)
+{
+ if (!strcasecmp(str, "RESET"))
+ {
+ func_debug_set_verbosity(func_debug, DEBUG_VERB_NONE);
+ return;
+ }
+
+ for (size_t i = 0; i < DEBUG_VERBOSITY_COUNT; i++)
+ {
+ if (!strcasecmp(str, func_debug_verbosity_str_table[i]))
+ {
+ func_debug_set_verbosity(func_debug, func_debug_verbosity_mask_table[i]);
+ return;
+ }
+ }
+
+ FATAL_ERROR("Invalid debug verbosity: %s", str);
+}
+
+void func_debug_set_verbosity(func_debug_t* func_debug, const uint32_t verb)
+{
+ uint32_t new_mask = verb;
+
+ switch (verb)
+ {
+ case DEBUG_VERB_NONE:
+ new_mask = DEBUG_VERB_NONE;
+ break;
+ case DEBUG_VERB_INFO:
+ new_mask = DEBUG_VERB_INFO;
+ break;
+ case DEBUG_VERB_IFACE:
+ new_mask = DEBUG_VERB_IFACE;
+ break;
+ case DEBUG_VERB_HIGH:
+ new_mask |= DEBUG_VERB_HIGH;
+ // Intentional fallthrough
+ case DEBUG_VERB_MED:
+ new_mask |= DEBUG_VERB_MED;
+ // Intentional fallthrough
+ case DEBUG_VERB_LOW:
+ new_mask |= DEBUG_VERB_LOW;
+ new_mask |= DEBUG_VERB_INFO;
+ new_mask |= DEBUG_VERB_IFACE;
+ break;
+ }
+
+ func_debug->func_debug_verbosity = new_mask;
+}
+
+void func_debug_set_suppress_arch_error_mask(func_debug_t* func_debug, const uint32_t suppress)
+{
+ func_debug->func_suppress_arch_error_mask = suppress;
+}
+
+void func_debug_set_mask(func_debug_t* func_debug, const uint64_t mask)
+{
+ if (mask == DEBUG_NONE)
+ func_debug->func_debug_mask = mask;
+ else
+ func_debug->func_debug_mask |= mask;
+
+ // Set a minimum verbosity level
+ if (func_debug->func_debug_verbosity == DEBUG_VERB_NONE)
+ func_debug->func_debug_verbosity = DEBUG_VERB_INFO;
+}
+
+void func_debug_set_inst_mask(func_debug_t* func_debug, const char* mask)
+{
+ uint64_t val;
+
+ val = strtoul(mask, NULL, 0);
+
+ return func_debug_set_inst_mask(func_debug, val);
+}
+
+void func_debug_set_inst_mask(func_debug_t* func_debug, const uint64_t mask)
+{
+ if (mask == 0)
+ func_debug->func_debug_inst_mask = DEBUG_INST_ALL;
+ else
+ func_debug->func_debug_inst_mask = mask;
+}
+
+void func_debug_set_mask(func_debug_t* func_debug, const char* str)
+{
+ if (!strcasecmp(str, "all"))
+ {
+ func_debug_set_mask(func_debug, UINT64_MAX - 1);
+ return;
+ }
+
+ size_t i;
+ for (i = 0; i < DEBUG_MASK_COUNT; i++)
+ {
+ if (!strcasecmp(str, func_debug_mode_str_table[i]))
+ {
+ func_debug_set_mask(func_debug, 1ULL << i);
+ return;
+ }
+ }
+
+ func_debug_print_masks(stderr);
+
+ FATAL_ERROR("Invalid debug mask: %s", str);
+}
+
+void func_debug_print_masks(FILE* out)
+{
+ uint32_t i;
+
+ fprintf(out, "Available debug masks:\n");
+
+ for (i = 0; i < DEBUG_MASK_COUNT; i++)
+ {
+ fprintf(out, "[%d] %s\n", i, func_debug_mode_str_table[i]);
+ }
+}
+
+void func_debug_set_output_unbuffered(func_debug_t* func_debug, const bool is_unbuffered)
+{
+ func_debug->is_output_unbuffered = is_unbuffered;
+}
+
+// Print warnings to the debug file or optionally store them in a buffer instead
+// Note that the buffer is circular and can be overwritten if enough messages are
+// written before removing a warning from the front.
+void func_debug_warning(
+ func_debug_t* func_debug, const char* file, const char* func, const int line, const char* fmt, ...)
+{
+ va_list args;
+ va_start(args, fmt);
+
+ if (func_debug->record_warnings)
+ {
+ // Record to the circular buffer
+ uint32_t len;
+
+ len = snprintf(func_debug->warning_buffer[func_debug->warning_buffer_tail], WARNING_BUFFER_ENTRY_LENGTH,
+ "WARNING AT %s:%d %s(): ", file, line, func);
+ vsnprintf(func_debug->warning_buffer[func_debug->warning_buffer_tail] + len, WARNING_BUFFER_ENTRY_LENGTH - len,
+ fmt, args);
+ func_debug->warning_buffer_tail = (func_debug->warning_buffer_tail + 1) % WARNING_BUFFER_SIZE;
+ }
+ else
+ {
+ // Print to the debug file (e.g., stderr)
+ fprintf(func_debug->func_debug_file, "WARNING AT %s:%d %s():\n", file, line, func);
+ vfprintf(func_debug->func_debug_file, fmt, args);
+ fprintf(func_debug->func_debug_file, "\n");
+ }
+ va_end(args);
+}
+
+// Initialize the warning buffer capture
+int func_debug_set_captured_warnings(func_debug_t* func_debug, uint32_t capture)
+{
+ uint32_t i;
+ func_debug->record_warnings = capture;
+ if (capture)
+ {
+ func_debug->warning_buffer_head = 0;
+ func_debug->warning_buffer_tail = 0;
+
+ for (i = 0; i < WARNING_BUFFER_SIZE; i++)
+ {
+ func_debug->warning_buffer[i] = (char*)calloc(1, WARNING_BUFFER_ENTRY_LENGTH);
+ }
+ }
+ else
+ {
+ for (i = 0; i < WARNING_BUFFER_SIZE; i++)
+ {
+ if (func_debug->warning_buffer[i])
+ {
+ free(func_debug->warning_buffer[i]);
+ func_debug->warning_buffer[i] = NULL;
+ }
+ }
+ }
+
+ return 0;
+}
+
+int func_debug_has_captured_warning(func_debug_t* func_debug)
+{
+ if (func_debug->record_warnings && func_debug->warning_buffer_head != func_debug->warning_buffer_tail)
+ return 1;
+ else
+ return 0;
+}
+
+int func_debug_get_captured_warning(func_debug_t* func_debug, char* buf_ptr, const uint32_t buf_len)
+{
+ if (!func_debug_has_captured_warning(func_debug))
+ return 1;
+
+ strncpy(buf_ptr, func_debug->warning_buffer[func_debug->warning_buffer_head], buf_len);
+
+ func_debug->warning_buffer_head = (func_debug->warning_buffer_head + 1) % WARNING_BUFFER_SIZE;
+
+ return 0;
+}
diff --git a/reference_model/src/func_debug.h b/reference_model/src/func_debug.h
new file mode 100644
index 0000000..2d47462
--- /dev/null
+++ b/reference_model/src/func_debug.h
@@ -0,0 +1,255 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FUNC_DEBUG_H
+#define FUNC_DEBUG_H
+
+#include "debug_types.h"
+#include <assert.h>
+#include <cinttypes>
+#include <signal.h>
+#include <stdio.h>
+
+void func_print_backtrace(FILE* out, int sig = SIGABRT);
+
+void func_enable_signal_handlers();
+
+// Debug content container
+#define WARNING_BUFFER_SIZE 16
+#define WARNING_BUFFER_ENTRY_LENGTH 1024
+
+// STRINGIFY2 is needed expand expression passed to STRINGIFY
+#define STRINGIFY2(s) #s
+#define STRINGIFY(s) STRINGIFY2(s)
+
+// If TRACED_LOG is defined, add file:line to log messages
+#if defined(TRACED_LOG)
+#define WHERE "@" __FILE__ ":" STRINGIFY(__LINE__)
+#else
+#define WHERE
+#endif
+
+#if defined(COLORIZED_LOG)
+#define COL(col, fmt) "\x1b[3" col "m" fmt "\x1b[0m"
+#define COL_FATAL(fmt) COL("1;41", fmt)
+#define COL_WARN(fmt) COL("1;43", fmt)
+#define COL_INFO(fmt) COL("2", fmt)
+#define COL_IFACE(fmt) fmt
+#define COL_LOW(fmt) COL("35", fmt)
+#define COL_MED(fmt) COL("2;33", fmt)
+#define COL_HIGH(fmt) COL("2;32", fmt)
+#else
+#define COL_FATAL(fmt) fmt
+#define COL_WARN(fmt) fmt
+#define COL_INFO(fmt) fmt
+#define COL_IFACE(fmt) fmt
+#define COL_LOW(fmt) fmt
+#define COL_MED(fmt) fmt
+#define COL_HIGH(fmt) fmt
+#endif
+
+struct func_debug_t
+{
+ uint32_t func_debug_verbosity; // What verbosity level is set? (bitmask)
+ uint64_t func_debug_mask; // Which units have debugging enabled? (bitmask)
+ uint64_t func_debug_inst_mask; // Which instances have debugging enabled (bitmask)
+ uint64_t inst_id; // The instance id for multiple model instances
+ uint32_t func_suppress_arch_error_mask; // Which architecture error should be suppressed? (bitmask)
+ FILE* func_debug_file; // Output file
+ uint32_t record_warnings;
+ char* warning_buffer[WARNING_BUFFER_SIZE];
+ uint32_t warning_buffer_head; // next unread message
+ uint32_t warning_buffer_tail; // next message to write
+ uint32_t is_gzip;
+ bool is_output_unbuffered; // should log files be opened with unbuffered I/O.
+};
+
+#ifndef ASSERT
+#define ASSERT(COND) \
+ if (!(COND)) \
+ { \
+ fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, __func__, #COND); \
+ func_print_backtrace(stderr); \
+ assert(COND); \
+ }
+#endif
+
+#ifndef ASSERT_MSG
+#define ASSERT_MSG(COND, fmt, ...) \
+ if (!(COND)) \
+ { \
+ fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, __func__, #COND); \
+ fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
+ func_print_backtrace(stderr); \
+ assert(COND); \
+ }
+#endif
+
+#ifndef ASSERT_MSG_NODE
+#define ASSERT_MSG_NODE(COND, fmt, ...) \
+ if (!(COND)) \
+ { \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \
+ __func__, #COND); \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
+ this->dumpNode(g_func_debug.func_debug_file); \
+ func_print_backtrace(g_func_debug.func_debug_file); \
+ assert(COND); \
+ }
+#endif
+
+// Assertion specific to allocating memory
+#ifndef ASSERT_MEM
+#define ASSERT_MEM(OBJ) \
+ if (!(OBJ)) \
+ { \
+ fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (" #OBJ "): out of memory\n"), __FILE__, __LINE__, \
+ __func__); \
+ func_print_backtrace(stderr); \
+ assert(OBJ); \
+ }
+#endif
+
+#ifndef FATAL_ERROR
+#define FATAL_ERROR(fmt, ...) \
+ fprintf(stderr, COL_FATAL("FATAL ERROR AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \
+ fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
+ func_print_backtrace(stderr); \
+ abort();
+#endif
+
+#ifndef FATAL_ERROR_NODE
+#define FATAL_ERROR_NODE(fmt, ...) \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL("FATAL ERROR AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
+ this->dumpNode(g_func_debug.func_debug_file); \
+ func_print_backtrace(g_func_debug.func_debug_file); \
+ abort();
+#endif
+#ifndef SIMPLE_FATAL_ERROR
+#define SIMPLE_FATAL_ERROR(fmt, ...) \
+ fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
+ exit(1);
+#endif
+
+void func_debug_warning(
+ func_debug_t* func_debug, const char* file, const char* func, const int line, const char* fmt, ...);
+#ifndef WARNING
+#define WARNING(...) func_debug_warning(&g_func_debug, __FILE__, __func__, __LINE__, __VA_ARGS__)
+#endif
+
+#ifndef WARNING_STDERR
+#define WARNING_STDERR(fmt, ...) \
+ fprintf(stderr, COL_WARN("WARNING AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \
+ fprintf(stderr, COL_WARN(fmt) "\n", ##__VA_ARGS__);
+#endif
+
+int func_debug_set_captured_warnings(func_debug_t* func_debug, uint32_t capture);
+
+int func_debug_has_captured_warning(func_debug_t* func_debug);
+
+int func_debug_get_captured_warning(func_debug_t* func_debug, char* buf_ptr, const uint32_t buf_len);
+
+// Is this debug verbosity and unit level enabled?
+// Provide compiler hints that this is unlikely
+// Two versions, depending on whether DEBUG_INSTANCE_EXPR is defined in a file or not
+//
+// For .cpp files whose units have discrete instance IDs, define DEBUG_INSTANCE_EXPR to evalute
+// to the instance ID variable. The use of this define in header files is discouraged.
+
+#ifdef DEBUG_INSTANCE_EXPR
+// Expression for whether the debugging verbosity + debugging unit is enabled for free-form printouts
+#ifdef DEBUG_INSTANCE_EXPR_2
+#define DEBUG_ENABLED(VERB, LEVEL) \
+ (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \
+ (g_func_debug.func_debug_inst_mask & (uint64_t(1) << (DEBUG_INSTANCE_EXPR))) && \
+ (g_func_debug.func_debug_verbosity & (VERB)), \
+ 0))
+// Debug printing macro
+#define DEBUG(VERB, LEVEL, FMT, ...) \
+ if (DEBUG_ENABLED(VERB, LEVEL)) \
+ { \
+ fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d_%02d" WHERE "]: " FMT "\n", \
+ (int)g_func_debug.inst_id, (int)(DEBUG_INSTANCE_EXPR), (int)(DEBUG_INSTANCE_EXPR_2), ##__VA_ARGS__); \
+ }
+
+// Prints just the debugging prefix for properly marking free-form printouts
+#define DEBUG_PREFIX(LEVEL) \
+ fprintf(g_func_debug.func_debug_file, "[%d" #LEVEL "_%02d_%02d" WHERE "]: ", (int)g_func_debug.inst_id, \
+ (int)(DEBUG_INSTANCE_EXPR), (int)(DEBUG_INSTANCE_EXPR_2))
+
+#else // !DEBUG_INSTANCE_EXPR_2
+
+#define DEBUG_ENABLED(VERB, LEVEL) \
+ (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \
+ (g_func_debug.func_debug_inst_mask & (uint64_t(1) << (DEBUG_INSTANCE_EXPR))) && \
+ (g_func_debug.func_debug_verbosity & (VERB)), \
+ 0))
+// Debug printing macro
+#define DEBUG(VERB, LEVEL, FMT, ...) \
+ if (DEBUG_ENABLED(VERB, LEVEL)) \
+ { \
+ fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d" WHERE "]: " FMT "\n", (int)g_func_debug.inst_id, \
+ (int)(DEBUG_INSTANCE_EXPR), ##__VA_ARGS__); \
+ }
+
+// Prints just the debugging prefix for properly marking free-form printouts
+#define DEBUG_PREFIX(LEVEL) \
+ fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d" WHERE "]: ", (int)g_func_debug.inst_id, \
+ (int)(DEBUG_INSTANCE_EXPR))
+
+#endif // DEBUG_INSTANCE_EXPR_2
+
+#else // !DEBUG_INSTANCE_EXPR
+
+// Expression for whether the debugging verbosity + debugging unit is enabled for free-form printouts
+#define DEBUG_ENABLED(VERB, LEVEL) \
+ (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \
+ (g_func_debug.func_debug_verbosity & (VERB)), \
+ 0))
+// Debug printing macro
+#define DEBUG(VERB, LEVEL, FMT, ...) \
+ if (DEBUG_ENABLED(VERB, LEVEL)) \
+ { \
+ fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL WHERE "]: " FMT "\n", (int)g_func_debug.inst_id, \
+ ##__VA_ARGS__); \
+ }
+
+// Prints just the debugging prefix for properly marking free-form printouts
+#define DEBUG_PREFIX(LEVEL) fprintf(g_func_debug.func_debug_file, "[" #LEVEL WHERE "]: ")
+
+#endif
+
+// Macros for different verbosity levels
+#define DEBUG_INFO(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_INFO, LEVEL, COL_INFO(FMT), ##__VA_ARGS__)
+#define DEBUG_IFACE(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_IFACE, LEVEL, COL_IFACE(FMT), ##__VA_ARGS__)
+#define DEBUG_LOW(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_LOW, LEVEL, COL_LOW(FMT), ##__VA_ARGS__)
+#define DEBUG_MED(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_MED, LEVEL, COL_MED(FMT), ##__VA_ARGS__)
+#define DEBUG_HIGH(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_HIGH, LEVEL, COL_HIGH(FMT), ##__VA_ARGS__)
+
+int func_init_debug(func_debug_t*, uint64_t inst_id);
+int func_fini_debug(func_debug_t*);
+int func_debug_set_file(func_debug_t*, const char* filename);
+void func_debug_set_mask(func_debug_t*, const char* str);
+void func_debug_set_mask(func_debug_t*, const uint64_t mask);
+void func_debug_print_masks(FILE* out);
+void func_debug_set_verbosity(func_debug_t*, const char* str);
+void func_debug_set_verbosity(func_debug_t*, const uint32_t verb);
+void func_debug_set_suppress_arch_error_mask(func_debug_t*, const uint32_t suppress);
+void func_debug_set_inst_mask(func_debug_t*, const char* mask);
+void func_debug_set_inst_mask(func_debug_t*, const uint64_t mask);
+void func_debug_set_output_unbuffered(func_debug_t*, const bool is_unbuffered);
+
+#endif
diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc
new file mode 100644
index 0000000..b57b9dd
--- /dev/null
+++ b/reference_model/src/graph_node.cc
@@ -0,0 +1,226 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "graph_node.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+GraphNode::GraphNode(const Op& nodeType_, const uint64_t id_)
+{
+ nodeType = nodeType_;
+ nodeId = id_;
+ inputs.clear();
+ outputs.clear();
+ inputNames.clear();
+ outputNames.clear();
+ clearNodeMarked();
+ evalCount = 0;
+ clearOnNextNodeList();
+ setRequiredOperands(-1, -1);
+ setRequiredRank(-1);
+}
+
+GraphNode::~GraphNode()
+{}
+
+int GraphNode::addInputName(std::string& name)
+{
+ inputNames.push_back(name);
+ return 0;
+}
+
+int GraphNode::addOutputName(std::string& name)
+{
+ outputNames.push_back(name);
+ return 0;
+}
+
+int GraphNode::addInputTensor(Tensor* tens)
+{
+ ASSERT_MSG(tens, "GraphNode::addInputTensor: no tensor provided");
+ inputs.push_back(tens);
+ return 0;
+}
+
+int GraphNode::addOutputTensor(Tensor* tens)
+{
+ ASSERT_MSG(tens, "GraphNode::addOutputTensor: no tensor provided");
+ outputs.push_back(tens);
+ return 0;
+}
+
+int GraphNode::checkTensorAttributes()
+{
+ // Placeholder
+ return 0;
+}
+
+int GraphNode::eval()
+{
+ // Placeholder evaluation function
+ evalCount++;
+
+ // this should be set by derived op
+ for (auto ct : getOutputs())
+ {
+ ct->setIsValid();
+ }
+
+ return 0;
+}
+
+int GraphNode::hasAllInputsReady() const
+{
+ for (size_t i = 0; i < inputs.size(); i++)
+ {
+ if (!inputs[i]->getIsValid())
+ return false;
+ }
+
+ return true;
+}
+
+int GraphNode::hasAllOutputsReady() const
+{
+ for (size_t i = 0; i < outputs.size(); i++)
+ {
+ if (!outputs[i]->getIsValid())
+ return false;
+ }
+
+ return true;
+}
+
+int GraphNode::dumpNode(FILE* out)
+{
+ int i;
+ fprintf(out, "Node type: %s ID: %lu Eval Count: %d On next node list: %d Is marked: %d\n", EnumNamesOp()[nodeType],
+ nodeId, evalCount, onNextNodeList, isMarked);
+
+ i = 0;
+ for (Tensor* ins : inputs)
+ {
+ fprintf(out, " Input[%d] ", i++);
+ ins->dumpTensorParams(out);
+ }
+
+ i = 0;
+ for (Tensor* outs : outputs)
+ {
+ fprintf(out, " Output[%d] ", i++);
+ outs->dumpTensorParams(out);
+ }
+
+ return 0;
+}
+
+int GraphNode::dumpNode(std::ostream& out)
+{
+ int i;
+
+ out << "Node type: " << EnumNamesOp()[nodeType] << " ID: " << nodeId << " Eval count: " << evalCount
+ << " On next node list: " << onNextNodeList << " Is marked: " << isMarked << std::endl;
+
+ out << " Inputs:";
+ for (std::string& name : inputNames)
+ {
+ out << " " << name;
+ }
+ out << std::endl;
+
+ i = 0;
+ for (Tensor* ins : inputs)
+ {
+ out << " Input[" << i++ << "]: ";
+ ins->dumpTensorParams(out);
+ }
+
+ out << " Outputs:";
+ for (std::string& name : outputNames)
+ {
+ out << " " << name;
+ }
+ out << std::endl;
+
+ i = 0;
+ for (Tensor* outs : outputs)
+ {
+ out << " Output[" << i++ << "]: ";
+ outs->dumpTensorParams(out);
+ }
+ return 0;
+}
+
+int GraphNode::printNodeValidationError(const std::string& msg)
+{
+ std::cout << "Operator validation error: " << msg << std::endl;
+ ;
+ dumpNode(std::cout);
+
+ return 0;
+}
+
+int GraphNode::validateRequiredOperands()
+{
+ if (requiredInputCount >= 0 && inputs.size() != (size_t)requiredInputCount)
+ {
+ printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator must have " +
+ std::to_string(requiredInputCount) + " input(s)");
+ return 1;
+ }
+
+ if (requiredOutputCount >= 0 && outputs.size() != (size_t)requiredOutputCount)
+ {
+ printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator output must have exactly " +
+ std::to_string(requiredOutputCount) + " output(s)");
+ return 1;
+ }
+
+ return 0;
+}
+
+int GraphNode::validateRequiredRank(const Tensor* t)
+{
+ if (requiredRankMin >= 0 && requiredRankMax >= 0)
+ {
+ if (t->checkRequiredRank(requiredRankMin, requiredRankMax))
+ {
+ printNodeValidationError(std::string(EnumNamesOp()[nodeType]) +
+ " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" +
+ std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) +
+ "]. tensorName: " + t->getName());
+ return 1;
+ }
+ else
+ {
+ return 0;
+ }
+ }
+
+ if (requiredRankMin >= 0)
+ {
+ if (t->checkRequiredRank(requiredRankMin))
+ {
+ printNodeValidationError(std::string(EnumNamesOp()[nodeType]) +
+ " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " +
+ std::to_string(requiredRankMin) + ". tensorName: " + t->getName());
+ return 1;
+ }
+ }
+
+ return 0;
+}
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
new file mode 100644
index 0000000..5b4a767
--- /dev/null
+++ b/reference_model/src/graph_node.h
@@ -0,0 +1,354 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GRAPH_NODE_H
+#define GRAPH_NODE_H
+
+#include "attribute.h"
+#include "quant_info.h"
+#include "tensor.h"
+#include "tosa_generated.h"
+#include <iostream>
+
+#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP<RANK, DType_##DTYPE>;
+
+#define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
+ template class TosaReference::OP<RANK, DType_##DTYPE1, DType_##DTYPE2>;
+
+#define DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
+ template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE>;
+
+#define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
+ template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>;
+
+#define DEF_INSTANTIATE_ONE_RANK_0_6(OP) \
+ template class TosaReference::OP<0>; \
+ template class TosaReference::OP<1>; \
+ template class TosaReference::OP<2>; \
+ template class TosaReference::OP<3>; \
+ template class TosaReference::OP<4>; \
+ template class TosaReference::OP<5>; \
+ template class TosaReference::OP<6>;
+
+#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<DType_##DTYPE>;
+
+#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2>;
+
+#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
+
+#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
+
+#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2)
+
+#define DEF_INSTANTIATE_RESHAPE(OP, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE)
+
+#define DEF_INSTANTIATE_GATHER(OP, DTYPE) \
+ /* gather op takes input and index rank as template argument */ \
+ /* note output rank = input rank - 1 + index rank */ \
+ /* and max rank allowed in tosa_reference is 6 */ \
+ /* so only specific input and index pair is instantiated */ \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE)
+
+#define INIT_ATTRIBUTE(ATTRIBUTE_NAME) \
+ if (auto p = dynamic_cast<Tosa##ATTRIBUTE_NAME##Attribute*>(attribute_)) \
+ { \
+ attribute = new Tosa##ATTRIBUTE_NAME##Attribute(p); \
+ ASSERT_MEM(attribute); \
+ } \
+ else \
+ { \
+ FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute"); \
+ }
+
+#define INIT_QINFO(QINFO_NAME) \
+ if (auto p = dynamic_cast<Tosa##QINFO_NAME##QuantInfo*>(qinfo_)) \
+ { \
+ qinfo = new Tosa##QINFO_NAME##QuantInfo(p); \
+ ASSERT_MEM(qinfo); \
+ } \
+ else \
+ { \
+ qinfo = nullptr; \
+ }
+
+namespace TosaReference
+{
+
+// Nodes in the graph (e.g., tosa operators) are defined with this base
+// class.
+class GraphNode
+{
+public:
+ GraphNode(const tosa::Op& nodeType, const uint64_t id_);
+ virtual ~GraphNode();
+
+ int addInputName(std::string& name);
+ int addOutputName(std::string& name);
+
+ int addInputTensor(Tensor* tens);
+ int addOutputTensor(Tensor* tens);
+
+ // Validate that the input tensors match properly
+ // in their types, attributes, rank, etc well enough to be
+ // processed.
+ //
+ // This function should be pure virtual (eventually) in order to force
+ // derivative operators to implement the check, but we'll initially
+ // provide a default function so that GraphNode can be instantiated
+ // directly for testing purposes.
+ virtual int checkTensorAttributes();
+
+ // Evalute the node/operator
+ virtual int eval();
+
+ int hasAllInputsReady() const;
+ int hasAllOutputsReady() const;
+
+ int dumpNode(FILE* out);
+ int dumpNode(std::ostream& out);
+
+ int setNodeMarked()
+ {
+ isMarked = true;
+ return 0;
+ }
+
+ int getNodeMarked() const
+ {
+ return isMarked;
+ }
+
+ int clearNodeMarked()
+ {
+ isMarked = false;
+ return 0;
+ }
+
+ int getEvalCount() const
+ {
+ return evalCount;
+ }
+
+ uint64_t getID() const
+ {
+ return nodeId;
+ }
+
+ std::vector<std::string>& getInputNames()
+ {
+ return inputNames;
+ }
+
+ std::vector<std::string>& getOutputNames()
+ {
+ return outputNames;
+ }
+
+ std::vector<Tensor*>& getOutputs()
+ {
+ return outputs;
+ }
+
+ std::vector<Tensor*>& getInputs()
+ {
+ return inputs;
+ }
+
+ int getOnNextNodeList() const
+ {
+ return onNextNodeList;
+ }
+
+ int setOnNextNodeList()
+ {
+ onNextNodeList = true;
+ return 0;
+ }
+
+ int clearOnNextNodeList()
+ {
+ onNextNodeList = false;
+ return 0;
+ }
+
+ tosa::Op getOp() const
+ {
+ return nodeType;
+ }
+
+protected:
+ // Print out a node validation error
+ int printNodeValidationError(const std::string& msg);
+
+ int setRequiredOperands(const int in, const int out)
+ {
+ requiredInputCount = in;
+ requiredOutputCount = out;
+ return 0;
+ }
+
+ int setRequiredRank(const int min, const int max = -1)
+ {
+ if (max == -1)
+ {
+ requiredRankMin = requiredRankMax = min;
+ }
+ else
+ {
+ requiredRankMin = min;
+ requiredRankMax = max;
+ }
+
+ ASSERT_MSG(requiredRankMin <= requiredRankMax,
+ "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin,
+ requiredRankMax);
+
+ return 0;
+ }
+
+ int validateRequiredOperands();
+ int validateRequiredRank(const Tensor* t);
+
+ // Description of the node type (e.g., CONST, CONV2D, etc...)
+ tosa::Op nodeType;
+
+ // A list of input tensor names
+ std::vector<std::string> inputNames;
+
+ // A list of the output tensor names
+ std::vector<std::string> outputNames;
+
+ // A list of the input tensors (after names have been matched up)
+ std::vector<Tensor*> inputs;
+
+ // A list of the output tensors (after names have been matched up)
+ std::vector<Tensor*> outputs;
+
+ // Unique node ID for debugging
+ uint64_t nodeId;
+
+ // Flag used for graph analysis
+ int isMarked;
+
+ // Number of times eval() has been called for this node
+ int evalCount;
+
+ // Flag indicating that this node is ready and is on the
+ // next-node list.
+ int onNextNodeList;
+
+ // Required input/output tensor counts for node validation
+ // -1 means any number is allowed
+ int requiredInputCount;
+ int requiredOutputCount;
+
+ // Required rank ranges for input/output tensors
+ // -1 means n/a
+ int requiredRankMin;
+ int requiredRankMax;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
new file mode 100644
index 0000000..ec2fdc9
--- /dev/null
+++ b/reference_model/src/main.cpp
@@ -0,0 +1,295 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+
+#include "flatbuffers/idl.h"
+#include "flatbuffers/util.h"
+#include "model_common.h"
+#include "ops/op_factory.h"
+#include "subgraph_traverser.h"
+#include "tosa_serialization_handler.h"
+#include <Eigen/CXX11/Tensor>
+#include <iostream>
+
+using namespace TosaReference;
+using namespace tosa;
+
+// Global instantiation of configuration and debug objects
+func_config_t g_func_config;
+func_debug_t g_func_debug;
+
+int readInputTensors(SubgraphTraverser& gt);
+int writeFinalTensors(SubgraphTraverser& gt);
+int loadGraph(TosaSerializationHandler& tsh);
+
+int main(int argc, const char** argv)
+{
+ // Initialize configuration and debug subsystems
+ func_model_init_config();
+ func_model_set_default_config(&g_func_config);
+ func_init_debug(&g_func_debug, 0);
+ TosaSerializationHandler tsh;
+
+ if (func_model_parse_cmd_line(&g_func_config, &g_func_debug, argc, argv))
+ {
+ return 1;
+ }
+
+ if (loadGraph(tsh))
+ {
+ SIMPLE_FATAL_ERROR("Unable to load graph");
+ }
+
+ // load json first since it's easier debugging
+ SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh);
+
+ if (main_gt.initializeGraph())
+ {
+ SIMPLE_FATAL_ERROR("Unable to initialize graph traverser: \"main\"");
+ }
+
+ if (main_gt.linkTensorsAndNodes())
+ {
+ SIMPLE_FATAL_ERROR("Failed to link tensors and nodes");
+ }
+
+ if (main_gt.validateGraph())
+ {
+ SIMPLE_FATAL_ERROR("Failed to validate graph");
+ }
+
+ if (g_func_config.validate_only)
+ {
+ goto done;
+ }
+
+ if (readInputTensors(main_gt))
+ {
+ SIMPLE_FATAL_ERROR("Unable to read input tensors");
+ }
+
+ if (g_func_config.eval)
+ {
+
+ if (main_gt.evaluateAll())
+ {
+ SIMPLE_FATAL_ERROR("Error evaluating network. Giving up.");
+ }
+
+ // make sure output tensor is evaluated and show its value
+ int num_output_tensors = main_gt.getNumOutputTensors();
+ bool all_output_valid = true;
+ for (int i = 0; i < num_output_tensors; i++)
+ {
+ const Tensor* ct = main_gt.getOutputTensor(i);
+ ASSERT_MEM(ct);
+ if (!ct->getIsValid())
+ {
+ ct->dumpTensorParams(g_func_debug.func_debug_file);
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ ct->dumpTensor(g_func_debug.func_debug_file);
+ }
+ all_output_valid = false;
+ }
+ }
+ if (!all_output_valid)
+ {
+ main_gt.dumpGraph(g_func_debug.func_debug_file);
+ SIMPLE_FATAL_ERROR(
+ "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation.");
+ }
+
+ if (g_func_config.output_tensors)
+ {
+ if (writeFinalTensors(main_gt))
+ {
+ WARNING("Errors encountered in saving output tensors");
+ }
+ }
+ }
+
+done:
+ func_fini_debug(&g_func_debug);
+ func_model_config_cleanup();
+
+ return 0;
+}
+
+int loadGraph(TosaSerializationHandler& tsh)
+{
+ char graph_fullname[1024];
+
+ snprintf(graph_fullname, sizeof(graph_fullname), "%s/%s", g_func_config.subgraph_dir, g_func_config.subgraph_file);
+
+ if (strlen(graph_fullname) <= 2)
+ {
+ func_model_print_help(stderr);
+ SIMPLE_FATAL_ERROR("Missing required argument: Check -Csubgraph_file=");
+ }
+
+ const char JSON_EXT[] = ".json";
+ int is_json = 0;
+ {
+ // look for JSON file extension
+ size_t suffix_len = strlen(JSON_EXT);
+ size_t str_len = strlen(graph_fullname);
+
+ if (str_len > suffix_len && strncasecmp(graph_fullname + (str_len - suffix_len), JSON_EXT, suffix_len) == 0)
+ {
+ is_json = 1;
+ }
+ }
+
+ if (is_json)
+ {
+ if (tsh.LoadFileSchema(g_func_config.operator_fbs))
+ {
+ SIMPLE_FATAL_ERROR(
+ "\nJSON file detected. Unable to load TOSA flatbuffer schema from: %s\nCheck -Coperator_fbs=",
+ g_func_config.operator_fbs);
+ }
+
+ if (tsh.LoadFileJson(graph_fullname))
+ {
+ SIMPLE_FATAL_ERROR("\nError loading JSON graph file: %s\nCheck -Csubgraph_file= and -Csubgraph_dir=",
+ graph_fullname);
+ }
+ }
+ else
+ {
+ if (tsh.LoadFileTosaFlatbuffer(graph_fullname))
+ {
+ SIMPLE_FATAL_ERROR("\nError loading TOSA flatbuffer file: %s\nCheck -Csubgraph_file= and -Csubgraph_dir=",
+ graph_fullname);
+ }
+ }
+
+ return 0;
+}
+
+int readInputTensors(SubgraphTraverser& gt)
+{
+ int tensorCount = gt.getNumInputTensors();
+ Tensor* tensor;
+ char filename[1024];
+
+ // assuming filename doesn't have colons(:)
+ std::map<std::string, std::string> input_tensor_map;
+ std::string raw_str(g_func_config.input_tensor);
+ std::string name, npy;
+ bool last_pair = false;
+
+ std::string::size_type pair_start = 0, pair_end, colons_pos;
+ do
+ {
+ pair_end = raw_str.find(',', pair_start);
+ if (pair_end == std::string::npos)
+ last_pair = true;
+
+ colons_pos = raw_str.find(':', pair_start);
+
+ name = raw_str.substr(pair_start, colons_pos - pair_start);
+ npy = raw_str.substr(colons_pos + 1, pair_end - colons_pos - 1);
+
+ // Empty strings can make it to here
+ if (name.length() == 0 || npy.length() == 0)
+ break;
+
+ input_tensor_map[name] = npy;
+
+ pair_start = pair_end + 1; // skip colons
+ } while (!last_pair);
+
+ if ((size_t)tensorCount != input_tensor_map.size())
+ {
+ WARNING("graph has %lu input placeholders, but %lu initialized", tensorCount, input_tensor_map.size());
+ return 1;
+ }
+
+ for (auto& tensor_pair : input_tensor_map)
+ {
+ tensor = gt.getInputTensorByName(tensor_pair.first);
+ if (!tensor)
+ {
+ WARNING("Unable to find input tensor %s", tensor_pair.first.c_str());
+ return 1;
+ }
+
+ snprintf(filename, sizeof(filename), "%s/%s", g_func_config.input_dir, tensor_pair.second.c_str());
+
+ DEBUG_MED(GT, "Loading input tensor %s from filename: %s", tensor->getName().c_str(), filename);
+
+ if (tensor->allocate())
+ {
+ WARNING("Fail to allocate tensor %s", tensor->getName().c_str());
+ return 1;
+ }
+
+ if (tensor->readFromNpyFile(filename))
+ {
+ WARNING("Unable to read input tensor %s from filename: %s", tensor->getName().c_str(), filename);
+ tensor->dumpTensorParams(g_func_debug.func_debug_file);
+ return 1;
+ }
+
+ // Push ready consumers to the next node list
+ for (auto gn : tensor->getConsumers())
+ {
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
+ {
+ gt.addToNextNodeList(gn);
+ }
+ }
+ }
+
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ gt.dumpNextNodeList(g_func_debug.func_debug_file);
+ }
+
+ return 0;
+}
+
+int writeFinalTensors(SubgraphTraverser& gt)
+{
+ int tensorCount = gt.getNumOutputTensors();
+ const Tensor* tensor;
+ char filename[1024];
+
+ for (int i = 0; i < tensorCount; i++)
+ {
+ tensor = gt.getOutputTensor(i);
+ if (!tensor)
+ {
+ WARNING("Unable to find output tensor[%d]", i);
+ return 1;
+ }
+
+ snprintf(filename, sizeof(filename), "%s/%s%s.npy", g_func_config.output_dir,
+ g_func_config.output_tensor_prefix, tensor->getName().c_str());
+
+ DEBUG_MED(GT, "Writing output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);
+
+ if (tensor->writeToNpyFile(filename))
+ {
+ WARNING("Unable to write output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);
+ return 1;
+ }
+ }
+
+ return 0;
+}
diff --git a/reference_model/src/model_common.h b/reference_model/src/model_common.h
new file mode 100644
index 0000000..d6dab6d
--- /dev/null
+++ b/reference_model/src/model_common.h
@@ -0,0 +1,28 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef MODEL_COMMON_H
+#define MODEL_COMMON_H
+
+#include <iostream>
+#include <stdio.h>
+
+#include "func_config.h"
+#include "func_debug.h"
+
+extern func_config_t g_func_config;
+extern func_debug_t g_func_debug;
+
+#endif
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
new file mode 100644
index 0000000..bca9507
--- /dev/null
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -0,0 +1,118 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "activation_funcs.h"
+#include "quant_util.h"
+#include "template_types.h"
+#include <cmath>
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+int OpClamp<Rank, Dtype>::register_fcn()
+{
+
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ {
+ InEigenType min = (InEigenType)attribute->min_fp();
+ InEigenType max = (InEigenType)attribute->max_fp();
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+ }
+ break;
+ case DType_AINT8:
+ case DType_INT16:
+ {
+ InEigenType min = (InEigenType)attribute->min_int();
+ InEigenType max = (InEigenType)attribute->max_int();
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+ }
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReluN<Rank, Dtype>::register_fcn()
+{
+
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ {
+ InEigenType N = (InEigenType)attribute->max_fp();
+ this->fcn = [N](InEigenType a) -> OutEigenType { return a >= 0 ? (a <= N ? a : N) : 0; };
+ }
+ break;
+ case DType_INT32:
+ {
+ InEigenType N = (InEigenType)attribute->max_int();
+ this->fcn = [N](InEigenType a) -> OutEigenType { return a >= 0 ? (a <= N ? a : N) : 0; };
+ }
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSigmoid<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpTanh<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT);
diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h
new file mode 100644
index 0000000..b051b9d
--- /dev/null
+++ b/reference_model/src/ops/activation_funcs.h
@@ -0,0 +1,101 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_ACTIVATION_FUNCS_H
+#define OPS_ACTIVATION_FUNCS_H
+
+#include "ewise_unary.h"
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class OpClamp : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpClamp(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(Op_CLAMP, id_)
+ {
+ INIT_ATTRIBUTE(Clamp);
+ register_fcn();
+ }
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+
+protected:
+ TosaClampAttribute* attribute;
+};
+
+template <int Rank, DType Dtype>
+class OpReluN : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpReluN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(Op_RELUN, id_)
+ {
+ INIT_ATTRIBUTE(ReluN);
+ register_fcn();
+ }
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+
+protected:
+ TosaReluNAttribute* attribute;
+};
+
+template <int Rank, DType Dtype>
+class OpSigmoid : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpSigmoid(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(Op_SIGMOID, id_)
+ {
+ register_fcn();
+ }
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+};
+
+template <int Rank, DType Dtype>
+class OpTanh : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpTanh(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(Op_TANH, id_)
+ {
+ register_fcn();
+ }
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc
new file mode 100644
index 0000000..402e152
--- /dev/null
+++ b/reference_model/src/ops/comparison.cc
@@ -0,0 +1,81 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "comparison.h"
+#include "arith_util.h"
+#include "quant_util.h"
+#include "template_types.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+int OpEqual<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpGreater<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpGreaterEqual<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h
new file mode 100644
index 0000000..e75b1a6
--- /dev/null
+++ b/reference_model/src/ops/comparison.h
@@ -0,0 +1,71 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_COMPARISON_H
+#define OPS_COMPARISON_H
+
+#include "ewise_binary.h"
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class OpEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
+{
+public:
+ OpEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(Op_EQUAL, qinfo_, id_)
+ {
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ virtual int register_fcn();
+};
+
+template <int Rank, DType Dtype>
+class OpGreater : public BinaryNode<Rank, Dtype, DType_BOOL>
+{
+public:
+ OpGreater(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(Op_GREATER, qinfo_, id_)
+ {
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ virtual int register_fcn();
+};
+
+template <int Rank, DType Dtype>
+class OpGreaterEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
+{
+public:
+ OpGreaterEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(Op_EQUAL, qinfo_, id_)
+ {
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ virtual int register_fcn();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
new file mode 100644
index 0000000..9d5db40
--- /dev/null
+++ b/reference_model/src/ops/control_flow.cc
@@ -0,0 +1,353 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "control_flow.h"
+#include "subgraph_traverser.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+OpControlFlow::OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ tsh = tsh_;
+}
+
+OpControlFlow::~OpControlFlow()
+{}
+
+int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block,
+ std::vector<TosaReference::Tensor*>& block_inputs,
+ std::vector<TosaReference::Tensor*>& block_outputs)
+{
+ std::string block_name = block->GetName();
+
+ DEBUG_MED(OP, "Evaluating block %s", block_name.c_str());
+
+ SubgraphTraverser gt(block, tsh);
+
+ if (gt.initializeGraph())
+ {
+ FATAL_ERROR("Unable to initialize graph traverser for block %s", block_name.c_str());
+ }
+
+ if (gt.linkTensorsAndNodes())
+ {
+ FATAL_ERROR("Failed to link tensors and nodes for block %s", block_name.c_str());
+ }
+
+ if (gt.validateGraph())
+ {
+ FATAL_ERROR("Failed to validate subgraph for block %s", block_name.c_str());
+ }
+
+ int num_input_tensors = gt.getNumInputTensors();
+ int num_output_tensors = gt.getNumOutputTensors();
+
+ for (size_t i = 0; i < block_inputs.size(); i++)
+ {
+ DEBUG_HIGH(OP, "Input[%ld]: %s", i, block_inputs[i]->getName().c_str());
+ }
+ for (size_t i = 0; i < block_outputs.size(); i++)
+ {
+ DEBUG_HIGH(OP, "Output[%ld]: %s", i, block_outputs[i]->getName().c_str());
+ }
+
+ ASSERT_MSG((size_t)num_input_tensors == block_inputs.size(),
+ "op block %s inputs[%lu] does not match with graph traverser's inputs[%d]", block_name.c_str(),
+ block_inputs.size(), num_input_tensors);
+ ASSERT_MSG((size_t)num_output_tensors == block_outputs.size(),
+ "op block %s outputs[%lu] does not match with graph traverser's outputs[%d]", block_name.c_str(),
+ block_outputs.size(), num_output_tensors);
+
+ // set graph traverser's input = basic block's input
+ for (int i = 0; i < num_input_tensors; i++)
+ {
+ TosaReference::Tensor* tensor = gt.getInputTensor(i);
+ ASSERT_MSG(!tensor->is_allocated(), "block %s input tensors are unexpectedly initialized before",
+ block_name.c_str());
+
+ if (tensor->allocate())
+ {
+ WARNING("Fail to allocate tensor %s", tensor->getName().c_str());
+ return 1;
+ }
+
+ if (tensor->copyValueFrom(block_inputs[i]))
+ {
+ WARNING("Fail to copy tensor value %s -> %s", block_inputs[i]->getName().c_str(),
+ tensor->getName().c_str());
+ return 1;
+ }
+
+ // Push ready consumers to the next node list
+ for (auto gn : tensor->getConsumers())
+ {
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
+ {
+ gt.addToNextNodeList(gn);
+ }
+ }
+ }
+
+ if (gt.evaluateAll())
+ {
+ FATAL_ERROR("Error evaluating network. Giving up.");
+ }
+
+ // make sure output tensor is evaluated and show its value
+ bool all_output_valid = true;
+ for (int i = 0; i < num_output_tensors; i++)
+ {
+ const TosaReference::Tensor* ct = gt.getOutputTensor(i);
+ ASSERT_MEM(ct);
+ if (!ct->getIsValid())
+ {
+ ct->dumpTensorParams(g_func_debug.func_debug_file);
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ ct->dumpTensor(g_func_debug.func_debug_file);
+ }
+ all_output_valid = false;
+ }
+ }
+ if (!all_output_valid)
+ {
+ gt.dumpGraph(g_func_debug.func_debug_file);
+ FATAL_ERROR("SubgraphTraverser \"%s\" error: Output tensors are not all valid at the end of evaluation.",
+ block_name.c_str());
+ }
+
+ // set basic block's output = subgraph_traverser's output
+ for (int i = 0; i < num_output_tensors; i++)
+ {
+ TosaReference::Tensor* tensor = gt.getOutputTensor(i);
+ ASSERT_MSG(tensor->is_allocated(), "tensor %s is not allocated", tensor->getName().c_str());
+
+ if (block_outputs[i]->copyValueFrom(tensor))
+ {
+ WARNING("Fail to copy tensor value %s -> %s", tensor->getName().c_str(), outputs[i]->getName().c_str());
+ return 1;
+ }
+ }
+ return 0;
+}
+
+OpCondIf::OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_)
+ : OpControlFlow(tsh_, Op_COND_IF, id_)
+{
+ INIT_ATTRIBUTE(CondIf);
+}
+
+OpCondIf::~OpCondIf()
+{
+ if (attribute)
+ delete attribute;
+}
+
+int OpCondIf::checkTensorAttributes()
+{
+ if (getInputs().size() < 1)
+ {
+ WARNING("OpCondIf: must have at least 1 operand");
+ return 1;
+ }
+
+ if (inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0)
+ {
+ WARNING("OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()],
+ inputs[0]->getRank());
+ return 1;
+ }
+
+ cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
+ ASSERT_MEM(cond);
+
+ then_block = tsh->GetBlockByName(attribute->then_branch());
+ else_block = tsh->GetBlockByName(attribute->else_branch());
+
+ if (!then_block)
+ {
+ WARNING("OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str());
+ return 1;
+ }
+
+ if (!else_block)
+ {
+ WARNING("OpCondIf: fail to resolve else_branch %s", attribute->else_branch().c_str());
+ return 1;
+ }
+
+ return 0;
+}
+
+int OpCondIf::eval()
+{
+ bool cond_val = cond->getTensor()(0);
+ std::vector<TosaReference::Tensor*> block_inputs(getInputs().begin() + 1, getInputs().end());
+
+ if (cond_val)
+ {
+ if (evalBlock(then_block, block_inputs, getOutputs()))
+ {
+ WARNING("OpCondIf: Fail to evaluate then branch block %s", attribute->then_branch().c_str());
+ return 1;
+ }
+ }
+ else
+ {
+ if (evalBlock(else_block, block_inputs, getOutputs()))
+ {
+ WARNING("OpCondIf: Fail to evaluate else branch block %s", attribute->else_branch().c_str());
+ return 1;
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+OpWhileLoop::OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_)
+ : OpControlFlow(tsh_, Op_WHILE_LOOP, id_)
+{
+ INIT_ATTRIBUTE(WhileLoop);
+}
+
+OpWhileLoop::~OpWhileLoop()
+{
+ if (attribute)
+ delete attribute;
+}
+
+int OpWhileLoop::checkTensorAttributes()
+{
+ if (getInputs().size() <= 0)
+ {
+ WARNING("OpWhileLoop: must have at least 1 operands");
+ return 1;
+ }
+
+ if (getInputs().size() != getOutputs().size())
+ {
+ WARNING("OpWhileLoop: inputs and outputs size must match");
+ return 1;
+ }
+
+ cond_block = tsh->GetBlockByName(attribute->cond_branch());
+ body_block = tsh->GetBlockByName(attribute->body_branch());
+
+ if (!cond_block)
+ {
+ WARNING("OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str());
+ return 1;
+ }
+
+ if (!body_block)
+ {
+ WARNING("OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str());
+ return 1;
+ }
+
+ if (cond_block->GetOutputs().size() != 1)
+ {
+ WARNING("OpWhileLoop: invalid cond_block output size %lu", cond_block->GetOutputs().size());
+ return 1;
+ }
+
+ TosaSerializationTensor* cond_output_tensor = cond_block->GetTensorByName(cond_block->GetOutputs()[0]);
+
+ if (!cond_output_tensor)
+ {
+ WARNING("OpWhileLoop: fail to resolve cond_block's output tensor %s", cond_block->GetOutputs()[0].c_str());
+ return 1;
+ }
+
+ if (cond_output_tensor->GetDtype() != DType_BOOL)
+ {
+ WARNING("OpWhileLoop: invalid cond_block's output tensor data type %s",
+ EnumNamesDType()[cond_output_tensor->GetDtype()]);
+ return 1;
+ }
+ if (cond_output_tensor->GetShape().size() != 0)
+ {
+ WARNING("OpWhileLoop: invalid cond_block's output rank %lu", cond_output_tensor->GetShape().size());
+ return 1;
+ }
+
+ return 0;
+}
+
+int OpWhileLoop::eval()
+{
+
+ TosaReference::Tensor0<bool> cond_output_ctensor(
+ std::string("cond_output"), DType_BOOL, std::vector<Usage>({ Usage_ACTIVATION }),
+ std::vector<Format>({ Format_UNKNOWN }), std::vector<int32_t>({}), false);
+
+ cond_output_ctensor.allocate();
+ std::vector<TosaReference::Tensor*> cond_block_outputs;
+ cond_block_outputs.push_back(&cond_output_ctensor);
+
+ size_t num_input_output = getInputs().size();
+ size_t eval_count = 0;
+
+ while (eval_count++ < MAX_WHILE_LOOP_ITERATION)
+ {
+ if (evalBlock(cond_block, getInputs(), cond_block_outputs))
+ {
+ WARNING("OpWhileLoop: Fail to evaluate cond block %s", attribute->cond_branch().c_str());
+ return 1;
+ }
+ bool cond_val = cond_output_ctensor.getTensor()(0);
+ DEBUG_HIGH(OP, "Conditional block value: %d", cond_val);
+
+ if (cond_val)
+ {
+ if (evalBlock(body_block, getInputs(), getOutputs()))
+ {
+ WARNING("OpWhileLoop: Fail to evaluate body block %s", attribute->body_branch().c_str());
+ return 1;
+ }
+
+ // assigning output tensors value back to input tensors value for next iteration
+ for (size_t i = 0; i < num_input_output; i++)
+ {
+ if (getInputs()[i]->copyValueFrom(getOutputs()[i]))
+ {
+ WARNING("Fail to copy tensor value %s -> %s", getOutputs()[i]->getName().c_str(),
+ getInputs()[i]->getName().c_str());
+ return 1;
+ }
+ }
+ }
+ else
+ {
+ // in last iteration or the case it never evaluates body block
+ // assign input tensors value to output tensors
+ for (size_t i = 0; i < num_input_output; i++)
+ {
+ if (getOutputs()[i]->copyValueFrom(getInputs()[i]))
+ {
+ WARNING("Fail to copy tensor value %s -> %s", getInputs()[i]->getName().c_str(),
+ getOutputs()[i]->getName().c_str());
+ return 1;
+ }
+ }
+ break;
+ }
+ }
+
+ return GraphNode::eval();
+}
diff --git a/reference_model/src/ops/control_flow.h b/reference_model/src/ops/control_flow.h
new file mode 100644
index 0000000..14c11bc
--- /dev/null
+++ b/reference_model/src/ops/control_flow.h
@@ -0,0 +1,72 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_CONTROL_FLOW_H
+#define OPS_CONTROL_FLOW_H
+
+#include "graph_node.h"
+
+#define MAX_WHILE_LOOP_ITERATION 10000
+
+namespace TosaReference
+{
+class OpControlFlow : public GraphNode
+{
+public:
+ OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_);
+ ~OpControlFlow();
+
+ virtual int evalBlock(TosaSerializationBasicBlock* block,
+ std::vector<TosaReference::Tensor*>& block_inputs,
+ std::vector<TosaReference::Tensor*>& block_outputs);
+
+protected:
+ TosaSerializationHandler* tsh;
+};
+
+class OpCondIf : public OpControlFlow
+{
+public:
+ OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_);
+ virtual ~OpCondIf();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+protected:
+ TosaCondIfAttribute* attribute;
+ TosaReference::Tensor0<bool>* cond;
+ TosaSerializationBasicBlock* then_block;
+ TosaSerializationBasicBlock* else_block;
+};
+
+class OpWhileLoop : public OpControlFlow
+{
+public:
+ OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_);
+ virtual ~OpWhileLoop();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+protected:
+ TosaWhileLoopAttribute* attribute;
+ TosaSerializationBasicBlock* cond_block;
+ TosaSerializationBasicBlock* body_block;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/custom.cc b/reference_model/src/ops/custom.cc
new file mode 100644
index 0000000..5c4f29b
--- /dev/null
+++ b/reference_model/src/ops/custom.cc
@@ -0,0 +1,40 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "custom.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+OpCustom::OpCustom(uint64_t id_)
+ : GraphNode(Op_CUSTOM, id_)
+{}
+
+OpCustom::~OpCustom()
+{}
+
+int OpCustom::checkTensorAttributes()
+{
+ return 0;
+}
+
+int OpCustom::eval()
+{
+ FATAL_ERROR_NODE("not supported yet");
+
+ // Evaluation is trivial for constants
+ return GraphNode::eval();
+}
diff --git a/reference_model/src/ops/custom.h b/reference_model/src/ops/custom.h
new file mode 100644
index 0000000..b1085a5
--- /dev/null
+++ b/reference_model/src/ops/custom.h
@@ -0,0 +1,38 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_CUSTOM_H
+#define OPS_CUSTOM_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+class OpCustom : public GraphNode
+{
+public:
+ OpCustom(uint64_t id_);
+ virtual ~OpCustom();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
new file mode 100644
index 0000000..32029b9
--- /dev/null
+++ b/reference_model/src/ops/data_layout.cc
@@ -0,0 +1,644 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "data_layout.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+OpConcat<Rank, Dtype>::OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_CONCAT, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(1, 6);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, DType Dtype>
+OpConcat<Rank, Dtype>::~OpConcat()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpConcat<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types and rank
+ // inputs[0] and inputs[1] should also match type and rank
+ if (inputs[0]->matchRankType(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Concat operator input ranks and types must match");
+ return 1;
+ }
+
+ lhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ rhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (attribute->axis() < 0 || (size_t)attribute->axis() >= rhs->getShape().size())
+ {
+ printNodeValidationError("Axis is beyond input tensor rank");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpConcat<Rank, Dtype>::eval()
+{
+
+ int32_t reversed_axis = Rank - 1 - attribute->axis();
+
+ for (int32_t d = 0; d < Rank; d++)
+ {
+ reverser[d] = Rank - 1 - d;
+ }
+
+ TIn lhs_reversed = lhs->getTensor().shuffle(reverser);
+ TIn rhs_reversed = rhs->getTensor().shuffle(reverser);
+
+ TIn reversed_result = lhs_reversed.concatenate(rhs_reversed, reversed_axis);
+ out->getTensor() = reversed_result.shuffle(reverser);
+ // out->getTensor() = lhs->getTensor().concatenate(rhs->getTensor(), axis);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpPad<Rank, Dtype>::OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_PAD, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(0, 6);
+
+ INIT_QINFO(Pad);
+}
+
+template <int Rank, DType Dtype>
+OpPad<Rank, Dtype>::~OpPad()
+{
+ if (qinfo)
+ delete qinfo;
+}
+
+template <int Rank, DType Dtype>
+int OpPad<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output type and rank");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ TosaReference::TensorTemplate<ETensor2<int32_t>>* paddings =
+ dynamic_cast<TosaReference::TensorTemplate<ETensor2<int32_t>>*>(inputs[1]);
+
+ for (int i = 0; i < Rank; i++)
+ {
+ paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1));
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpPad<Rank, Dtype>::eval()
+{
+ InEigenType pad_value = 0;
+ if (this->qinfo)
+ {
+ pad_value = (InEigenType)this->qinfo->input_zp();
+ }
+
+ this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value);
+
+ return GraphNode::eval();
+}
+
+template <int InRank, int OutRank, DType Dtype>
+OpReshape<InRank, OutRank, Dtype>::OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_RESHAPE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Reshape);
+}
+
+template <int InRank, int OutRank, DType Dtype>
+OpReshape<InRank, OutRank, Dtype>::~OpReshape()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int InRank, int OutRank, DType Dtype>
+int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
+{
+ uint32_t minusOneCount = 0;
+
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("OpReshape: Input and output types must match");
+ return 1;
+ }
+
+ for (uint32_t d = 0; d < OutRank; d++)
+ {
+ if (attribute->shape()[d] == -1)
+ {
+ minusOneCount++;
+ }
+ }
+
+ if (minusOneCount > 1)
+ {
+ printNodeValidationError("OpReshape: new shape has more than one -1 dimension");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ return 0;
+}
+
+template <int InRank, int OutRank, DType Dtype>
+int OpReshape<InRank, OutRank, Dtype>::eval()
+{
+ uint32_t remainingSize = in->getElementCount();
+
+ // If there is a -1 dimension, find the remainder in one pass over the output shape
+ for (int32_t d = 0; d < OutRank; d++)
+ {
+ if (attribute->shape()[d] != -1)
+ {
+ remainingSize = remainingSize / attribute->shape()[d];
+ }
+ }
+
+ for (int32_t d = 0; d < OutRank; d++)
+ {
+ array_shape[d] = attribute->shape()[OutRank - 1 - d];
+ out_reverser[d] = OutRank - 1 - d;
+
+ // Jam in the remainder here
+ if (array_shape[d] == -1)
+ {
+ array_shape[d] = remainingSize;
+ }
+ }
+
+ for (int32_t d = 0; d < InRank; d++)
+ {
+ in_reverser[d] = InRank - 1 - d;
+ }
+
+ // Eigen Tensor is col-major, and we're referencing row-major result
+ // need to reverse it to row-major before reshape, and perform another reverse afterward
+
+ // input tensor rank 0 can't do .shuffle(), need to be handled otherwise
+ TIn in_reversed;
+ if (InRank > 1)
+ {
+ in_reversed = in->getTensor().shuffle(in_reverser);
+ }
+ else
+ {
+ in_reversed = in->getTensor();
+ }
+
+ TOut in_reshaped = in_reversed.reshape(array_shape);
+
+ // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise
+ if (OutRank > 1)
+ {
+ out->getTensor() = in_reshaped.shuffle(out_reverser);
+ }
+ else
+ {
+ out->getTensor() = in_reshaped;
+ }
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpReverse<Rank, Dtype>::OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_REVERSE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(1, 6);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, DType Dtype>
+OpReverse<Rank, Dtype>::~OpReverse()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpReverse<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankTypeShape(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output rank/type/shape");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
+ {
+ printNodeValidationError("Reverse axis must between [0, input_rank - 1]");
+ return 1;
+ }
+
+ // transform list of axis into true or false list
+ // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false]
+ for (int i = 0; i < Rank; i++)
+ {
+ reverse_array[i] = false;
+ }
+ reverse_array[attribute->axis()] = true;
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReverse<Rank, Dtype>::eval()
+{
+ out->getTensor() = in->getTensor().reverse(reverse_array);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpSlice<Rank, Dtype>::OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_SLICE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Slice);
+}
+
+template <int Rank, DType Dtype>
+OpSlice<Rank, Dtype>::~OpSlice()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpSlice<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output type");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ for (size_t i = 0; i < attribute->begin().size(); i++)
+ {
+ begin_array[i] = attribute->begin()[i];
+ }
+
+ for (size_t i = 0; i < attribute->size().size(); i++)
+ {
+ if (attribute->size()[i] != 0)
+ {
+ size_array[i] = attribute->size()[i];
+ }
+ else
+ {
+ // Tensorflow assigns a zero size to dimensions that are kept
+ // Eigen expects size to be the full size of the dimension
+ size_array[i] = in->getTensor().dimension(0);
+ }
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSlice<Rank, Dtype>::eval()
+{
+ out->getTensor() = in->getTensor().slice(begin_array, size_array);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpTileBase<Rank, Dtype>::OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_TILE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Tile);
+}
+
+template <int Rank, DType Dtype>
+OpTileBase<Rank, Dtype>::~OpTileBase()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpTileBase<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same ranks and types
+ if (inputs[0]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output rank or type");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (attribute->multiples().size() != Rank)
+ {
+ printNodeValidationError("1D list 'multiples' must have size equal to input rank");
+ return 1;
+ }
+
+ for (int32_t d = 0; d < Rank; d++)
+ {
+ if (in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d])
+ {
+ printNodeValidationError("unexpected output shape");
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpTile<Rank, Dtype>::eval()
+{
+ // primary template shouldn't be called
+ FATAL_ERROR_NODE("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
+}
+
+template <DType Dtype>
+int OpTile<1, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ this->out->getTensor()(od0) = this->in->getTensor()(id0);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpTile<2, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1);
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpTile<3, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
+ {
+ int32_t id2 = od2 % this->in->getShape()[2];
+ this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2);
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpTile<4, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
+ {
+ int32_t id2 = od2 % this->in->getShape()[2];
+ for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
+ {
+ int32_t id3 = od3 % this->in->getShape()[3];
+ this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3);
+ }
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpTranspose<Rank, Dtype>::OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_TRANSPOSE, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType Dtype>
+OpTranspose<Rank, Dtype>::~OpTranspose()
+{}
+
+template <int Rank, DType Dtype>
+int OpTranspose<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output rank and type");
+ return 1;
+ }
+
+ if (inputs[0]->getElementCount() != outputs[0]->getElementCount())
+ {
+ printNodeValidationError("Failure to match input and output total element count");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ perm_tensor = dynamic_cast<TosaReference::TensorTemplate<ETensor1<int32_t>>*>(inputs[1]);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpTranspose<Rank, Dtype>::eval()
+{
+ for (int32_t d = 0; d < Rank; d++)
+ {
+ perm_array[d] = this->perm_tensor->getTensor().data()[d];
+ }
+
+ out->getTensor() = in->getTensor().shuffle(perm_array);
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, AINT8)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+
+DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT);
+DEF_INSTANTIATE_RESHAPE(OpReshape, AINT8);
+DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
+DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
+DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
+DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
new file mode 100644
index 0000000..100bd6b
--- /dev/null
+++ b/reference_model/src/ops/data_layout.h
@@ -0,0 +1,216 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_DATA_LAYOUT_H
+#define OPS_DATA_LAYOUT_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class OpConcat : public GraphNode
+{
+public:
+ OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpConcat();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ Eigen::array<int, Rank> reverser;
+ TosaReference::TensorTemplate<TIn>* lhs;
+ TosaReference::TensorTemplate<TIn>* rhs;
+ TosaAxisAttribute* attribute;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <int Rank, DType Dtype>
+class OpPad : public GraphNode
+{
+public:
+ OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpPad();
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ TosaPadQuantInfo* qinfo;
+};
+
+template <int InRank, int OutRank, DType Dtype>
+class OpReshape : public GraphNode
+{
+public:
+ OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpReshape();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, InRank>;
+ using TOut = Eigen::Tensor<OutEigenType, OutRank>;
+
+protected:
+ Eigen::array<Eigen::Index, OutRank> array_shape;
+ Eigen::array<Eigen::Index, InRank> in_reverser;
+ Eigen::array<Eigen::Index, OutRank> out_reverser;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReshapeAttribute* attribute;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <int Rank, DType Dtype>
+class OpReverse : public GraphNode
+{
+public:
+ OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpReverse();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ TosaAxisAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ Eigen::array<bool, Rank> reverse_array;
+};
+
+template <int Rank, DType Dtype>
+class OpSlice : public GraphNode
+{
+public:
+ OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpSlice();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ TosaSliceAttribute* attribute;
+ Eigen::array<Eigen::Index, Rank> begin_array;
+ Eigen::array<Eigen::Index, Rank> size_array;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <int Rank, DType Dtype>
+class OpTileBase : public GraphNode
+{
+public:
+ OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpTileBase();
+
+ virtual int checkTensorAttributes();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ TosaTileAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+// primary template for op tile
+template <int Rank, DType Dtype>
+class OpTile : public OpTileBase<Rank, Dtype>
+{
+public:
+ OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : OpTileBase<Rank, Dtype>(attribute_, qinfo_, id_)
+ {}
+
+protected:
+ virtual int eval();
+};
+
+// partial specialization for specific rank
+#define DEF_OP_TILE_RANK(N) \
+ template <DType Dtype> \
+ class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
+ { \
+ public: \
+ OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : OpTileBase<N, Dtype>(attribute_, qinfo_, id_) \
+ {} \
+ \
+ protected: \
+ virtual int eval(); \
+ };
+
+DEF_OP_TILE_RANK(1)
+DEF_OP_TILE_RANK(2)
+DEF_OP_TILE_RANK(3)
+DEF_OP_TILE_RANK(4)
+
+#undef DEF_OP_TILE_RANK
+
+template <int Rank, DType Dtype>
+class OpTranspose : public GraphNode
+{
+public:
+ OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpTranspose();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ Eigen::array<int, Rank> perm_array;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<ETensor1<int32_t>>* perm_tensor;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
new file mode 100644
index 0000000..2ee4935
--- /dev/null
+++ b/reference_model/src/ops/data_nodes.cc
@@ -0,0 +1,172 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "data_nodes.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+OpConst::OpConst(uint64_t id_)
+ : GraphNode(Op_CONST, id_)
+{
+ setRequiredOperands(0, 1);
+}
+
+OpConst::~OpConst()
+{}
+
+int OpConst::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ return 0;
+}
+
+int OpConst::eval()
+{
+ // Evaluation is trivial for constants
+ return GraphNode::eval();
+}
+
+OpPlaceholder::OpPlaceholder(uint64_t id_)
+ : GraphNode(Op_PLACEHOLDER, id_)
+{
+ setRequiredOperands(0, 1);
+}
+
+OpPlaceholder::~OpPlaceholder()
+{}
+
+int OpPlaceholder::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ return 0;
+}
+
+int OpPlaceholder::eval()
+{
+ // Evaluation is trivial for placeholders
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpIdentity<Rank, Dtype>::OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_IDENTITY, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType Dtype>
+OpIdentity<Rank, Dtype>::~OpIdentity()
+{}
+
+template <int Rank, DType Dtype>
+int OpIdentity<Rank, Dtype>::checkTensorAttributes()
+{
+
+ if (inputs.size() != outputs.size())
+ {
+ printNodeValidationError("Input and output tensor list lengths are not equal");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (in->matchRankTypeShape(*out))
+ {
+ printNodeValidationError("Input and output tensor rank, type, or shape do not match");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpIdentity<Rank, Dtype>::eval()
+{
+ out->getTensor() = in->getTensor();
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpIdentityN<Rank, Dtype>::OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_IDENTITYN, id_)
+{
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType Dtype>
+OpIdentityN<Rank, Dtype>::~OpIdentityN()
+{}
+
+template <int Rank, DType Dtype>
+int OpIdentityN<Rank, Dtype>::checkTensorAttributes()
+{
+
+ if (inputs.size() != outputs.size())
+ {
+ printNodeValidationError("Input and output tensor list lengths are not equal");
+ return 1;
+ }
+
+ for (size_t i = 0; i < inputs.size(); i++)
+ {
+ ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
+ outs.push_back(dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[i]));
+
+ if (ins[i]->matchRankTypeShape(*outs[i]))
+ {
+ printNodeValidationError("Input and output tensor rank, type, or shape do not match");
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpIdentityN<Rank, Dtype>::eval()
+{
+ for (size_t i = 0; i < ins.size(); i++)
+ {
+ outs[i]->getTensor() = ins[i]->getTensor();
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+// note OpConst and OpPlaceholder are not templated
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL);
diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h
new file mode 100644
index 0000000..bec4669
--- /dev/null
+++ b/reference_model/src/ops/data_nodes.h
@@ -0,0 +1,86 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_DATA_NODES_H
+#define OPS_DATA_NODES_H
+
+#include "graph_node.h"
+
+namespace TosaReference
+{
+
+class OpConst : public GraphNode
+{
+public:
+ OpConst(uint64_t id_);
+ virtual ~OpConst();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+};
+
+class OpPlaceholder : public GraphNode
+{
+public:
+ OpPlaceholder(uint64_t id_);
+ virtual ~OpPlaceholder();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpIdentity : public GraphNode
+{
+public:
+ OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpIdentity();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <int Rank, DType Dtype>
+class OpIdentityN : public GraphNode
+{
+public:
+ OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpIdentityN();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ std::vector<TosaReference::TensorTemplate<TIn>*> ins;
+ std::vector<TosaReference::TensorTemplate<TOut>*> outs;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
new file mode 100644
index 0000000..4d4f8b9
--- /dev/null
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -0,0 +1,586 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ewise_binary.h"
+#include "arith_util.h"
+#include "quant_util.h"
+#include "template_types.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType InDtype, DType OutDtype>
+BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(const Op& op_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(0, 6);
+
+ a_rank = b_rank = max_input_rank = -1;
+ a = b = nullptr;
+ a_rank0 = b_rank0 = nullptr;
+ result = nullptr;
+
+ fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
+{}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ a_rank = inputs[0]->getRank();
+ b_rank = inputs[1]->getRank();
+ if (a_rank != 0 && b_rank != 0 && a_rank != b_rank)
+ {
+ printNodeValidationError("Binary operator input ranks must match");
+ return 1;
+ }
+
+ max_input_rank = a_rank >= b_rank ? a_rank : b_rank;
+
+ // A & B must be the same types
+ if (inputs[0]->matchType(*inputs[1]))
+ {
+ printNodeValidationError("Binary operator input types must match");
+ return 1;
+ }
+
+ // Result's geometry must match, but the type may be wider
+ if (outputs[0]->getRank() != max_input_rank)
+ {
+ printNodeValidationError("Binary operator input and output genometry must match");
+ return 1;
+ }
+
+ if (a_rank == max_input_rank)
+ {
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ }
+ else
+ {
+ a_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[0]);
+ }
+
+ if (b_rank == max_input_rank)
+ {
+ b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ }
+ else
+ {
+ b_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[1]);
+ }
+
+ result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ // either a or b can be rank0
+ // a_rank0 and b_rank0 can't be valid at the same time.
+ // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0'
+ ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result);
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
+{
+ auto output_shape = result->getTensor().dimensions();
+
+ std::vector<int> a_shape, b_shape;
+
+ if (a_rank == max_input_rank)
+ {
+ a_shape = a->getShape();
+ }
+ else
+ {
+ a_shape.assign(max_input_rank, 1);
+ }
+
+ if (b_rank == max_input_rank)
+ {
+ b_shape = b->getShape();
+ }
+ else
+ {
+ b_shape.assign(max_input_rank, 1);
+ }
+
+ for (int i = 0; i < max_input_rank; i++)
+ {
+ if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
+ {
+ bcast_a[i] = output_shape[i];
+ }
+ else
+ {
+ bcast_a[i] = 1;
+ }
+ if (b_shape[i] != output_shape[i] && b_shape[i] == 1)
+ {
+ bcast_b[i] = output_shape[i];
+ }
+ else
+ {
+ bcast_b[i] = 1;
+ }
+ }
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int BinaryNode<Rank, InDtype, OutDtype>::eval()
+{
+ this->broadcast();
+
+ Eigen::array<int, Rank> reshaper;
+ reshaper.fill(1);
+ TIn ia, ib;
+
+ if (this->a_rank == this->max_input_rank)
+ {
+ ia = this->a->getTensor().broadcast(this->bcast_a);
+ }
+ else
+ {
+ ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a);
+ }
+
+ if (this->b_rank == this->max_input_rank)
+ {
+ ib = this->b->getTensor().broadcast(this->bcast_b);
+ }
+ else
+ {
+ ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b);
+ }
+
+ this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
+
+ return GraphNode::eval();
+}
+
+// still need to partial specialize this, or Eigen will throw static assertion
+template <DType InDtype, DType OutDtype>
+int BinaryNode<0, InDtype, OutDtype>::eval()
+{
+ this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpAdd<Rank, Dtype>::register_fcn()
+{
+ switch (InDtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
+{
+ int32_t num_bits = 0;
+ switch (Dtype)
+ {
+ case DType_INT8:
+ num_bits = 8;
+ break;
+ case DType_INT16:
+ num_bits = 16;
+ break;
+ case DType_INT32:
+ num_bits = 32;
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
+ uint32_t sign = a & (1 << (num_bits - 1));
+ uint32_t ones_mask = ONES_MASK(b) << (num_bits - b);
+ if (sign)
+ return ones_mask | (a >> b);
+ else
+ return (~ones_mask) & (a >> b);
+ };
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseAnd<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseOr<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseXor<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalAnd<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_INT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalRightShift<Rank, Dtype>::register_fcn()
+{
+ int32_t num_bits = 0;
+ switch (Dtype)
+ {
+ case DType_INT8:
+ num_bits = 8;
+ break;
+ case DType_INT16:
+ num_bits = 16;
+ break;
+ case DType_INT32:
+ num_bits = 32;
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
+ uint32_t mask = ONES_MASK(num_bits) >> b;
+ return (a >> b) & mask;
+ };
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalOr<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalXor<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpMaximum<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpMinimum<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpMul<Rank, InDtype, OutDtype>::register_fcn()
+{
+ switch (InDtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+ break;
+ case DType_INT8:
+ case DType_INT16:
+ this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
+ OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
+
+ OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
+
+ return clamped_output;
+ };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpPow<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSub<Rank, Dtype>::register_fcn()
+{
+ switch (InDtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank>
+OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_TABLE, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank>
+OpTable<Rank>::~OpTable()
+{}
+
+template <int Rank>
+int OpTable<Rank>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
+ {
+ FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && table && out);
+
+ return 0;
+}
+
+template <int Rank>
+int OpTable<Rank>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
+ // 1. make sure input is int16 range
+ int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
+
+ // 2. calculate index and interpolation fraction
+ int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1));
+ index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
+ int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
+
+ // 3. interpolate, generate 16.7 (23-bit) output
+ int32_t base = this->table->getTensor()(index);
+ int32_t next = this->table->getTensor()(index + 1);
+ int32_t value = (base << 7) + (next - base) * frac;
+
+ return value;
+ });
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
+
+DEF_INSTANTIATE_ONE_RANK_0_6(OpTable);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
new file mode 100644
index 0000000..00fb3d9
--- /dev/null
+++ b/reference_model/src/ops/ewise_binary.h
@@ -0,0 +1,195 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_EWISE_BINARY_H
+#define OPS_EWISE_BINARY_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+// class BinaryNodeBase: hold common functions of all the binary nodes
+// when an binary op is created, the virtual OpXXX::register_fcn() will be called
+// and 'fcn' will be register with lambda function which has two inputs
+// class BinaryNode: the level of indirection to partially specialize template for rank 0
+// eval() from toplevel called should call the .binaryExpr(dims, fcn) here
+// this needs to be partially specialize or
+// compiler will statically fail when trying to broadcast rank0 tensor
+// class OpXXX: implement per-element lambda function based on different data type
+// unlike BinaryNode, this doesn't need to be partially specialized
+
+// Eigen::Tensor does support some binary element-wise natively (e.g. CWiseMax, or '+', etc.)
+// which might be faster since it could be implemented with SIMD instructions
+// the way of registering lambda + .binaryExpr() might sacrifice performance here
+// but it can avoid partially specialization for combination of {rankN, rank0} x {FLOAT/INT32, QU8, ...}
+// needs to revisit if performance becomes a bottleneck here
+template <int Rank, DType InDtype, DType OutDtype>
+class BinaryNodeBase : public GraphNode
+{
+public:
+ BinaryNodeBase(const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_);
+ virtual ~BinaryNodeBase();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() = 0;
+ virtual int register_fcn() = 0;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ int broadcast();
+
+protected:
+ std::function<OutEigenType(InEigenType, InEigenType)> fcn;
+ Eigen::array<int, Rank> bcast_a;
+ Eigen::array<int, Rank> bcast_b;
+ TosaReference::TensorTemplate<TIn>* a;
+ TosaReference::TensorTemplate<TIn>* b;
+ TosaReference::TensorTemplate<ETensor0<InEigenType>>* a_rank0;
+ TosaReference::TensorTemplate<ETensor0<InEigenType>>* b_rank0;
+ TosaReference::TensorTemplate<TOut>* result;
+ int a_rank;
+ int b_rank;
+ int max_input_rank;
+};
+
+// primary class
+template <int Rank, DType InDtype, DType OutDtype>
+class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype>
+{
+public:
+ BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_)
+ : BinaryNodeBase<Rank, InDtype, OutDtype>(op_, qinfo_, id_)
+ {}
+ virtual ~BinaryNode()
+ {}
+
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+};
+
+// partial specialization for rank 0
+template <DType InDtype, DType OutDtype>
+class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype>
+{
+public:
+ BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_)
+ : BinaryNodeBase<0, InDtype, OutDtype>(op_, qinfo_, id_)
+ {}
+ virtual ~BinaryNode()
+ {}
+
+ virtual int eval();
+};
+
+#define DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Opname, OPNAME) \
+ template <int Rank, DType Dtype> \
+ class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \
+ { \
+ public: \
+ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : BinaryNode<Rank, Dtype, Dtype>(Op_##OPNAME, qinfo_, id_) \
+ { \
+ register_fcn(); \
+ } \
+ static constexpr DType InDtype = Dtype; \
+ static constexpr DType OutDtype = Dtype; \
+ using InEigenType = typename GetEigenType<InDtype>::type; \
+ using OutEigenType = typename GetEigenType<OutDtype>::type; \
+ virtual int register_fcn(); \
+ };
+
+#define DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Opname, OPNAME) \
+ template <int Rank, DType InDtype, DType OutDtype> \
+ class Op##Opname : public BinaryNode<Rank, InDtype, OutDtype> \
+ { \
+ public: \
+ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : BinaryNode<Rank, InDtype, OutDtype>(Op_##OPNAME, qinfo_, id_) \
+ { \
+ register_fcn(); \
+ } \
+ static constexpr int32_t QMin = GetQMin<OutDtype>::value; \
+ static constexpr int32_t QMax = GetQMax<OutDtype>::value; \
+ using InEigenType = typename GetEigenType<InDtype>::type; \
+ using OutEigenType = typename GetEigenType<OutDtype>::type; \
+ virtual int register_fcn(); \
+ };
+
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Add, ADD)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(ArithmeticRightShift, ARITHMETIC_RIGHT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseAnd, BITWISE_AND)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseOr, BITWISE_OR)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseXor, BITWISE_XOR)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalAnd, LOGICAL_AND)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalOr, LOGICAL_OR)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalXor, LOGICAL_XOR)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Maximum, MAXIMUM)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Minimum, MINIMUM)
+DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Mul, MUL)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Pow, POW)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Sub, SUB)
+
+#undef DEF_TEMPLATE_BINARY_OP_ONE_TYPE
+#undef DEF_TEMPLATE_BINARY_OP_TWO_TYPE
+
+template <int Rank>
+class OpTable : public GraphNode
+{
+public:
+ OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpTable();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ static constexpr DType InDtype = DType_INT16;
+ static constexpr DType TableDtype = DType_INT16;
+ static constexpr DType OutDtype = DType_INT32;
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using TableEigenType = typename GetEigenType<TableDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TTable = Eigen::Tensor<TableEigenType, 1>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+ static constexpr int32_t IntegerBits = 9;
+ static constexpr int32_t FractionBits = 7;
+ static constexpr int32_t NumTableEntries = (1 << IntegerBits);
+ static constexpr int32_t QInMin = GetQMin<InDtype>::value;
+ static constexpr int32_t QInMax = GetQMax<InDtype>::value;
+ static constexpr int32_t QOutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t QOutMax = GetQMax<OutDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TTable>* table;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
new file mode 100644
index 0000000..eded0d7
--- /dev/null
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -0,0 +1,115 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ewise_ternary.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+OpSelectBase<Rank, Dtype>::OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_SELECT, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType Dtype>
+OpSelectBase<Rank, Dtype>::~OpSelectBase()
+{}
+
+template <int Rank, DType Dtype>
+int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) ||
+ validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRank(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]) ||
+ inputs[2]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output rank and type");
+ return 1;
+ }
+
+ cond = dynamic_cast<TosaReference::TensorTemplate<TCond>*>(inputs[0]);
+ then_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ else_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[2]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(outputs[0]);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSelectBase<Rank, Dtype>::eval()
+{
+ FATAL_ERROR_NODE("shouldn't be called");
+}
+
+template <int Rank, DType Dtype>
+int OpSelect<Rank, Dtype>::broadcast()
+{
+ std::vector<int> cond_shape = this->cond->getShape();
+ std::vector<int> then_shape = this->then_val->getShape();
+ std::vector<int> else_shape = this->else_val->getShape();
+ std::vector<int> out_shape = this->out->getShape();
+
+ for (int i = 0; i < Rank; i++)
+ {
+ this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1;
+ this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1;
+ this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1;
+ ASSERT_MSG_NODE((this->bcast_cond[i] * cond_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
+ ASSERT_MSG_NODE((this->bcast_then[i] * then_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
+ ASSERT_MSG_NODE((this->bcast_else[i] * else_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSelect<Rank, Dtype>::eval()
+{
+ this->broadcast();
+ this->out->getTensor() = this->cond->getTensor()
+ .broadcast(this->bcast_cond)
+ .select(this->then_val->getTensor().broadcast(this->bcast_then),
+ this->else_val->getTensor().broadcast(this->bcast_else));
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpSelect<0, Dtype>::eval()
+{
+ this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor());
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);
diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h
new file mode 100644
index 0000000..b354247
--- /dev/null
+++ b/reference_model/src/ops/ewise_ternary.h
@@ -0,0 +1,83 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_TERNARY_H
+#define OPS_TERNARY_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+// The Ternary Select op has the following operands:
+// 1. Cond: rank N, type=bool
+// 2. Then_val: Rank N, type=<V>
+// 3. Else_val: Rank N, type=<V>
+// 4. Result: Rank N, type=<V>
+// Cond, Then_val, Else_val need to be mutually-broadcastable
+template <int Rank, DType Dtype>
+class OpSelectBase : public GraphNode
+{
+public:
+ OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpSelectBase();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using CondEigenType = typename GetEigenType<DType_BOOL>::type;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using TCond = Eigen::Tensor<CondEigenType, Rank>;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+
+protected:
+ TosaReference::TensorTemplate<TCond>* cond;
+ Eigen::array<int, Rank> bcast_cond;
+ Eigen::array<int, Rank> bcast_then;
+ Eigen::array<int, Rank> bcast_else;
+ TosaReference::TensorTemplate<TIn>* then_val;
+ TosaReference::TensorTemplate<TIn>* else_val;
+ TosaReference::TensorTemplate<TIn>* out;
+};
+
+// primary class
+template <int Rank, DType Dtype>
+class OpSelect : public OpSelectBase<Rank, Dtype>
+{
+public:
+ OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : OpSelectBase<Rank, Dtype>(attribute_, qinfo_, id_)
+ {}
+ virtual int eval();
+ int broadcast();
+
+ using InEigenType = typename OpSelectBase<Rank, Dtype>::InEigenType;
+};
+
+// partial specialization for rank 0
+template <DType Dtype>
+class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype>
+{
+public:
+ OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : OpSelectBase<0, Dtype>(attribute_, qinfo_, id_)
+ {}
+ virtual int eval();
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
new file mode 100644
index 0000000..d7bddc0
--- /dev/null
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -0,0 +1,302 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ewise_unary.h"
+#include "quant_util.h"
+#include "template_types.h"
+#include <cmath>
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+UnaryNode<Rank, Dtype>::UnaryNode(const Op& op_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); };
+}
+
+template <int Rank, DType Dtype>
+UnaryNode<Rank, Dtype>::~UnaryNode()
+{}
+
+template <int Rank, DType Dtype>
+int UnaryNode<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankSize(*outputs[0]))
+ {
+ printNodeValidationError("UnaryNode: input and output rank must match");
+ return 1;
+ }
+
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(a && result);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int UnaryNode<Rank, Dtype>::eval()
+{
+ this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpAbs<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseNot<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return ~a; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpCeil<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpClz<Rank, Dtype>::register_fcn()
+{
+ int32_t num_bits;
+ switch (Dtype)
+ {
+ case DType_INT32:
+ num_bits = 32;
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ this->fcn = [num_bits](int32_t a) -> int32_t {
+ int32_t leading_zeros = 0;
+ for (int bit = num_bits - 1; bit >= 0; bit--)
+ {
+ if (((a >> bit) & 0x1) == 0)
+ {
+ leading_zeros++;
+ }
+ else
+ {
+ break;
+ }
+ }
+ return leading_zeros;
+ };
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpExp<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpFloor<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLog<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalNot<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a) -> OutEigenType { return !a; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpNegate<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType {
+ InEigenType result = -(a);
+ return result;
+ };
+ break;
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a) -> OutEigenType {
+ InEigenType result = -(a);
+ return result;
+ };
+ break;
+ case DType_AINT8:
+ ASSERT(this->qinfo);
+ this->fcn = [this](InEigenType a) -> OutEigenType {
+ InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp();
+ return result;
+ };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReciprocal<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpRsqrt<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);
diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h
new file mode 100644
index 0000000..0db3cfb
--- /dev/null
+++ b/reference_model/src/ops/ewise_unary.h
@@ -0,0 +1,102 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_EWISE_UNARY_H
+#define OPS_EWISE_UNARY_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+template <int Rank, DType Dtype>
+class UnaryNode : public GraphNode
+{
+public:
+ UnaryNode(const Op& nodeType, const uint64_t id_);
+ virtual ~UnaryNode();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+ virtual int register_fcn() = 0;
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ std::function<OutEigenType(InEigenType)> fcn;
+ TosaReference::TensorTemplate<TIn>* a;
+ TosaReference::TensorTemplate<TOut>* result;
+};
+
+#define DEF_TEMPLATE_UNARY_OP(Opname, OPNAME) \
+ template <int Rank, DType Dtype> \
+ class Op##Opname : public UnaryNode<Rank, Dtype> \
+ { \
+ public: \
+ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : UnaryNode<Rank, Dtype>(Op_##OPNAME, id_) \
+ { \
+ register_fcn(); \
+ } \
+ static constexpr int32_t QMin = GetQMin<Dtype>::value; \
+ static constexpr int32_t QMax = GetQMax<Dtype>::value; \
+ using InEigenType = typename GetEigenType<Dtype>::type; \
+ using OutEigenType = typename GetEigenType<Dtype>::type; \
+ virtual int register_fcn(); \
+ };
+
+#define DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Opname, OPNAME) \
+ template <int Rank, DType Dtype> \
+ class Op##Opname : public UnaryNode<Rank, Dtype> \
+ { \
+ public: \
+ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : UnaryNode<Rank, Dtype>(Op_##OPNAME, id_) \
+ { \
+ INIT_QINFO(Unary); \
+ register_fcn(); \
+ } \
+ static constexpr int32_t QMin = GetQMin<Dtype>::value; \
+ static constexpr int32_t QMax = GetQMax<Dtype>::value; \
+ using InEigenType = typename GetEigenType<Dtype>::type; \
+ using OutEigenType = typename GetEigenType<Dtype>::type; \
+ virtual int register_fcn(); \
+ \
+ protected: \
+ TosaUnaryQuantInfo* qinfo; \
+ };
+
+DEF_TEMPLATE_UNARY_OP(Abs, ABS)
+DEF_TEMPLATE_UNARY_OP(BitwiseNot, BITWISE_NOT)
+DEF_TEMPLATE_UNARY_OP(Ceil, CEIL)
+DEF_TEMPLATE_UNARY_OP(Clz, CLZ)
+DEF_TEMPLATE_UNARY_OP(Exp, EXP)
+DEF_TEMPLATE_UNARY_OP(Floor, FLOOR)
+DEF_TEMPLATE_UNARY_OP(Log, LOG)
+DEF_TEMPLATE_UNARY_OP(LogicalNot, LOGICAL_NOT)
+DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Negate, NEGATE)
+DEF_TEMPLATE_UNARY_OP(Reciprocal, RECIPROCAL)
+DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT)
+
+#undef DEF_TEMPLATE_UNARY_OP
+#undef DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
new file mode 100644
index 0000000..d3352ce
--- /dev/null
+++ b/reference_model/src/ops/image.cc
@@ -0,0 +1,169 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "image.h"
+#include "arith_util.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <DType InDtype, DType OutDtype>
+OpResize<InDtype, OutDtype>::OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_RESIZE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(4, 4);
+
+ INIT_ATTRIBUTE(Resize);
+}
+
+template <DType InDtype, DType OutDtype>
+OpResize<InDtype, OutDtype>::~OpResize()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpResize<InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ return 1;
+
+ output_size = this->attribute->output_size();
+ stride = this->attribute->stride();
+ offset = this->attribute->offset();
+ shift = this->attribute->shift();
+ mode = this->attribute->mode();
+
+ int output_height = outputs[0]->getShape()[1];
+ int output_width = outputs[0]->getShape()[2];
+
+ if (this->mode == ResizeMode_BILINEAR)
+ {
+ if (OutDtype != DType_INT32 && OutDtype != DType_INT48)
+ {
+ printNodeValidationError("OpResize: invalid data type for BILINEAR");
+ return 1;
+ }
+ }
+ else
+ {
+ if (OutDtype != DType_INT8 && OutDtype != DType_INT16)
+ {
+ printNodeValidationError("OpResize: invalid data type for NEAREST");
+ return 1;
+ }
+ }
+
+ if (output_size[0] != output_height || output_size[1] != output_width)
+ {
+ printNodeValidationError("OpResize: attribute output_size doesn't match output [height, width]");
+ return 1;
+ }
+
+ if (shift < 1 || shift > 11)
+ {
+ printNodeValidationError("OpResize: attribute shift should be within [1, 11]");
+ return 1;
+ }
+
+ if (stride[0] <= 0 || stride[1] <= 0)
+ {
+ printNodeValidationError("OpResize: invalid attribute stride");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ return 0;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpResize<InDtype, OutDtype>::eval()
+{
+ int in_batch = in->getShape()[0];
+ int in_height = in->getShape()[1];
+ int in_width = in->getShape()[2];
+ int in_channels = in->getShape()[3];
+
+ int out_batch = out->getShape()[0];
+ int out_height = out->getShape()[1];
+ int out_width = out->getShape()[2];
+ int out_channels = out->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch");
+ ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch");
+
+ for (int b = 0; b < out_batch; b++)
+ for (int c = 0; c < out_channels; c++)
+ for (int oy = 0; oy < out_height; oy++)
+ for (int ox = 0; ox < out_width; ox++)
+ {
+ int y = oy * stride[0] + offset[0];
+ int x = ox * stride[1] + offset[1];
+
+ int iy = y >> shift;
+ int dy = y - (iy << shift);
+ int ix = x >> shift;
+ int dx = x - (ix << shift);
+
+ int iy0 = MAX(iy, 0);
+ int iy1 = MIN(iy + 1, in_height - 1);
+ int ix0 = MAX(ix, 0);
+ int ix1 = MIN(ix + 1, in_width - 1);
+
+ ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)",
+ iy0, iy1, ix0, ix1);
+
+ InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
+ InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
+ InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
+ InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
+
+ OutEigenType acc;
+ if (mode == ResizeMode_BILINEAR)
+ {
+ acc = (OutEigenType)v00 * ((1 << shift) - dy) * ((1 << shift) - dx);
+ acc = acc + (OutEigenType)v01 * ((1 << shift) - dy) * dx;
+ acc = acc + (OutEigenType)v10 * dy * ((1 << shift) - dx);
+ acc = acc + (OutEigenType)v11 * dy * dx;
+ }
+ else
+ {
+ iy = (dy >> (shift - 1)) != 0 ? iy1 : iy0;
+ ix = (dx >> (shift - 1)) != 0 ? ix1 : ix0;
+ acc = in->getTensor()(b, iy, ix, c);
+ }
+
+ out->getTensor()(b, oy, ox, c) = acc;
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT48);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT16);
diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h
new file mode 100644
index 0000000..9d15d49
--- /dev/null
+++ b/reference_model/src/ops/image.h
@@ -0,0 +1,53 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_IMAGE_H
+#define OPS_IMAGE_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <DType InDtype, DType OutDtype>
+class OpResize : public GraphNode
+{
+public:
+ OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpResize();
+ virtual int checkTensorAttributes() final;
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
+
+protected:
+ TosaResizeAttribute* attribute;
+ std::vector<int32_t> output_size;
+ std::vector<int32_t> stride;
+ std::vector<int32_t> offset;
+ int32_t shift;
+ ResizeMode mode;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
new file mode 100644
index 0000000..bad0c40
--- /dev/null
+++ b/reference_model/src/ops/op_factory.cc
@@ -0,0 +1,432 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "op_factory.h"
+#include "activation_funcs.h"
+#include "comparison.h"
+#include "control_flow.h"
+#include "custom.h"
+#include "data_layout.h"
+#include "data_nodes.h"
+#include "ewise_binary.h"
+#include "ewise_ternary.h"
+#include "ewise_unary.h"
+#include "image.h"
+#include "reduction.h"
+#include "scatter_gather.h"
+#include "tensor_ops.h"
+#include "type_conversion.h"
+
+using namespace TosaReference;
+using namespace tosa;
+
+GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
+ Op opType,
+ TosaAttributeBase* attribute,
+ TosaQuantInfoBase* qinfo,
+ uint64_t id,
+ DType inputDType,
+ int inputRank,
+ DType outputDType,
+ int outputRank,
+ DType weightDType,
+ int weightRank)
+{
+ switch (opType)
+ {
+ // tensor_ops
+ case Op_ARGMAX:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+ break;
+ case Op_AVG_POOL2D:
+ DEF_FACTORY_ONE_TYPE(OpAvgPool2d, FLOAT);
+ DEF_FACTORY_ONE_TYPE(OpAvgPool2d, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT16);
+ break;
+ case Op_CONV2D:
+ DEF_FACTORY_TWO_TYPE(OpConv2d, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8);
+ break;
+ case Op_DEPTHWISE_CONV2D:
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8);
+ break;
+ case Op_FULLY_CONNECTED:
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT16, INT8);
+ break;
+ case Op_MATMUL:
+ DEF_FACTORY_ONE_TYPE(OpMatMul, FLOAT);
+ DEF_FACTORY_ONE_TYPE(OpMatMul, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpMatMul, INT16);
+ break;
+ case Op_MAX_POOL2D:
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
+ break;
+ case Op_TRANSPOSE_CONV2D:
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT16, INT8);
+ break;
+
+ // activation_funcs
+ case Op_CLAMP:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+ break;
+ case Op_RELUN:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, INT32);
+ break;
+ case Op_SIGMOID:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT);
+ break;
+ case Op_TANH:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT);
+ break;
+
+ // ewise_binary
+ case Op_ADD:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
+ break;
+ case Op_ARITHMETIC_RIGHT_SHIFT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
+ break;
+ case Op_BITWISE_AND:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
+ break;
+ case Op_BITWISE_OR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
+ break;
+ case Op_BITWISE_XOR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
+ break;
+ case Op_LOGICAL_AND:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
+ break;
+ case Op_LOGICAL_LEFT_SHIFT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
+ break;
+ case Op_LOGICAL_RIGHT_SHIFT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
+ break;
+ case Op_LOGICAL_OR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
+ break;
+ case Op_LOGICAL_XOR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
+ break;
+ case Op_MAXIMUM:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+ break;
+ case Op_MINIMUM:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+ break;
+ case Op_MUL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+ break;
+ case Op_POW:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
+ break;
+ case Op_SUB:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
+ break;
+ case Op_TABLE:
+ DEF_FACTORY_ONE_RANK_0_6(OpTable);
+ break;
+
+ // ewise_unary
+ case Op_ABS:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
+ break;
+ case Op_BITWISE_NOT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
+ break;
+ case Op_CEIL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
+ break;
+ case Op_CLZ:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
+ break;
+ case Op_EXP:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
+ break;
+ case Op_FLOOR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
+ break;
+ case Op_LOG:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
+ break;
+ case Op_LOGICAL_NOT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
+ break;
+ case Op_NEGATE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+ break;
+ case Op_RECIPROCAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);
+ break;
+ case Op_RSQRT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
+ break;
+
+ // ewise_ternary
+ case Op_SELECT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);
+ break;
+
+ // comparison
+ case Op_EQUAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
+ break;
+ case Op_GREATER:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
+ break;
+ case Op_GREATER_EQUAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
+ break;
+
+ // reduction
+ case Op_REDUCE_ALL:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
+ break;
+ case Op_REDUCE_ANY:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
+ break;
+ case Op_REDUCE_MAX:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+ break;
+ case Op_REDUCE_MIN:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+ break;
+ case Op_REDUCE_PRODUCT:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT);
+ break;
+ case Op_REDUCE_SUM:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, INT32);
+ break;
+
+ // data layout
+ case Op_CONCAT:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL);
+ break;
+ case Op_PAD:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+ break;
+ case Op_RESHAPE:
+ DEF_FACTORY_RESHAPE(OpReshape, FLOAT);
+ DEF_FACTORY_RESHAPE(OpReshape, AINT8);
+ DEF_FACTORY_RESHAPE(OpReshape, INT8);
+ DEF_FACTORY_RESHAPE(OpReshape, INT16);
+ DEF_FACTORY_RESHAPE(OpReshape, INT32);
+ DEF_FACTORY_RESHAPE(OpReshape, BOOL);
+ break;
+ case Op_REVERSE:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+ break;
+ case Op_SLICE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+ break;
+ case Op_TILE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+ break;
+ case Op_TRANSPOSE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+ break;
+
+ // scatter_gather
+ case Op_GATHER:
+ {
+ // output.rank = input.rank - 1 + index.rank
+ int32_t index_rank = outputRank - inputRank + 1;
+ DEF_FACTORY_GATHER(OpGather, AINT8);
+ DEF_FACTORY_GATHER(OpGather, INT16);
+ DEF_FACTORY_GATHER(OpGather, INT32);
+ }
+ break;
+
+ // image
+ case Op_RESIZE:
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT8);
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT48);
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT16);
+ break;
+
+ // data_nodes
+ case Op_CONST:
+ return new OpConst(id);
+ case Op_PLACEHOLDER:
+ return new OpPlaceholder(id);
+ case Op_IDENTITY:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
+ break;
+ case Op_IDENTITYN:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL);
+ break;
+
+ // type_conversion
+ case Op_CAST:
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
+ break;
+ case Op_RESCALE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, UINT8);
+ break;
+
+ // custom
+ case Op_CUSTOM:
+ return new OpCustom(id);
+
+ // control_flow
+ case Op_COND_IF:
+ return new OpCondIf(tsh, attribute, id);
+ case Op_WHILE_LOOP:
+ return new OpWhileLoop(tsh, attribute, id);
+
+ // Ops not recognized
+ default:
+ goto done;
+
+ } // End of switch(opType)
+
+done:
+ return nullptr;
+}
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
new file mode 100644
index 0000000..cde6841
--- /dev/null
+++ b/reference_model/src/ops/op_factory.h
@@ -0,0 +1,294 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_OP_FACTORY_H
+#define OPS_OP_FACTORY_H
+
+#include "attribute.h"
+#include "graph_node.h"
+#include "quant_info.h"
+#include "template_types.h"
+#include "tosa_serialization_handler.h"
+
+#define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
+ case RANK: \
+ return new OP<RANK, DType_##DTYPE>(attribute, qinfo, id);
+
+#define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
+ case RANK: \
+ return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id);
+
+#define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
+ case RANK2: \
+ return new OP<RANK1, RANK2, DType_##DTYPE>(attribute, qinfo, id);
+
+#define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
+ case RANK2: \
+ return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id);
+
+#define DEF_FACTORY_ONE_RANK_0_6(OP) \
+ switch (inputRank) \
+ { \
+ case 0: \
+ return new OP<0>(attribute, qinfo, id); \
+ case 1: \
+ return new OP<1>(attribute, qinfo, id); \
+ case 2: \
+ return new OP<2>(attribute, qinfo, id); \
+ case 3: \
+ return new OP<3>(attribute, qinfo, id); \
+ case 4: \
+ return new OP<4>(attribute, qinfo, id); \
+ case 5: \
+ return new OP<5>(attribute, qinfo, id); \
+ case 6: \
+ return new OP<6>(attribute, qinfo, id); \
+ }
+
+#define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE) \
+ { \
+ return new OP<DType_##DTYPE>(attribute, qinfo, id); \
+ }
+
+#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); \
+ }
+
+#define DEF_FACTORY_TWO_TYPE_RESIZE(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); \
+ }
+
+#define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE) \
+ { \
+ switch (inputRank) \
+ { \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) \
+ } \
+ }
+
+#define DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE) \
+ { \
+ switch (inputRank) \
+ { \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) \
+ } \
+ }
+
+#define DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ { \
+ switch (inputRank) \
+ { \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2) \
+ } \
+ }
+
+#define DEF_FACTORY_RESHAPE(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
+ { \
+ switch (inputRank) \
+ { \
+ case 0: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \
+ } \
+ } \
+ case 1: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
+ } \
+ } \
+ case 2: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \
+ } \
+ } \
+ case 3: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \
+ } \
+ } \
+ case 4: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \
+ } \
+ } \
+ case 5: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \
+ } \
+ } \
+ case 6: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE) \
+ } \
+ } \
+ } \
+ }
+
+#define DEF_FACTORY_GATHER(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
+ { \
+ switch (inputRank) \
+ { \
+ case 1: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE); \
+ } \
+ case 2: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE); \
+ } \
+ case 3: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE); \
+ } \
+ case 4: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE); \
+ } \
+ case 5: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE); \
+ } \
+ case 6: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE); \
+ } \
+ } \
+ }
+
+namespace TosaReference
+{
+
+class OpFactory
+{
+public:
+ static GraphNode* newOp(tosa::TosaSerializationHandler* tsh,
+ tosa::Op opType,
+ tosa::TosaAttributeBase* attribute,
+ tosa::TosaQuantInfoBase* qinfo,
+ uint64_t id,
+ tosa::DType inputDType,
+ int inputRank,
+ tosa::DType outputDType,
+ int outputRank,
+ tosa::DType weightDType,
+ int weightRank);
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
new file mode 100644
index 0000000..a2adfdb
--- /dev/null
+++ b/reference_model/src/ops/reduction.cc
@@ -0,0 +1,139 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reduction.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+ReduceNode<Rank, Dtype>::ReduceNode(const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 4);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, DType Dtype>
+ReduceNode<Rank, Dtype>::~ReduceNode()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int ReduceNode<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
+ {
+ printNodeValidationError("Reduce axis must between [0, input_rank - 1]");
+ return 1;
+ }
+
+ if (inputs[0]->matchRank(*outputs[0]))
+ {
+ printNodeValidationError("Input and output tensor ranks must match");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ dims[0] = this->attribute->axis();
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReduceAll<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceAny<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceMax<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceMin<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceProduct<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceSum<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, INT32);
diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h
new file mode 100644
index 0000000..cf75812
--- /dev/null
+++ b/reference_model/src/ops/reduction.h
@@ -0,0 +1,109 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_REDUCTION_H
+#define OPS_REDUCTION_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class ReduceNode : public GraphNode
+{
+public:
+ ReduceNode(const Op& nodeType, TosaAttributeBase* attribute_, const uint64_t id_);
+ virtual ~ReduceNode();
+ virtual int checkTensorAttributes();
+ virtual int eval() = 0;
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ Eigen::array<int, 1> dims;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ TosaAxisAttribute* attribute;
+};
+
+template <int Rank, DType Dtype>
+class OpReduceAll : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceAll(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_ALL, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceAny : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceAny(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_ALL, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceMax : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_MAX, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceMin : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceMin(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_MIN, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceProduct : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceProduct(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_PRODUCT, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceSum : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceSum(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_SUM, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
new file mode 100644
index 0000000..c54204a
--- /dev/null
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -0,0 +1,120 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "scatter_gather.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int InRank, int IndexRank, DType Dtype>
+OpGather<InRank, IndexRank, Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_GATHER, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(1, 6);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int InRank, int IndexRank, DType Dtype>
+OpGather<InRank, IndexRank, Dtype>::~OpGather()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int InRank, int IndexRank, DType Dtype>
+int OpGather<InRank, IndexRank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output type");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ index = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && index && out);
+
+ return 0;
+}
+
+template <int InRank, int IndexRank, DType Dtype>
+int OpGather<InRank, IndexRank, Dtype>::eval()
+{
+ int axis = attribute->axis();
+
+ // calculate size left and right to axis
+ int left_size = 1;
+ for (int i = 0; i < axis; ++i)
+ {
+ left_size *= in->getShape()[i];
+ }
+
+ int right_size = 1;
+ for (int i = axis + 1; i < in->getRank(); ++i)
+ {
+ right_size *= in->getShape()[i];
+ }
+
+ InEigenType* input_data = in->getTensor().data();
+ int32_t* index_data = index->getTensor().data();
+ OutEigenType* output_data = out->getTensor().data();
+
+ int32_t axis_size = in->getShape()[axis];
+ int32_t index_count = index->getElementCount();
+
+ // sanity check if index is valid
+ // need to check until this point since index is not known until runtime
+ for (size_t i = 0; i < index->getElementCount(); i++)
+ {
+ if (index_data[i] >= axis_size)
+ {
+ FATAL_ERROR_NODE("OpGather: index[%lu]=%i can't exceed axis_size=%i", i, index_data[i], axis_size);
+ }
+ }
+
+ // Eigen stores tensor in column-major
+ // so we iterate through dimension right to axis and the index array
+ // do memory copy with size of left size each time
+ for (int right = 0; right < right_size; ++right)
+ {
+ for (int i = 0; i < index_count; ++i)
+ {
+ std::memcpy(output_data + (right * index_count + i) * left_size,
+ input_data + (right * axis_size + index_data[i]) * left_size, sizeof(InEigenType) * left_size);
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_GATHER(OpGather, AINT8);
+DEF_INSTANTIATE_GATHER(OpGather, INT16);
+DEF_INSTANTIATE_GATHER(OpGather, INT32);
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h
new file mode 100644
index 0000000..d9b1263
--- /dev/null
+++ b/reference_model/src/ops/scatter_gather.h
@@ -0,0 +1,54 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_SCATTER_GATHER_H
+#define OPS_SCATTER_GATHER_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+// input and index can have different rank
+// and infer OutRank statically
+template <int InRank, int IndexRank, DType Dtype>
+class OpGather : public GraphNode
+{
+public:
+ OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpGather();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ static constexpr int OutRank = InRank - 1 + IndexRank;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, InRank>;
+ using TIndex = Eigen::Tensor<int32_t, IndexRank>;
+ using TOut = Eigen::Tensor<OutEigenType, OutRank>;
+
+protected:
+ TosaAxisAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TIndex>* index;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
new file mode 100644
index 0000000..1859e03
--- /dev/null
+++ b/reference_model/src/ops/template_types.h
@@ -0,0 +1,277 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OP_TEMPLATE_TYPES_H
+#define OP_TEMPLATE_TYPES_H
+
+#include "tosa_generated.h"
+#include <Eigen/CXX11/Tensor>
+
+using namespace tosa;
+
+namespace TosaReference
+{
+// Shorter aliase templates for common Eigen::Tensor types
+template <typename T>
+using ETensor0 = Eigen::Tensor<T, 0>;
+template <typename T>
+using ETensor1 = Eigen::Tensor<T, 1>;
+template <typename T>
+using ETensor2 = Eigen::Tensor<T, 2>;
+template <typename T>
+using ETensor3 = Eigen::Tensor<T, 3>;
+template <typename T>
+using ETensor4 = Eigen::Tensor<T, 4>;
+template <typename T>
+using ETensor5 = Eigen::Tensor<T, 5>;
+template <typename T>
+using ETensor6 = Eigen::Tensor<T, 6>;
+
+// Forward declaration
+template <class T>
+class TensorTemplate;
+
+// Shortcut to hide the TensorTemplate class.
+// For example, declare Tensor1<float> to get a TensorTemplate
+// with an Eigen::Tensor<float, 1>
+template <typename T>
+using Tensor0 = TensorTemplate<ETensor0<T>>;
+template <typename T>
+using Tensor1 = TensorTemplate<ETensor1<T>>;
+template <typename T>
+using Tensor2 = TensorTemplate<ETensor2<T>>;
+template <typename T>
+using Tensor3 = TensorTemplate<ETensor3<T>>;
+template <typename T>
+using Tensor4 = TensorTemplate<ETensor4<T>>;
+template <typename T>
+using Tensor5 = TensorTemplate<ETensor5<T>>;
+template <typename T>
+using Tensor6 = TensorTemplate<ETensor6<T>>;
+
+template <DType type>
+struct GetEigenType;
+template <>
+struct GetEigenType<DType_FLOAT>
+{
+ using type = float;
+};
+template <>
+struct GetEigenType<DType_INT32>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_INT48>
+{
+ using type = int64_t;
+};
+template <>
+struct GetEigenType<DType_BOOL>
+{
+ using type = bool;
+};
+template <>
+struct GetEigenType<DType_AINT8>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_UINT8>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_INT4>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_INT8>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_INT16>
+{
+ using type = int32_t;
+};
+
+// Meta function to get number of bits
+template <DType T>
+struct GetNumBits
+{
+ static constexpr int32_t value = 0;
+};
+template <>
+struct GetNumBits<DType_BOOL>
+{
+ static constexpr int32_t value = 1;
+};
+template <>
+struct GetNumBits<DType_AINT8>
+{
+ static constexpr int32_t value = 8;
+};
+template <>
+struct GetNumBits<DType_UINT8>
+{
+ static constexpr int32_t value = 8;
+};
+template <>
+struct GetNumBits<DType_INT4>
+{
+ static constexpr int32_t value = 4;
+};
+template <>
+struct GetNumBits<DType_INT8>
+{
+ static constexpr int32_t value = 8;
+};
+template <>
+struct GetNumBits<DType_INT16>
+{
+ static constexpr int32_t value = 16;
+};
+template <>
+struct GetNumBits<DType_INT32>
+{
+ static constexpr int32_t value = 32;
+};
+template <>
+struct GetNumBits<DType_INT48>
+{
+ static constexpr int32_t value = 48;
+};
+
+// Meta function to get quantized min/max in compile time
+template <DType T>
+struct GetQMin
+{
+ static constexpr int64_t value = 0L;
+};
+template <>
+struct GetQMin<DType_AINT8>
+{
+ static constexpr int64_t value = -128L;
+};
+template <>
+struct GetQMin<DType_UINT8>
+{
+ static constexpr int64_t value = 0L;
+};
+template <>
+struct GetQMin<DType_INT4>
+{
+ static constexpr int64_t value = -8L;
+};
+template <>
+struct GetQMin<DType_INT8>
+{
+ static constexpr int64_t value = -128L;
+};
+template <>
+struct GetQMin<DType_INT16>
+{
+ static constexpr int64_t value = -32768L;
+};
+template <>
+struct GetQMin<DType_INT32>
+{
+ static constexpr int64_t value = -(1L << 31);
+};
+template <>
+struct GetQMin<DType_INT48>
+{
+ static constexpr int64_t value = -(1L << 47);
+};
+
+template <DType T>
+struct GetQMax
+{
+ static constexpr int64_t value = 0L;
+};
+template <>
+struct GetQMax<DType_AINT8>
+{
+ static constexpr int64_t value = 127L;
+};
+template <>
+struct GetQMax<DType_UINT8>
+{
+ static constexpr int64_t value = 255L;
+};
+template <>
+struct GetQMax<DType_INT4>
+{
+ static constexpr int64_t value = 7L;
+};
+template <>
+struct GetQMax<DType_INT8>
+{
+ static constexpr int64_t value = 127L;
+};
+template <>
+struct GetQMax<DType_INT16>
+{
+ static constexpr int64_t value = 32767L;
+};
+template <>
+struct GetQMax<DType_INT32>
+{
+ static constexpr int64_t value = (1L << 31) - 1;
+};
+template <>
+struct GetQMax<DType_INT48>
+{
+ static constexpr int64_t value = (1L << 47) - 1;
+};
+
+template <DType TIn1, DType TIn2>
+struct GetAccDType;
+template <>
+struct GetAccDType<DType_AINT8, DType_AINT8>
+{
+ static constexpr DType value = DType_INT32;
+};
+template <>
+struct GetAccDType<DType_AINT8, DType_INT4>
+{
+ static constexpr DType value = DType_INT32;
+};
+template <>
+struct GetAccDType<DType_AINT8, DType_INT8>
+{
+ static constexpr DType value = DType_INT32;
+};
+template <>
+struct GetAccDType<DType_INT16, DType_INT8>
+{
+ static constexpr DType value = DType_INT48;
+};
+template <>
+struct GetAccDType<DType_INT16, DType_INT16>
+{
+ static constexpr DType value = DType_INT48;
+};
+template <>
+struct GetAccDType<DType_FLOAT, DType_FLOAT>
+{
+ static constexpr DType value = DType_FLOAT;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
new file mode 100644
index 0000000..a735334
--- /dev/null
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -0,0 +1,1229 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tensor_ops.h"
+#include "quant_util.h"
+#include "template_types.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+OpArgMax<Rank, Dtype>::OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_ARGMAX, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, DType Dtype>
+OpArgMax<Rank, Dtype>::~OpArgMax()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpArgMax<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpArgMax<Rank, Dtype>::eval()
+{
+ Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
+
+ this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpAvgPool2d<Dtype>::OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_AVG_POOL2D, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(Pool2d);
+ INIT_QINFO(Unary);
+}
+
+template <DType Dtype>
+OpAvgPool2d<Dtype>::~OpAvgPool2d()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <DType Dtype>
+int OpAvgPool2d<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (!in->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpAvgPool2d: unsupported tensor format");
+ return 1;
+ }
+
+ if (attribute->padding().size() != 4)
+ {
+ printNodeValidationError("OpAvgPool2d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->kernel().size() != 2)
+ {
+ printNodeValidationError("OpAvgPool2d: illegal size for attribute kernel");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpAvgPool2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType Dtype>
+ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride)
+{
+ ETensor1<int32_t> result(out_size);
+
+ int32_t total_pad = (out_size - 1) * stride + kernel_size - in_size;
+ total_pad = total_pad < 0 ? 0 : total_pad;
+
+ int32_t pad_left = total_pad >> 1;
+ int32_t pad_right = total_pad - pad_left;
+
+ result.setConstant(kernel_size);
+
+ // the index left to 'left_index' and index right to 'right_index' indicates
+ // the input window of this output covers a pad bit
+ int32_t left_index = pad_left / stride;
+ int32_t right_index = pad_right / stride;
+
+ // not handle ultra small activation yet
+ ASSERT_MSG_NODE((out_size - 1 - right_index) >= left_index, "AvgPool2d: Small activations not supported yet");
+
+ // minus the number of pad bit this index cover
+ while (left_index >= 0)
+ {
+ result(left_index) -= (pad_left - left_index * stride);
+ left_index--;
+ }
+
+ while (right_index >= 0)
+ {
+ result(out_size - 1 - right_index) -= (pad_right - right_index * stride);
+ right_index--;
+ }
+
+ return result;
+}
+
+// assuming input and output tensor have same scales like tflite reference
+// so no need to scale input and output
+template <DType Dtype>
+int OpAvgPool2d<Dtype>::eval()
+{
+ int in_batch = this->in->getShape()[0];
+ int in_height = this->in->getShape()[1];
+ int in_width = this->in->getShape()[2];
+ int in_channels = this->in->getShape()[3];
+
+ int out_batch = this->out->getShape()[0];
+ int out_height = this->out->getShape()[1];
+ int out_width = this->out->getShape()[2];
+ int out_channels = this->out->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+
+ int padding_top = this->attribute->padding()[0];
+ int padding_bottom = this->attribute->padding()[1];
+ int padding_left = this->attribute->padding()[2];
+ int padding_right = this->attribute->padding()[3];
+ int kernel_h = this->attribute->kernel()[0];
+ int kernel_w = this->attribute->kernel()[1];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+
+ DEBUG_INFO(OP,
+ "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
+ "stride=[%d,%d], padding=[%d,%d,%d,%d]",
+ in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
+ kernel_w, stride_h, stride_w, padding_top, padding_bottom, padding_left, padding_right);
+
+ Eigen::array<Eigen::Index, 2> im2col_input_dims;
+ im2col_input_dims[0] = kernel_h * kernel_w;
+ im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
+
+ Eigen::array<Eigen::Index, 4> col2im_output_dims;
+ col2im_output_dims[0] = out_batch;
+ col2im_output_dims[1] = out_height;
+ col2im_output_dims[2] = out_width;
+ col2im_output_dims[3] = out_channels;
+
+ Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_top, padding_bottom);
+ padding[2] = std::make_pair(padding_left, padding_right);
+ padding[3] = std::make_pair(0, 0);
+
+ ETensor4<InEigenType> input_val = this->in->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ }
+
+ ETensor4<InEigenType> input_padded = input_val.pad(padding);
+
+ // assuming input and output have same scales
+ // so input and output scaling is not required
+ // TODO: check if this assumption TOSA made
+
+ // extract_image_patches() output [N, KH, KW, H * W, C]
+ // transpose to [KH, KW, N, H * W, C]
+ // reshape to [KH * KW, N * H * W * C]
+ ETensor2<InEigenType> input_extract_patches =
+ input_padded.extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID)
+ .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
+ .reshape(im2col_input_dims);
+
+ // 1D result with [N * H * W * C]
+ ETensor1<AccEigenType> out_1d(this->out->getElementCount());
+ out_1d.setZero();
+
+ // sum pool
+ for (size_t i = 0; i < this->out->getElementCount(); i++)
+ {
+ for (int32_t j = 0; j < kernel_h * kernel_w; j++)
+ {
+ out_1d(i) += (AccEigenType)input_extract_patches(j, i);
+ }
+ }
+
+ // reshape result to [N, H, W, C] and divide with div_map
+ ETensor4<AccEigenType> sum = out_1d.reshape(col2im_output_dims);
+
+ // calculate 1d height/width div_map (number of elements this pooling window covers)
+ // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C]
+ ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_h, stride_h);
+ ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_w, stride_w);
+ Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
+ Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
+
+ ETensor4<int32_t> div_map =
+ div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 })
+ .contract(div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width }), contract_dims)
+ .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
+ .broadcast(bcast);
+
+ if (Dtype != DType_FLOAT)
+ {
+ this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType {
+ int32_t multiplier, shift;
+ TosaReference::QuantUtil<AccDtype>::reciprocal_scale(div, multiplier, shift);
+
+ return (OutEigenType)TosaReference::QuantUtil<AccDtype>::apply_scale(value, multiplier, shift, false);
+ });
+ this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp());
+ this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
+ this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
+ }
+ else
+ {
+ this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType WeightDtype>
+OpConv2d<InDtype, WeightDtype>::OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_CONV2D, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(Conv2d);
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType WeightDtype>
+OpConv2d<InDtype, WeightDtype>::~OpConv2d()
+{
+ if (attribute)
+ delete attribute;
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
+ if (inputs[2]->getRank() != 1)
+ {
+ printNodeValidationError("OpConv2d: bias tensor must be rank 1");
+ }
+
+ if (inputs[1]->getIsConst() == 0)
+ {
+ printNodeValidationError("OpConv2d: weight tensor is not const typed");
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ if (!input->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpConv2d: unsupported input tensor format");
+ return 1;
+ }
+
+ if (!weight->hasFormat(Format_OHWI))
+ {
+ printNodeValidationError("OpConv2d: unsupported weight tensor format");
+ return 1;
+ }
+
+ if (attribute->padding().size() != 4)
+ {
+ printNodeValidationError("OpConv2d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpConv2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ if (attribute->dilation().size() != 2)
+ {
+ printNodeValidationError("OpConv2d: illegal size for attribute dilation");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpConv2d<InDtype, WeightDtype>::eval()
+{
+ int in_batch = this->input->getShape()[0];
+ int in_height = this->input->getShape()[1];
+ int in_width = this->input->getShape()[2];
+ int in_channels = this->input->getShape()[3];
+
+ int f_out_channels = this->weight->getShape()[0];
+ int f_height = this->weight->getShape()[1];
+ int f_width = this->weight->getShape()[2];
+ int f_in_channels = this->weight->getShape()[3];
+
+ int b_out_channels = this->bias->getShape()[0];
+
+ int out_batch = this->output->getShape()[0];
+ int out_height = this->output->getShape()[1];
+ int out_width = this->output->getShape()[2];
+ int out_channels = this->output->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ASSERT_MSG_NODE(f_in_channels == in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
+ in_channels);
+ ASSERT_MSG_NODE(f_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
+ out_channels);
+ ASSERT_MSG_NODE(b_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", b_out_channels,
+ out_channels);
+
+ int padding_top = this->attribute->padding()[0];
+ int padding_bottom = this->attribute->padding()[1];
+ int padding_left = this->attribute->padding()[2];
+ int padding_right = this->attribute->padding()[3];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+ int dilation_h = this->attribute->dilation()[0];
+ int dilation_w = this->attribute->dilation()[1];
+
+ DEBUG_INFO(OP,
+ "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
+ "stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d,%d,%d]",
+ in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
+ out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top,
+ padding_bottom, padding_left, padding_right);
+
+ // GEMM-conv2d, left matrix is input, right matrix is weight
+ Eigen::array<Eigen::Index, 2> im2col_input_dims;
+ im2col_input_dims[0] = out_batch * out_height * out_width;
+ im2col_input_dims[1] = f_height * f_width * f_in_channels;
+
+ Eigen::array<Eigen::Index, 2> im2col_weight_dims;
+ im2col_weight_dims[0] = f_height * f_width * f_in_channels;
+ im2col_weight_dims[1] = f_out_channels;
+
+ Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
+ bias_reshaped_dims[0] = 1;
+ bias_reshaped_dims[1] = b_out_channels;
+
+ Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
+ weight_zp_bcast_dims[0] = f_height;
+ weight_zp_bcast_dims[1] = f_width;
+ weight_zp_bcast_dims[2] = f_in_channels;
+
+ Eigen::array<Eigen::Index, 2> bias_bcast_dims;
+ bias_bcast_dims[0] = out_batch * out_height * out_width;
+ bias_bcast_dims[1] = 1;
+
+ Eigen::array<Eigen::Index, 4> col2im_output_dims;
+ col2im_output_dims[0] = out_batch;
+ col2im_output_dims[1] = out_height;
+ col2im_output_dims[2] = out_width;
+ col2im_output_dims[3] = out_channels;
+
+ Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
+
+ Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_top, padding_bottom);
+ padding[2] = std::make_pair(padding_left, padding_right);
+ padding[3] = std::make_pair(0, 0);
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ ETensor4<InEigenType> input_padded = input_val.pad(padding);
+
+ // extract_image_patches() output [N, KH, KW, H * W, C]
+ // need to transpose to [N, H * W, KH, KW, C]
+ ETensor5<InEigenType> input_extract_patches =
+ input_padded
+ .extract_image_patches(f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID)
+ .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
+
+ // reshape input to [N * H * W, KH * KW * C]
+ ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
+
+ // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
+ ETensor2<WeightEigenType> im2col_weight =
+ weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
+
+ // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
+ // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
+ ETensor2<AccEigenType> bias_2d = this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims);
+
+ // output matrix is [N * H * W, C]
+ ETensor2<AccEigenType> contracted_result =
+ im2col_input.template cast<AccEigenType>().contract(im2col_weight.template cast<AccEigenType>(), contract_dims);
+
+ // adding bias
+ ETensor2<AccEigenType> biased_output = contracted_result + bias_2d.template cast<AccEigenType>();
+
+ // reshape back to [N, H, W, C]
+ this->output->getTensor() = biased_output.reshape(col2im_output_dims);
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType WeightDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(Op_DEPTHWISE_CONV2D, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(Conv2d);
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType WeightDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype>::~OpDepthwiseConv2d()
+{
+ if (attribute)
+ delete attribute;
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
+ if (inputs[2]->getRank() != 1)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
+ }
+
+ if (inputs[1]->getIsConst() == 0)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: weight tensor is not const typed");
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ if (!input->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpDepthwiseConv2d: unsupported input tensor format");
+ return 1;
+ }
+
+ if (!weight->hasFormat(Format_HWIM))
+ {
+ printNodeValidationError("OpDepthwiseConv2d: unsupported weight tensor format");
+ return 1;
+ }
+
+ if (attribute->padding().size() != 4)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ if (attribute->dilation().size() != 2)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute dilation");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
+{
+ int in_batch = this->input->getShape()[0];
+ int in_height = this->input->getShape()[1];
+ int in_width = this->input->getShape()[2];
+ int in_channels = this->input->getShape()[3];
+
+ int f_height = this->weight->getShape()[0];
+ int f_width = this->weight->getShape()[1];
+ int f_in_channels = this->weight->getShape()[2];
+ int f_multiplier = this->weight->getShape()[3];
+
+ int b_out_channels = this->bias->getShape()[0];
+
+ int out_batch = this->output->getShape()[0];
+ int out_height = this->output->getShape()[1];
+ int out_width = this->output->getShape()[2];
+ int out_channels = this->output->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ASSERT_MSG_NODE(f_in_channels == in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d",
+ f_in_channels, in_channels);
+ ASSERT_MSG_NODE(in_channels * f_multiplier == out_channels,
+ "OpDepthwiseConv2d: tensor output channel mismatch %d != %d", in_channels * f_multiplier,
+ out_channels);
+ ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d",
+ b_out_channels, out_channels);
+
+ int padding_top = this->attribute->padding()[0];
+ int padding_bottom = this->attribute->padding()[1];
+ int padding_left = this->attribute->padding()[2];
+ int padding_right = this->attribute->padding()[3];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+ int dilation_h = this->attribute->dilation()[0];
+ int dilation_w = this->attribute->dilation()[1];
+
+ DEBUG_INFO(OP,
+ "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
+ "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d,%d,%d]",
+ in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
+ out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top,
+ padding_bottom, padding_left, padding_right);
+
+ Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_top, padding_bottom);
+ padding[2] = std::make_pair(padding_left, padding_right);
+ padding[3] = std::make_pair(0, 0);
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ ETensor4<InEigenType> input_padded = input_val.pad(padding);
+
+ // GEMM doesn't fit well with DepthwiseConv2d
+ // 1. use extract_image_patches() to handle stride/dilation/padding
+ // 2. perform direct convolution
+
+ // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
+ ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
+ f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID);
+
+ Eigen::array<Eigen::Index, 4> reshape_dim;
+ reshape_dim.fill(1);
+ reshape_dim[3] = b_out_channels;
+
+ Eigen::array<Eigen::Index, 4> bcast;
+ bcast[0] = out_batch;
+ bcast[1] = out_height;
+ bcast[2] = out_width;
+ bcast[3] = 1;
+
+ // initialize with bias
+ this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
+
+ // 2. direct depthwise convolution
+ for (int ob = 0; ob < out_batch; ob++)
+ {
+ for (int oh = 0; oh < out_height; oh++)
+ {
+ for (int ow = 0; ow < out_width; ow++)
+ {
+ for (int ic = 0; ic < in_channels; ic++)
+ {
+ for (int cm = 0; cm < f_multiplier; cm++)
+ {
+ for (int fh = 0; fh < f_height; fh++)
+ {
+ for (int fw = 0; fw < f_width; fw++)
+ {
+ this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
+ ((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) *
+ (AccEigenType)weight_val(fh, fw, ic, cm));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType WeightDtype>
+OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(Op_FULLY_CONNECTED, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(2);
+
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType WeightDtype>
+OpFullyConnected<InDtype, WeightDtype>::~OpFullyConnected()
+{
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+
+ if (input->getShape()[1] != weight->getShape()[1])
+ {
+ printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
+ return 1;
+ }
+
+ if (weight->getShape()[0] != bias->getShape()[0])
+ {
+ printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
+ return 1;
+ }
+
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ return 0;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpFullyConnected<InDtype, WeightDtype>::eval()
+{
+ typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
+ Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
+
+ Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
+
+ Eigen::array<Eigen::Index, 2> bias_reshape;
+ bias_reshape[0] = 1;
+ bias_reshape[1] = this->bias->getShape()[0];
+
+ Eigen::array<Eigen::Index, 2> bias_bcast;
+ bias_bcast[0] = this->input->getShape()[0];
+ bias_bcast[1] = 1;
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ this->output->getTensor() =
+ input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims) +
+ this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpMatMul<Dtype>::OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_MATMUL, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(2);
+
+ INIT_QINFO(MatMul);
+}
+
+template <DType Dtype>
+OpMatMul<Dtype>::~OpMatMul()
+{
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType Dtype>
+int OpMatMul<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+
+ if (a->getShape()[1] != b->getShape()[0])
+ {
+ printNodeValidationError("OpMatMul operator a.shape[1] should match b.shape[0]");
+ return 1;
+ }
+
+ c = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ return 0;
+}
+
+template <DType Dtype>
+int OpMatMul<Dtype>::eval()
+{
+ typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
+ Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
+
+ TIn a_val = this->a->getTensor();
+ TIn b_val = this->b->getTensor();
+ if (this->qinfo)
+ {
+ a_val = a_val - (InEigenType)this->qinfo->a_zp();
+ b_val = b_val - (InEigenType)this->qinfo->b_zp();
+ }
+
+ this->c->getTensor() = a_val.template cast<AccEigenType>().contract(b_val.template cast<AccEigenType>(), dims);
+
+ if (AccDtype == DType_INT48)
+ {
+ this->c->getTensor() = this->c->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->c->getTensor() = this->c->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpMaxPool2d<Dtype>::OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_MAX_POOL2D, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(Pool2d);
+}
+
+template <DType Dtype>
+OpMaxPool2d<Dtype>::~OpMaxPool2d()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <DType Dtype>
+int OpMaxPool2d<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (!in->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpMaxPool2d: unsupported tensor format");
+ return 1;
+ }
+
+ if (attribute->padding().size() != 4)
+ {
+ printNodeValidationError("OpMaxPool2d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->kernel().size() != 2)
+ {
+ printNodeValidationError("OpMaxPool2d: illegal size for attribute kernel");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpMaxPool2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType Dtype>
+int OpMaxPool2d<Dtype>::eval()
+{
+ int in_batch = this->in->getShape()[0];
+ int in_height = this->in->getShape()[1];
+ int in_width = this->in->getShape()[2];
+ int in_channels = this->in->getShape()[3];
+
+ int out_batch = this->out->getShape()[0];
+ int out_height = this->out->getShape()[1];
+ int out_width = this->out->getShape()[2];
+ int out_channels = this->out->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+
+ int padding_top = this->attribute->padding()[0];
+ int padding_bottom = this->attribute->padding()[1];
+ int padding_left = this->attribute->padding()[2];
+ int padding_right = this->attribute->padding()[3];
+ int kernel_h = this->attribute->kernel()[0];
+ int kernel_w = this->attribute->kernel()[1];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+
+ DEBUG_INFO(OP,
+ "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
+ "stride=[%d,%d], padding=[%d,%d,%d,%d]",
+ in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
+ kernel_w, stride_h, stride_w, padding_top, padding_bottom, padding_left, padding_right);
+
+ Eigen::array<Eigen::Index, 2> im2col_input_dims;
+ im2col_input_dims[0] = kernel_h * kernel_w;
+ im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
+
+ Eigen::array<Eigen::Index, 4> col2im_output_dims;
+ col2im_output_dims[0] = out_batch;
+ col2im_output_dims[1] = out_height;
+ col2im_output_dims[2] = out_width;
+ col2im_output_dims[3] = out_channels;
+
+ Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_top, padding_bottom);
+ padding[2] = std::make_pair(padding_left, padding_right);
+ padding[3] = std::make_pair(0, 0);
+
+ ETensor4<InEigenType> input_padded = this->in->getTensor().pad(padding, std::numeric_limits<InEigenType>::lowest());
+
+ // extract_image_patches() output [N, KH, KW, H * W, C]
+ // transpose to [KH, KW, N, H * W, C]
+ // reshape to [KH * KW, N * H * W * C]
+ //
+ // Set the padding value to be the most negative value that can be
+ // represented by the datatype to ensure that any padding values will be equal
+ // to or smaller than the actual maximum in the KH x KW patch.
+ ETensor2<InEigenType> input_extract_patches =
+ input_padded
+ .extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID,
+ std::numeric_limits<InEigenType>::lowest())
+ .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
+ .reshape(im2col_input_dims);
+
+ // Get the maximum of the KHxHW patches along axis 0
+ Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
+
+ // 1D result with [N * H * W * C]
+ ETensor1<OutEigenType> out_1d(this->out->getElementCount());
+
+ // index input_patches with argmax array should give the result
+ for (size_t i = 0; i < this->out->getElementCount(); i++)
+ {
+ out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
+ }
+
+ // reshape result to [N, H, W, C]
+ this->out->getTensor() = out_1d.reshape(col2im_output_dims);
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType OutDtype>
+OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(Op_TRANSPOSE_CONV2D, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(TransposeConv2d);
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType OutDtype>
+OpTransposeConv2d<InDtype, OutDtype>::~OpTransposeConv2d()
+{
+ if (attribute)
+ delete attribute;
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (inputs[1]->getIsConst() == 0)
+ {
+ printNodeValidationError("OpTransposeConv2d: weight tensor is not const typed");
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ if (!input->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpTransposeConv2d: unsupported input tensor format");
+ return 1;
+ }
+
+ if (!weight->hasFormat(Format_OHWI))
+ {
+ printNodeValidationError("OpTransposeConv2d: unsupported weight tensor format");
+ return 1;
+ }
+
+ if (attribute->outpad().size() != 2)
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute outpad");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ if (attribute->dilation().size() != 2)
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute dilation");
+ return 1;
+ }
+
+ if (attribute->output_shape().size() != 4)
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
+ return 1;
+ }
+
+ for (int d = 0; d < 4; d++)
+ {
+ if (attribute->output_shape()[d] != this->output->getShape()[d])
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpTransposeConv2d<InDtype, OutDtype>::eval()
+{
+ int in_batch = this->input->getShape()[0];
+ int in_height = this->input->getShape()[1];
+ int in_width = this->input->getShape()[2];
+ int in_channels = this->input->getShape()[3];
+
+ int f_out_channels = this->weight->getShape()[0];
+ int f_height = this->weight->getShape()[1];
+ int f_width = this->weight->getShape()[2];
+ int f_in_channels = this->weight->getShape()[3];
+
+ int b_out_channels = this->bias->getShape()[0];
+
+ int out_batch = this->output->getShape()[0];
+ int out_height = this->output->getShape()[1];
+ int out_width = this->output->getShape()[2];
+ int out_channels = this->output->getShape()[3];
+
+ int padding_top = this->attribute->outpad()[0];
+ int padding_left = this->attribute->outpad()[1];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+ int dilation_h = this->attribute->dilation()[0];
+ int dilation_w = this->attribute->dilation()[1];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ASSERT_MSG_NODE(f_in_channels == in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d",
+ f_in_channels, in_channels);
+ ASSERT_MSG_NODE(f_out_channels == out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
+ f_out_channels, out_channels);
+ ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d",
+ b_out_channels, out_channels);
+
+ DEBUG_INFO(OP,
+ "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
+ "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d]",
+ in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch,
+ out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top,
+ padding_left);
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ Eigen::array<Eigen::Index, 4> reshape_dim;
+ reshape_dim.fill(1);
+ reshape_dim[3] = b_out_channels;
+
+ Eigen::array<Eigen::Index, 4> bcast;
+ bcast[0] = out_batch;
+ bcast[1] = out_height;
+ bcast[2] = out_width;
+ bcast[3] = 1;
+
+ // initialize with bias
+ this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
+
+ int out_x_origin, out_y_origin;
+ int out_x, out_y;
+
+ // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
+ for (int ob = 0; ob < out_batch; ob++)
+ {
+ for (int ih = 0; ih < in_height; ih++)
+ {
+ for (int iw = 0; iw < in_width; iw++)
+ {
+ out_x_origin = iw * stride_w - padding_left;
+ out_y_origin = ih * stride_h - padding_top;
+ for (int ic = 0; ic < in_channels; ic++)
+ {
+ for (int fh = 0; fh < f_height; fh++)
+ {
+ for (int fw = 0; fw < f_width; fw++)
+ {
+ out_x = out_x_origin + fw * dilation_w;
+ out_y = out_y_origin + fh * dilation_h;
+ for (int oc = 0; oc < out_channels; oc++)
+ {
+ if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
+ {
+ this->output->getTensor()(ob, out_y, out_x, oc) +=
+ ((AccEigenType)input_val(ob, ih, iw, ic) *
+ (AccEigenType)weight_val(oc, fh, fw, ic));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+
+DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, FLOAT)
+DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, AINT8)
+DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT16)
+
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, AINT8);
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8);
+
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, AINT8);
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8);
+
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, AINT8);
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT16, INT8);
+
+DEF_INSTANTIATE_ONE_TYPE(OpMatMul, AINT8);
+DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpMatMul, FLOAT);
+
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, AINT8);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
+
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, AINT8);
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT16, INT8);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
new file mode 100644
index 0000000..26ce84b
--- /dev/null
+++ b/reference_model/src/ops/tensor_ops.h
@@ -0,0 +1,253 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_TENSOR_OPS_H
+#define OPS_TENSOR_OPS_H
+
+#include "graph_node.h"
+#include "quant_util.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class OpArgMax : public GraphNode
+{
+public:
+ OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpArgMax();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<DType_INT32>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
+
+protected:
+ TosaAxisAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TOut>* output;
+};
+
+template <DType Dtype>
+class OpAvgPool2d : public GraphNode
+{
+public:
+ OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpAvgPool2d();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
+
+ static constexpr int64_t QMin = GetQMin<Dtype>::value;
+ static constexpr int64_t QMax = GetQMax<Dtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ tosa::TosaPool2dAttribute* attribute;
+ tosa::TosaUnaryQuantInfo* qinfo;
+
+protected:
+ // return a 1D [N] tensor that describes a how many valid elements covered in the input space
+ ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride);
+};
+
+template <DType InDtype, DType WeightDtype>
+class OpConv2d : public GraphNode
+{
+public:
+ OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpConv2d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 4>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 4>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ tosa::TosaConv2dAttribute* attribute;
+ tosa::TosaConvQuantInfo* qinfo;
+};
+
+template <DType InDtype, DType WeightDtype>
+class OpDepthwiseConv2d : public GraphNode
+{
+public:
+ OpDepthwiseConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpDepthwiseConv2d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 4>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 4>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ tosa::TosaConv2dAttribute* attribute;
+ tosa::TosaConvQuantInfo* qinfo;
+};
+
+template <DType InDtype, DType WeightDtype>
+class OpFullyConnected : public GraphNode
+{
+public:
+ OpFullyConnected(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpFullyConnected();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 2>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 2>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 2>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ tosa::TosaConvQuantInfo* qinfo;
+};
+
+template <DType Dtype>
+class OpMatMul : public GraphNode
+{
+public:
+ OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpMatMul();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 2>;
+ using TAcc = Eigen::Tensor<AccEigenType, 2>;
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* a;
+ TosaReference::TensorTemplate<TIn>* b;
+ TosaReference::TensorTemplate<TAcc>* c;
+ tosa::TosaMatMulQuantInfo* qinfo;
+};
+
+template <DType Dtype>
+class OpMaxPool2d : public GraphNode
+{
+public:
+ OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpMaxPool2d();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ tosa::TosaPool2dAttribute* attribute;
+};
+
+template <DType InDtype, DType WeightDtype>
+class OpTransposeConv2d : public GraphNode
+{
+public:
+ OpTransposeConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpTransposeConv2d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 4>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 4>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ TosaTransposeConv2dAttribute* attribute;
+ TosaConvQuantInfo* qinfo;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
new file mode 100644
index 0000000..61a19f4
--- /dev/null
+++ b/reference_model/src/ops/type_conversion.cc
@@ -0,0 +1,299 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "type_conversion.h"
+#include "quant_util.h"
+#include "template_types.h"
+#include <cmath>
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType InDtype, DType OutDtype>
+OpRescale<Rank, InDtype, OutDtype>::OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_RESCALE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+ INIT_ATTRIBUTE(Rescale);
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same rank and size
+ if (inputs[0]->matchRankSize(*outputs[0]))
+ {
+ printNodeValidationError("OpRescale: input and output rank/size must match");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpRescale<Rank, InDtype, OutDtype>::eval()
+{
+ int32_t input_zp = attribute->input_zp();
+ int32_t output_zp = attribute->output_zp();
+ std::vector<int32_t> multiplier = attribute->multiplier();
+ std::vector<int32_t> shift = attribute->shift();
+ //bool scale32 = attribute->scale32();
+ bool double_round = attribute->double_round();
+ bool per_channel = attribute->per_channel();
+
+ if (TosaReference::TypeChecker::is_symmetric(InDtype))
+ {
+ if (input_zp != 0)
+ {
+ FATAL_ERROR_NODE("input tensor is symmetric type %s but zeropoint is %d instead of 0",
+ EnumNamesDType()[InDtype], input_zp);
+ }
+ }
+
+ if (TosaReference::TypeChecker::is_symmetric(OutDtype))
+ {
+ if (output_zp != 0)
+ {
+ FATAL_ERROR_NODE("output tensor is symmetric type %s but zeropoint is %d instead of 0",
+ EnumNamesDType()[OutDtype], output_zp);
+ }
+ }
+
+ // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
+ Eigen::array<Eigen::Index, 2> shape_2d;
+ shape_2d[0] = 1;
+ if (Rank > 0)
+ {
+ for (int i = 0; i < Rank - 1; i++)
+ {
+ shape_2d[0] *= this->in->getShape()[i];
+ }
+ shape_2d[1] = this->in->getShape()[Rank - 1];
+ }
+ else
+ {
+ shape_2d[1] = 1;
+ }
+ ETensor2<InEigenType> input_reshaped = this->in->getTensor().reshape(shape_2d);
+
+ ETensor2<OutEigenType> output_2d(shape_2d);
+
+ // TODO: pass scale32 in when 16-bit mode implemented
+ if (per_channel)
+ {
+ ETensor2<InEigenType> curr_channel_slice_prescaled;
+ ETensor2<OutEigenType> curr_channel_slice_postscaled;
+ int32_t channel_multiplier, channel_shift;
+ Eigen::array<Eigen::Index, 2> begin, size;
+ size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 });
+ for (int32_t i = 0; i < shape_2d[1]; i++)
+ {
+ begin = Eigen::array<Eigen::Index, 2>({ 0, i });
+ curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
+ channel_multiplier = multiplier[i];
+ channel_shift = shift[i];
+ curr_channel_slice_postscaled =
+ curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
+ double_round](InEigenType in_val) -> OutEigenType {
+ InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
+ int32_t scaled = TosaReference::QuantUtil<InDtype>::apply_scale(
+ input_zp_shifted, channel_multiplier, channel_shift, double_round);
+ OutEigenType out_val = (OutEigenType)(scaled + output_zp);
+ out_val = std::max<OutEigenType>(out_val, QMin);
+ out_val = std::min<OutEigenType>(out_val, QMax);
+ return out_val;
+ });
+
+ for (int32_t j = 0; j < shape_2d[0]; j++)
+ {
+ output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
+ }
+ }
+ }
+ else
+ {
+ int32_t tensor_multiplier = multiplier[0];
+ int32_t tensor_shift = shift[0];
+ output_2d = input_reshaped.unaryExpr(
+ [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType {
+ InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
+ int32_t scaled = TosaReference::QuantUtil<InDtype>::apply_scale(input_zp_shifted, tensor_multiplier,
+ tensor_shift, double_round);
+ OutEigenType out_val = (OutEigenType)(scaled + output_zp);
+ out_val = std::max<OutEigenType>(out_val, QMin);
+ out_val = std::min<OutEigenType>(out_val, QMax);
+ return out_val;
+ });
+ }
+
+ // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
+ Eigen::array<Eigen::Index, Rank> output_shape;
+ for (int i = 0; i < Rank; i++)
+ {
+ output_shape[i] = this->out->getShape()[i];
+ }
+ this->out->getTensor() = output_2d.reshape(output_shape);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+OpCast<Rank, InDtype, OutDtype>::OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_CAST, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+OpCast<Rank, InDtype, OutDtype>::~OpCast()
+{}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same rank and size
+ if (inputs[0]->matchRankSize(*outputs[0]))
+ {
+ printNodeValidationError("OpCast: input and output rank/size must match");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpCast<Rank, InDtype, OutDtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType OutDtype>
+CastHelper<InDtype, OutDtype>::CastHelper()
+{
+ fcn = [](InEigenType in) -> OutEigenType {
+ OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t)
+ int64_t mask = (1L << OutBits) - 1;
+ out = out & mask;
+ return out;
+ };
+}
+
+template <DType InDtype>
+CastHelper<InDtype, DType_BOOL>::CastHelper()
+{
+ fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
+}
+
+template <DType OutDtype>
+CastHelper<DType_BOOL, OutDtype>::CastHelper()
+{
+ fcn = [](bool in) -> OutEigenType {
+ OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
+ return out;
+ };
+}
+
+template <DType InDtype>
+CastHelper<InDtype, DType_FLOAT>::CastHelper()
+{
+ fcn = [](InEigenType in) -> float {
+ float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
+ return out;
+ };
+}
+
+template <DType OutDtype>
+CastHelper<DType_FLOAT, OutDtype>::CastHelper()
+{
+ fcn = [](float in) -> OutEigenType {
+ OutEigenType out = std::round(in);
+ out = std::max<OutEigenType>(out, OutMin);
+ out = std::min<OutEigenType>(out, OutMax);
+ return out;
+ };
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, UINT8);
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
new file mode 100644
index 0000000..6ec4d6d
--- /dev/null
+++ b/reference_model/src/ops/type_conversion.h
@@ -0,0 +1,162 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_TYPE_CONVERSION_H
+#define OPS_TYPE_CONVERSION_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+template <int Rank, DType InDtype, DType OutDtype>
+class OpRescale : public GraphNode
+{
+public:
+ OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpRescale();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+ static constexpr int32_t QMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t QMax = GetQMax<OutDtype>::value;
+
+protected:
+ TosaRescaleAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <DType InDtype, DType OutDtype>
+class CastHelper
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutBits = GetNumBits<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType InDtype>
+class CastHelper<InDtype, DType_BOOL>
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType OutDtype>
+class CastHelper<DType_BOOL, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<DType_BOOL>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType InDtype>
+class CastHelper<InDtype, DType_FLOAT>
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<DType_FLOAT>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType OutDtype>
+class CastHelper<DType_FLOAT, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<DType_FLOAT>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <int Rank, DType InDtype, DType OutDtype>
+class OpCast : public GraphNode
+{
+public:
+ OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpCast();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ CastHelper<InDtype, OutDtype> cast_helper;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
new file mode 100644
index 0000000..3638b3b
--- /dev/null
+++ b/reference_model/src/quant_util.h
@@ -0,0 +1,103 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TOSA_REFERENCE_QUANT_UTIL_H
+#define TOSA_REFERENCE_QUANT_UTIL_H
+
+#include "arith_util.h"
+#include "func_debug.h"
+#include "ops/template_types.h"
+#include "tosa_generated.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <DType AccDType>
+class QuantUtil
+{
+public:
+ using T = typename GetEigenType<AccDType>::type;
+
+ static void reciprocal_scale(int32_t value,
+ // Output
+ int32_t& multiplier,
+ int32_t& shift)
+ {
+ ASSERT_MSG(value > 0, "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value);
+ uint32_t value_u32 = (uint32_t)value;
+ int32_t k = 32 - LEADING_ZEROS_32(value_u32 - 1); // (1<<k)/2 < value <= (1<<k)
+ int64_t numerator = ((1L << 30) + 1) << k;
+ multiplier = numerator / value; // (1<<30) <= multiplier < (1<<31)
+ shift = 30 + k;
+ }
+
+ static int32_t apply_scale(T value, int32_t multiplier, int32_t shift, bool enabled_adjusted_rounding = true)
+ {
+ if (AccDType == DType_FLOAT)
+ {
+ return value;
+ }
+ ASSERT_MSG(multiplier >= 0, "apply_scale() error: multiplier should >= 0 but is %d", multiplier);
+ int64_t round = (shift > 0) ? (1L << (shift - 1)) : 0;
+ if (enabled_adjusted_rounding)
+ {
+ if (AccDType != DType_INT48)
+ {
+ if (shift > 31 && value >= 0)
+ round += (1L << 30);
+ if (shift > 31 && value < 0)
+ round -= (1L << 30);
+ }
+ else
+ { // input data could be int16, which leads to 48 bits accumulator
+ ASSERT_MSG(multiplier < (1 << 15), "apply_scale() error: multiplier should <= %d in 48 bit mode",
+ (1 << 15));
+ }
+ }
+ int64_t result = (int64_t)value * multiplier + round;
+ result = result >> shift;
+ ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31),
+ "apply_scale() error: scaled result exceed int32 numeric range");
+ return static_cast<int32_t>(result);
+ }
+};
+
+class TypeChecker
+{
+public:
+ static bool is_integer(DType dtype)
+ {
+ if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_AINT8 || dtype == DType_UINT8 ||
+ dtype == DType_INT16 || dtype == DType_INT32 || dtype == DType_INT48)
+ {
+ return true;
+ }
+ return false;
+ }
+ static bool is_symmetric(DType dtype)
+ {
+ if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_INT16 || dtype == DType_INT32 ||
+ dtype == DType_INT48)
+ {
+ return true;
+ }
+ return false;
+ }
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
new file mode 100644
index 0000000..789bcae
--- /dev/null
+++ b/reference_model/src/subgraph_traverser.cc
@@ -0,0 +1,649 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "subgraph_traverser.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh)
+{
+ block = _block;
+ tsh = _tsh;
+
+ tensors.clear();
+ nodes.clear();
+ nextNodeList.clear();
+}
+
+SubgraphTraverser::~SubgraphTraverser()
+{
+ nextNodeList.clear();
+
+ for (GraphNode* n : nodes)
+ {
+ delete n;
+ }
+ nodes.clear();
+
+ for (TosaReference::Tensor* t : tensors)
+ {
+ if (t->is_allocated())
+ {
+ t->deallocate();
+ }
+ delete t;
+ }
+ tensors.clear();
+}
+
+int SubgraphTraverser::getNumInputTensors() const
+{
+ return inputTensors.size();
+}
+
+TosaReference::Tensor* SubgraphTraverser::getInputTensor(const unsigned int idx) const
+{
+ return inputTensors[idx];
+}
+
+TosaReference::Tensor* SubgraphTraverser::getInputTensorByName(const std::string name) const
+{
+ for (auto t : inputTensors)
+ {
+ if (t->getName() == name)
+ {
+ return t;
+ }
+ }
+
+ return nullptr;
+}
+
+int SubgraphTraverser::getNumOutputTensors() const
+{
+ return outputTensors.size();
+}
+
+TosaReference::Tensor* SubgraphTraverser::getOutputTensor(const unsigned int idx) const
+{
+ return outputTensors[idx];
+}
+
+TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::string name) const
+{
+ for (auto t : outputTensors)
+ {
+ if (t->getName() == name)
+ {
+ return t;
+ }
+ }
+
+ return nullptr;
+}
+
+int SubgraphTraverser::initializeGraph()
+{
+ char tensor_fullname[1000];
+ int idx = 0;
+ for (auto op : block->GetOperators())
+ {
+ // translated TosaSerializationOperator to GraphNode
+ DType in_dtype = DType_UNKNOWN, out_dtype = DType_UNKNOWN, weight_dtype = DType_UNKNOWN;
+ uint32_t in_rank = 0, out_rank = 0, weight_rank = 0;
+ for (auto name : op->GetInputTensorNames())
+ {
+
+ TosaSerializationTensor* ts = block->GetTensorByName(name);
+ ASSERT_MSG(ts, "SubgraphTraverser: fail to get tensor %s from TosaSerializationHandler", name.c_str());
+
+ if (ts->HasUsage(Usage_WEIGHT))
+ {
+ weight_dtype = ts->GetDtype();
+ weight_rank = ts->GetShape().size();
+ }
+ else if (ts->HasUsage(Usage_INDEX))
+ {
+ // do nothing, but this will prevent tensor's dtype/rank being wrongly used as template argument when initializing this op
+ }
+ else if (ts->HasUsage(Usage_ACTIVATION))
+ {
+ if (ts->GetShape().size() >= in_rank)
+ {
+ in_dtype = ts->GetDtype();
+ in_rank = ts->GetShape().size();
+ }
+ }
+ }
+
+ for (auto name : op->GetOutputTensorNames())
+ {
+
+ TosaSerializationTensor* ts = block->GetTensorByName(name);
+ ASSERT_MSG(ts, "SubgraphTraverser: fail to get tensor %s from TosaSerializationHandler", name.c_str());
+
+ out_dtype = ts->GetDtype();
+ out_rank = ts->GetShape().size();
+ }
+
+ DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
+ EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
+
+ GraphNode* cn = OpFactory::newOp(tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, in_dtype, in_rank,
+ out_dtype, out_rank, weight_dtype, weight_rank);
+ if (!cn)
+ {
+ if (weight_dtype == DType_UNKNOWN && weight_rank == 0)
+ {
+ fprintf(g_func_debug.func_debug_file,
+ "OpFactory could not allocate op %8s input=(%s rank %d) -> (%s rank %d)",
+ EnumNamesOp()[op->GetOp()], EnumNamesDType()[in_dtype], in_rank, EnumNamesDType()[out_dtype],
+ out_rank);
+ }
+ else
+ {
+ fprintf(g_func_debug.func_debug_file,
+ "OpFactory could not allocate op %8s input=(%s rank %d), weight=(%s rank %d) -> (%s rank %d)",
+ EnumNamesOp()[op->GetOp()], EnumNamesDType()[in_dtype], in_rank, EnumNamesDType()[weight_dtype],
+ weight_rank, EnumNamesDType()[out_dtype], out_rank);
+ }
+
+ for (auto ts : op->GetInputTensors())
+ {
+ fprintf(g_func_debug.func_debug_file, "Input: %s\n", ts->GetName().c_str());
+ }
+
+ for (auto ts : op->GetOutputTensors())
+ {
+ fprintf(g_func_debug.func_debug_file, "Output: %s\n", ts->GetName().c_str());
+ }
+ FATAL_ERROR("Unsupported operation type or rank.");
+ }
+
+ for (auto name : op->GetInputTensorNames())
+ {
+ cn->addInputName(name);
+ }
+
+ for (auto name : op->GetOutputTensorNames())
+ {
+ cn->addOutputName(name);
+ }
+
+ addNode(cn);
+
+ // if node doesn't have any inputs (i.e. CONST)
+ // it should be ready for evaluation
+ if (op->GetInputTensorNames().empty() && !cn->getOnNextNodeList())
+ {
+ addToNextNodeList(cn);
+ }
+
+ idx++;
+ }
+
+ for (auto ts : block->GetTensors())
+ {
+
+ bool is_const = false;
+ if (ts->HasUsage(Usage_WEIGHT))
+ {
+ is_const = true;
+ }
+
+ DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
+ TosaReference::Tensor* ct =
+ TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetUsage(), ts->GetFormat(), ts->GetShape(),
+ is_const, ts->GetShape().size());
+
+ if (ts->GetNpyFilePtr())
+ {
+ if (ct->allocate())
+ {
+ FATAL_ERROR("Fail to allocate Eigen tensor %s", ct->getName().c_str());
+ }
+
+ bzero(tensor_fullname, sizeof(tensor_fullname));
+ snprintf(tensor_fullname, sizeof(tensor_fullname), "%s/%s", g_func_config.subgraph_dir,
+ ts->GetNpyFilePtr()->c_str());
+ if (ct->readFromNpyFile(tensor_fullname))
+ {
+ FATAL_ERROR("Cannot read input data into graph tensor %s from block %s", ct->getName().c_str(),
+ block->GetName().c_str());
+ }
+ }
+
+ // update this->tensors
+ addTensor(ct);
+ }
+
+ DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
+ for (auto& input_name : block->GetInputs())
+ {
+ TosaReference::Tensor* ct = findTensorByName(input_name);
+ DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str());
+ if (ct)
+ {
+ ct->setIsSubgraphInput();
+ inputTensors.push_back(ct);
+ }
+ else
+ {
+ FATAL_ERROR("loadGraphJson: Fail to find input tensor by name %s", input_name.c_str());
+ }
+ }
+
+ DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str());
+ for (auto& output_name : block->GetOutputs())
+ {
+ TosaReference::Tensor* ct = findTensorByName(output_name);
+ DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str());
+ if (ct)
+ {
+ ct->setIsSubgraphOutput();
+ outputTensors.push_back(ct);
+ }
+ else
+ {
+ FATAL_ERROR("loadGraphJson: Fail to find output tensor by name %s", output_name.c_str());
+ }
+ }
+
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ dumpNextNodeList(g_func_debug.func_debug_file);
+ }
+
+ return 0;
+}
+
+int SubgraphTraverser::isFullyEvaluated() const
+{
+ return nextNodeList.empty();
+}
+
+GraphNode* SubgraphTraverser::getNextNode()
+{
+ GraphNode* nextNode = nextNodeList.front();
+ ASSERT_MSG(nextNode, "SubgraphTraverser::getNextNode(): called with empty next node list");
+ ASSERT_MSG(nextNode->getOnNextNodeList(),
+ "SubgraphTraverser::getNextNode(): internal state error: node is not listed as being on next node list");
+
+ nextNodeList.pop_front();
+
+ nextNode->clearOnNextNodeList();
+ return nextNode;
+}
+
+int SubgraphTraverser::addToNextNodeList(GraphNode* nextNode)
+{
+ ASSERT_MSG(nextNode, "SubgraphTraverser::addToNextNodeList(): called with no node");
+ ASSERT_MSG(!nextNode->getOnNextNodeList(),
+ "SubgraphTraverser::addToNextNodeList(): internal state error: node is already on next node list");
+
+ nextNode->setOnNextNodeList();
+ nextNodeList.push_back(nextNode);
+
+ return 0;
+}
+
+int SubgraphTraverser::evaluateNextNode()
+{
+ if (isFullyEvaluated())
+ return 0;
+
+ GraphNode* currNode = getNextNode();
+
+ DEBUG_INFO(GT, "Evaluating node_%03lu, %8s, output tensor=%s", currNode->getID(), EnumNamesOp()[currNode->getOp()],
+ currNode->getOutputNames()[0].c_str());
+
+ // Sanity check for never-ending loops
+ if (currNode->getEvalCount() >= MAX_EVAL_COUNT && (currNode->getEvalCount() % MAX_EVAL_COUNT) == 0)
+ {
+ WARNING("Node %lu has been evaluated %d times. Loop suspected.", currNode->getID(), currNode->getEvalCount());
+ }
+
+ for (auto ct : currNode->getOutputs())
+ {
+ if (!ct->is_allocated())
+ if (ct->allocate())
+ {
+ FATAL_ERROR("Fail to allocate Eigen tensor %s", ct->getName().c_str());
+ }
+ }
+
+ if (currNode->eval())
+ {
+ FATAL_ERROR("Error evaluating node: %lu\n", currNode->getID());
+ }
+
+ // free input tensor if all of its consumers have all of their outputs ready and it's not block's output
+ for (auto ct : currNode->getInputs())
+ {
+ bool in_use = false;
+ for (auto cn : ct->getConsumers())
+ {
+ if (!cn->hasAllOutputsReady())
+ {
+ in_use = true;
+ }
+ }
+ for (auto name : block->GetOutputs())
+ {
+ if (name == ct->getName())
+ {
+ in_use = true;
+ }
+ }
+ if (!in_use)
+ {
+ ct->deallocate();
+ }
+ }
+
+ // Search the output tensors of this node to see if
+ // there are now new ready nodes available from completing this node
+ for (TosaReference::Tensor* tensor : currNode->getOutputs())
+ {
+ for (GraphNode* node : tensor->getConsumers())
+ {
+ if (!node->getOnNextNodeList() && node->hasAllInputsReady())
+ {
+ addToNextNodeList(node);
+ }
+ }
+ }
+
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ dumpNextNodeList(g_func_debug.func_debug_file);
+ }
+
+ if (g_func_config.dump_intermediates)
+ {
+ currNode->dumpNode(g_func_debug.func_debug_file);
+ for (auto outs : currNode->getOutputs())
+ {
+ outs->dumpTensorParams(g_func_debug.func_debug_file);
+ outs->dumpTensor(g_func_debug.func_debug_file);
+ fprintf(g_func_debug.func_debug_file, "\n");
+ }
+ }
+
+ return 0;
+}
+
+int SubgraphTraverser::dumpNextNodeList(FILE* out) const
+{
+
+ // Dump next node list
+ fprintf(out, "Next node list\n");
+
+ if (nextNodeList.empty())
+ {
+ fprintf(out, "<empty>\n");
+ }
+
+ for (auto gn : nextNodeList)
+ {
+ gn->dumpNode(out);
+ }
+
+ fprintf(out, "Done.\n");
+ return 0;
+}
+
+int SubgraphTraverser::clearAllNodeMarkings()
+{
+ for (GraphNode* currNode : nodes)
+ {
+ currNode->clearNodeMarked();
+ }
+
+ return false;
+}
+
+int SubgraphTraverser::addTensor(TosaReference::Tensor* ct)
+{
+ // Enforce no duplicate tensors/tensor names
+ // O(N), but the number of tensors is small
+ for (TosaReference::Tensor* currTensor : tensors)
+ {
+ if (ct == currTensor || currTensor->getName() == ct->getName())
+ {
+ FATAL_ERROR("Error: Duplicate tensor or tensor name being added to graph: %s\n", ct->getName().c_str());
+ return 1;
+ }
+ }
+
+ tensors.push_back(ct);
+
+ if (ct->getIsSubgraphInput())
+ {
+ inputTensors.push_back(ct);
+ }
+
+ if (ct->getIsSubgraphOutput())
+ {
+ outputTensors.push_back(ct);
+ }
+
+ return 0;
+}
+int SubgraphTraverser::addNode(GraphNode* newNode)
+{
+ // Enforce no duplicate nodes
+ for (GraphNode* currNode : nodes)
+ {
+ if (currNode == newNode)
+ {
+ FATAL_ERROR("Error: duplicate node being added to graph");
+ return 1;
+ }
+ }
+
+ nodes.push_back(newNode);
+
+ return 0;
+}
+
+TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const
+{
+ for (TosaReference::Tensor* currTensor : tensors)
+ {
+ if (currTensor->getName() == name)
+ {
+ return currTensor;
+ }
+ }
+
+ WARNING("Unable to find tensor with name: %s\n", name.c_str());
+
+ return nullptr;
+}
+
+int SubgraphTraverser::linkTensorsAndNodes()
+{
+ // Nodes have a list of input/output tensor names
+ // For each node, read this list, link up the tensors with their inputs/outputs
+ for (GraphNode* currNode : nodes)
+ {
+
+ // Link inputs/consuming nodes
+ for (std::string& name : currNode->getInputNames())
+ {
+ TosaReference::Tensor* t = findTensorByName(name);
+ if (!t)
+ {
+ FATAL_ERROR("linkTensorsAndNodes: Cannot find tensor %s in node %lu\n", name.c_str(),
+ currNode->getID());
+ return 1;
+ }
+
+ if (currNode->addInputTensor(t))
+ {
+ FATAL_ERROR("linkTensorsAndNodes: cannot link tensor %s to node %lu\n", name.c_str(),
+ currNode->getID());
+ return 1;
+ }
+
+ if (t->addConsumer(currNode))
+ {
+ FATAL_ERROR("linkTensorsAndNodes: cannot link consumer node %lu to tensor %s\n", currNode->getID(),
+ name.c_str());
+ return 1;
+ }
+ }
+
+ // Link outputs/producing nodes
+ for (std::string& name : currNode->getOutputNames())
+ {
+ TosaReference::Tensor* t = findTensorByName(name);
+ if (!t)
+ {
+ FATAL_ERROR("linkTensorsAndNodes: Cannot find tensor %s in node %lu\n", name.c_str(),
+ currNode->getID());
+ return 1;
+ }
+
+ if (currNode->addOutputTensor(t))
+ {
+ FATAL_ERROR("linkTensorsAndNodes: cannot link tensor %s to node %lu\n", name.c_str(),
+ currNode->getID());
+ return 1;
+ }
+
+ if (t->setProducer(currNode))
+ {
+ FATAL_ERROR("linkTensorsAndNodes: cannot link producer node %lu to tensor tensor %s\n",
+ currNode->getID(), name.c_str());
+ return 1;
+ }
+ }
+ }
+
+ return 0;
+}
+
+int SubgraphTraverser::validateGraph()
+{
+ // Need to make sure that:
+ // - each tensor is actually used
+ // - input and output tesnsors truly are just input and just output
+ // Graph building already determined that each node has found its input/output tensors
+
+ for (TosaReference::Tensor* currTensor : tensors)
+ {
+
+ if (!currTensor->getProducer() && currTensor->getConsumers().empty())
+ {
+ WARNING("Graph inconsistency: TosaReference::Tensor %s has no producers or consumers\n",
+ currTensor->getName().c_str());
+ return 1;
+ }
+
+ if (currTensor->getIsSubgraphInput())
+ {
+ if (currTensor->getProducer() && currTensor->getProducer()->getOp() != Op_PLACEHOLDER)
+ {
+ WARNING("Graph inconsistency: TosaReference::Tensor %s is a subgraph input and has a producer\n",
+ currTensor->getName().c_str());
+ return 1;
+ }
+ }
+
+ // comment this check out as this is possible when graph have multiple output
+ // for example:
+ // %0 = add(%arg0, %arg1)
+ // %1 = mul(%arg0, %0)
+ // yields(%0, %1)
+ //if (currTensor->getIsSubgraphOutput()) {
+ // if (!currTensor->getConsumers().empty()) {
+ // WARNING ("Graph inconsistency: TosaReference::Tensor %s is a subgraph output and has a consumer\n",
+ // currTensor->getName().c_str());
+ // return 1;
+ // }
+ //}
+
+ if (g_func_config.tosa_profile == 0)
+ {
+ DType dtype = currTensor->getDtype();
+
+ // Float-point disallowed
+ if (dtype == DType_FLOAT)
+ {
+ WARNING("TOSA Base Inference profile selected: All floating point disabled, but %s tensor %s found\n",
+ EnumNamesDType()[dtype], currTensor->getName().c_str());
+ return 1;
+ }
+ }
+ else if (g_func_config.tosa_profile == 1 || g_func_config.tosa_profile == 2)
+ {
+ // Do nothing. All FP types allowed
+ // Currently no implementation difference between Main Inference and Main Training modes
+ }
+ else
+ {
+ FATAL_ERROR("TOSA profile not recognized: %d", g_func_config.tosa_profile);
+ }
+ }
+
+ for (GraphNode* currNode : nodes)
+ {
+ if (currNode->checkTensorAttributes())
+ {
+ WARNING("TosaReference::Tensor attribute check failed");
+ return 1;
+ }
+ }
+
+ if (outputTensors.size() <= 0)
+ {
+ DEBUG_MED(GT, "Graph output tensor empty");
+ return 0;
+ }
+
+ return 0;
+}
+
+int SubgraphTraverser::dumpGraph(FILE* out) const
+{
+ int i = 0;
+
+ fprintf(out, "Full graph dump:\n");
+ for (GraphNode* currNode : nodes)
+ {
+ fprintf(out, "Node [%d]: ", i++);
+ currNode->dumpNode(out);
+ }
+
+ return 0;
+}
+
+int SubgraphTraverser::evaluateAll()
+{
+ // evaluation loop
+ while (!isFullyEvaluated())
+ {
+ if (evaluateNextNode())
+ {
+ return 1;
+ }
+ }
+
+ return 0;
+}
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
new file mode 100644
index 0000000..3f4eecf
--- /dev/null
+++ b/reference_model/src/subgraph_traverser.h
@@ -0,0 +1,90 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SUBGRAPH_TRAVERSER_H
+#define SUBGRAPH_TRAVERSER_H
+
+#include "model_common.h"
+
+#include "graph_node.h"
+#include "ops/op_factory.h"
+#include "tosa_serialization_handler.h"
+
+namespace TosaReference
+{
+
+class SubgraphTraverser
+{
+public:
+ SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh);
+ ~SubgraphTraverser();
+
+ int initializeGraph();
+ int isFullyEvaluated() const;
+ int evaluateNextNode();
+ int evaluateAll();
+
+ int linkTensorsAndNodes();
+ int validateGraph();
+
+ int dumpGraph(FILE* out) const;
+ int dumpNextNodeList(FILE* out) const;
+ int clearAllNodeMarkings();
+
+ int getNumInputTensors() const;
+ Tensor* getInputTensor(const unsigned int idx) const;
+ Tensor* getInputTensorByName(const std::string name) const;
+ int getNumOutputTensors() const;
+ Tensor* getOutputTensor(const unsigned int idx) const;
+ Tensor* getOutputTensorByName(const std::string name) const;
+ int addToNextNodeList(GraphNode*);
+
+private:
+ int addTensor(Tensor* ct);
+ int addNode(GraphNode* cn);
+
+ Tensor* findTensorByName(const std::string& name) const;
+
+ GraphNode* getNextNode();
+
+ // pointer to serialization library and corresponding basic block
+ TosaSerializationBasicBlock* block;
+ TosaSerializationHandler* tsh;
+
+ // The definitive list of all tensors
+ std::vector<Tensor*> tensors;
+
+ // The subset of tensors that are also input tensors
+ std::vector<Tensor*> inputTensors;
+
+ // The subset of tensors that are also output tensors
+ std::vector<Tensor*> outputTensors;
+
+ // The definitive list of all nodes in the graph
+ std::vector<GraphNode*> nodes;
+
+ // The subset of node that have all of their input tensors ready, but
+ // have not yet been evaluated to produce their output tensors.
+ // With control flow, a node may appear on this list more than once during its
+ // lifetime, although the list itself should only contain unique nodes.
+ std::list<GraphNode*> nextNodeList;
+
+ // Maximum number of times to evalute a node before
+ // warning.
+ const int MAX_EVAL_COUNT = 10000;
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
new file mode 100644
index 0000000..179484e
--- /dev/null
+++ b/reference_model/src/tensor.cc
@@ -0,0 +1,3008 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tensor.h"
+#include "arith_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+TosaReference::Tensor::Tensor(std::string tensorName_,
+ DType tensorDtype_,
+ const std::vector<Usage>& tensorUsage_,
+ const std::vector<Format>& tensorFormat_,
+ std::vector<int> shape_,
+ int isConst_)
+{
+ tensorName = std::string(tensorName_);
+ tensorDtype = tensorDtype_;
+ tensorUsage = std::vector<Usage>(tensorUsage_);
+ tensorFormat = std::vector<Format>(tensorFormat_);
+ shape = std::vector<int>(shape_);
+ isConst = isConst_;
+ producer = nullptr;
+ isValid = false;
+ consumers.clear();
+ isSubgraphInput = false;
+ isSubgraphOutput = false;
+}
+
+TosaReference::Tensor::~Tensor()
+{}
+
+int TosaReference::Tensor::setIsSubgraphInput()
+{
+ isSubgraphInput = true;
+ return 0;
+}
+
+int TosaReference::Tensor::setIsSubgraphOutput()
+{
+ isSubgraphOutput = true;
+ return 0;
+}
+
+int TosaReference::Tensor::setProducer(GraphNode* node)
+{
+ ASSERT_MSG(node, "Tensor::setProducer: no node passed in");
+ ASSERT_MSG(!producer, "Tensor::setProducer: producer node already set, tensor %s", tensorName.c_str());
+ producer = node;
+
+ return 0;
+}
+
+int TosaReference::Tensor::addConsumer(GraphNode* node)
+{
+ ASSERT_MSG(node, "Tensor::addConsumer: no node passed in");
+ consumers.push_back(node);
+
+ return 0;
+}
+
+int TosaReference::Tensor::dumpTensorParams(FILE* out) const
+{
+ fprintf(out, "Name: %s DType=%s Usage=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(),
+ EnumNamesDType()[getDtype()], getUsageAsString().c_str(), getIsValid(), getRank(),
+ getShapeAsString().c_str());
+
+ return 0;
+}
+
+int TosaReference::Tensor::dumpTensorParams(std::ostream& out) const
+{
+ out << "Name: " << getName() << " DType=" << EnumNamesDType()[getDtype()] << " Usage=" << getUsageAsString()
+ << " isValid=" << getIsValid() << " Rank=" << getRank() << " Shape=" << getShapeAsString() << "\n";
+
+ return 0;
+}
+
+int TosaReference::Tensor::readFromNpyFile(const char* filename)
+{
+ uint32_t elements = getElementCount();
+ float* fdatabuf = nullptr;
+ int32_t* i32databuf = nullptr;
+ int64_t* i64databuf = nullptr;
+ bool* bdatabuf = nullptr;
+ NumpyUtilities::NPError nperror;
+
+ switch (getDtype())
+ {
+ case DType_FLOAT:
+ fdatabuf = (float*)calloc(sizeof(float), elements);
+ ASSERT_MEM(fdatabuf);
+
+ nperror = NumpyUtilities::readFromNpyFile(filename, elements, fdatabuf);
+ break;
+ case DType_INT32:
+ case DType_AINT8:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ i32databuf = (int32_t*)calloc(sizeof(int32_t), elements);
+ ASSERT_MEM(i32databuf);
+
+ nperror = NumpyUtilities::readFromNpyFile(filename, elements, i32databuf);
+ break;
+ case DType_INT48:
+ i64databuf = (int64_t*)calloc(sizeof(int64_t), elements);
+ ASSERT_MEM(i64databuf);
+
+ nperror = NumpyUtilities::readFromNpyFile(filename, elements, i64databuf);
+ break;
+ case DType_BOOL:
+ bdatabuf = (bool*)calloc(sizeof(bool), elements);
+ ASSERT_MEM(bdatabuf);
+
+ nperror = NumpyUtilities::readFromNpyFile(filename, elements, bdatabuf);
+ break;
+ default:
+ FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]);
+ }
+
+ switch (nperror)
+ {
+ case NumpyUtilities::NO_ERROR:
+ break;
+ case NumpyUtilities::FILE_NOT_FOUND:
+ FATAL_ERROR("readFromNpyFile: Cannot open file %s", filename);
+ case NumpyUtilities::FILE_IO_ERROR:
+ FATAL_ERROR("readFromNpyFile: IO error reading file: %s", filename);
+ case NumpyUtilities::FILE_TYPE_MISMATCH:
+ FATAL_ERROR("readFromNpyFile: Tensor type %s and Numpy file type mismatch for tensor %s filename %s",
+ EnumNamesDType()[getDtype()], getName().c_str(), filename);
+ case NumpyUtilities::HEADER_PARSE_ERROR:
+ FATAL_ERROR("Numpy header parsing error for file: %s", filename);
+ case NumpyUtilities::BUFFER_SIZE_MISMATCH:
+ FATAL_ERROR("Buffer size does not match numpy file size for tensor %s filename %s", getName().c_str(),
+ filename);
+ default:
+ FATAL_ERROR("Unknown error parsing Numpy file: %s", filename);
+ }
+
+ switch (getDtype())
+ {
+ case DType_FLOAT:
+ if (setTensorValueFloat(elements, fdatabuf))
+ {
+ free(fdatabuf);
+ return 1;
+ }
+ break;
+ case DType_INT32:
+ case DType_AINT8:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ if (setTensorValueInt32(elements, i32databuf))
+ {
+ free(i32databuf);
+ return 1;
+ }
+ break;
+ case DType_INT48:
+ if (setTensorValueInt64(elements, i64databuf))
+ {
+ free(i64databuf);
+ return 1;
+ }
+ break;
+ case DType_BOOL:
+ if (setTensorValueBool(elements, bdatabuf))
+ {
+ free(i32databuf);
+ return 1;
+ }
+ break;
+ default:
+ FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]);
+ }
+
+ setIsValid();
+
+ if (fdatabuf)
+ free(fdatabuf);
+ if (i32databuf)
+ free(i32databuf);
+ if (i64databuf)
+ free(i64databuf);
+ if (bdatabuf)
+ free(bdatabuf);
+
+ return 0;
+}
+
+int TosaReference::Tensor::writeToNpyFile(const char* filename) const
+{
+ float* fdatabuf = nullptr;
+ int32_t* i32databuf = nullptr;
+ int64_t* i64databuf = nullptr;
+ bool* bdatabuf = nullptr;
+ NumpyUtilities::NPError nperror;
+ int elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case DType_FLOAT:
+ fdatabuf = (float*)calloc(sizeof(float), elements);
+ ASSERT_MEM(fdatabuf);
+
+ if (getTensorValueFloat(elements, fdatabuf))
+ {
+ free(fdatabuf);
+ return 1;
+ }
+
+ nperror = NumpyUtilities::writeToNpyFile(filename, shape, fdatabuf);
+
+ free(fdatabuf);
+ break;
+ case DType_INT32:
+ case DType_AINT8:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ i32databuf = (int32_t*)calloc(sizeof(int32_t), elements);
+ ASSERT_MEM(i32databuf);
+
+ if (getTensorValueInt32(elements, i32databuf))
+ {
+ free(i32databuf);
+ return 1;
+ }
+
+ nperror = NumpyUtilities::writeToNpyFile(filename, shape, i32databuf);
+
+ free(i32databuf);
+ break;
+ case DType_INT48:
+ i64databuf = (int64_t*)calloc(sizeof(int64_t), elements);
+ ASSERT_MEM(i64databuf);
+
+ if (getTensorValueInt64(elements, i64databuf))
+ {
+ free(i64databuf);
+ return 1;
+ }
+
+ nperror = NumpyUtilities::writeToNpyFile(filename, shape, i64databuf);
+
+ free(i64databuf);
+ break;
+ case DType_BOOL:
+ bdatabuf = (bool*)calloc(sizeof(bool), elements);
+ ASSERT_MEM(bdatabuf);
+
+ if (getTensorValueBool(elements, bdatabuf))
+ {
+ free(bdatabuf);
+ return 1;
+ }
+
+ nperror = NumpyUtilities::writeToNpyFile(filename, shape, bdatabuf);
+
+ free(bdatabuf);
+ break;
+ default:
+ FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]);
+ }
+
+ switch (nperror)
+ {
+ case NumpyUtilities::NO_ERROR:
+ break;
+ case NumpyUtilities::FILE_NOT_FOUND:
+ FATAL_ERROR("writeToNpyFile: Cannot open output file %s", filename);
+ case NumpyUtilities::FILE_IO_ERROR:
+ FATAL_ERROR("writeToNpyFile: IO error writing file: %s", filename);
+ case NumpyUtilities::FILE_TYPE_MISMATCH:
+ FATAL_ERROR("writeToNpyFile: Tensor type and Numpy file type mismatch for tensor %s filename %s",
+ getName().c_str(), filename);
+ case NumpyUtilities::HEADER_PARSE_ERROR:
+ FATAL_ERROR("Numpy header parsing error for file: %s", filename);
+ case NumpyUtilities::BUFFER_SIZE_MISMATCH:
+ FATAL_ERROR("Buffer size does not match numpy file size for tensor %s filename %s", getName().c_str(),
+ filename);
+ default:
+ FATAL_ERROR("Unknown error writing Numpy file: %s", filename);
+ }
+
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::copyValueFrom(TosaReference::Tensor* src)
+{
+ FATAL_ERROR("TensorTemplate<T>::copyValueFrom should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+#define DEF_CTENSOR_COPY_VALUE_FROM(RANK, TYPE) \
+ template <> \
+ int TosaReference::Tensor##RANK<TYPE>::copyValueFrom(TosaReference::Tensor* src) \
+ { \
+ TosaReference::Tensor##RANK<TYPE>* t = dynamic_cast<Tensor##RANK<TYPE>*>(src); \
+ if (!t) \
+ { \
+ WARNING("tensor %s templated class does not match %s", src->getName().c_str(), this->getName().c_str()); \
+ return 1; \
+ } \
+ \
+ uint32_t src_rank = src->getRank(); \
+ uint32_t dst_rank = this->getRank(); \
+ DType src_dtype = src->getDtype(); \
+ DType dst_dtype = this->getDtype(); \
+ bool tensor_match = true; \
+ \
+ if ((src_rank != dst_rank) || (src_dtype != dst_dtype)) \
+ { \
+ tensor_match = false; \
+ } \
+ else \
+ { \
+ for (uint32_t i = 0; i < src_rank; i++) \
+ { \
+ int src_dim = src->getShape()[i]; \
+ int dst_dim = this->getShape()[i]; \
+ if (src_dim != dst_dim) \
+ { \
+ tensor_match = false; \
+ } \
+ } \
+ } \
+ \
+ if (!tensor_match) \
+ { \
+ WARNING("source tensor %s (rank=%u, dtype=%s, shape=%s) doesn't match destination tensor %s (rank=%u, " \
+ "dtype=%s, shape=%s)", \
+ src->getName().c_str(), src_rank, EnumNamesDType()[src_dtype], src->getShapeAsString().c_str(), \
+ this->getName().c_str(), dst_rank, EnumNamesDType()[dst_dtype], this->getShapeAsString().c_str()); \
+ return 1; \
+ } \
+ \
+ this->getTensor() = t->getTensor(); \
+ return 0; \
+ }
+
+DEF_CTENSOR_COPY_VALUE_FROM(0, float)
+DEF_CTENSOR_COPY_VALUE_FROM(1, float)
+DEF_CTENSOR_COPY_VALUE_FROM(2, float)
+DEF_CTENSOR_COPY_VALUE_FROM(3, float)
+DEF_CTENSOR_COPY_VALUE_FROM(4, float)
+DEF_CTENSOR_COPY_VALUE_FROM(5, float)
+DEF_CTENSOR_COPY_VALUE_FROM(6, float)
+DEF_CTENSOR_COPY_VALUE_FROM(0, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(1, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(2, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(3, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(4, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(5, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(6, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(0, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(1, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(2, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(3, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(4, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(5, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(6, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(0, bool)
+DEF_CTENSOR_COPY_VALUE_FROM(1, bool)
+DEF_CTENSOR_COPY_VALUE_FROM(2, bool)
+DEF_CTENSOR_COPY_VALUE_FROM(3, bool)
+DEF_CTENSOR_COPY_VALUE_FROM(4, bool)
+DEF_CTENSOR_COPY_VALUE_FROM(5, bool)
+DEF_CTENSOR_COPY_VALUE_FROM(6, bool)
+
+#undef DEF_CTENSOR_COPY_VALUE_FROM
+
+template <class T>
+int TosaReference::TensorTemplate<T>::setTensorValueFloat(const size_t buflen, const float* vals)
+{
+ FATAL_ERROR("TensorTemplate<T>::setTensorValueFloat should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ (*tensor)(0) = vals[0];
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ (*tensor)(i0) = vals[idx++];
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ (*tensor)(i0, i1) = vals[idx++];
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ (*tensor)(i0, i1, i2) = vals[idx++];
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ (*tensor)(i0, i1, i2, i3) = vals[idx++];
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ FATAL_ERROR("TensorTemplate<T>::setTensorValueInt32 should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ (*tensor)(0) = vals[0];
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ (*tensor)(i0) = vals[idx++];
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ (*tensor)(i0, i1) = vals[idx++];
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ (*tensor)(i0, i1, i2) = vals[idx++];
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ (*tensor)(i0, i1, i2, i3) = vals[idx++];
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ FATAL_ERROR("TensorTemplate<T>::setTensorValueInt64 should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ (*tensor)(0) = vals[0];
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ (*tensor)(i0) = vals[idx++];
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ (*tensor)(i0, i1) = vals[idx++];
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ (*tensor)(i0, i1, i2) = vals[idx++];
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ (*tensor)(i0, i1, i2, i3) = vals[idx++];
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::setTensorValueBool(const size_t buflen, const bool* vals)
+{
+ FATAL_ERROR("TensorTemplate<T>::setTensorValueBool should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ (*tensor)(0) = vals[0];
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ (*tensor)(i0) = vals[idx++];
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ (*tensor)(i0, i1) = vals[idx++];
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ (*tensor)(i0, i1, i2) = vals[idx++];
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ (*tensor)(i0, i1, i2, i3) = vals[idx++];
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ FATAL_ERROR("TensorTemplate<T>::getTensorValueFloat should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ int totalVals = 1;
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ vals[0] = (*tensor)(0);
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ vals[idx++] = (*tensor)(i0);
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ vals[idx++] = (*tensor)(i0, i1);
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2);
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3);
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4);
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5);
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ FATAL_ERROR("TensorTemplate<T>::getTensorValueInt32 should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ int totalVals = 1;
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ vals[0] = (*tensor)(0);
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ vals[idx++] = (*tensor)(i0);
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ vals[idx++] = (*tensor)(i0, i1);
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2);
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3);
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4);
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5);
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ FATAL_ERROR("TensorTemplate<T>::getTensorValueInt64 should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ int totalVals = 1;
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ vals[0] = (*tensor)(0);
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ vals[idx++] = (*tensor)(i0);
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ vals[idx++] = (*tensor)(i0, i1);
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2);
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3);
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4);
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5);
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ FATAL_ERROR("TensorTemplate<T>::getTensorValueBool should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ int totalVals = 1;
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ vals[0] = (*tensor)(0);
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ vals[idx++] = (*tensor)(i0);
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ vals[idx++] = (*tensor)(i0, i1);
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2);
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3);
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4);
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5);
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor0<float>();
+
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor1<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor1<float>(shape[0]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor2<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor2<float>(shape[0], shape[1]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor3<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor3<float>(shape[0], shape[1], shape[2]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor4<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor4<float>(shape[0], shape[1], shape[2], shape[3]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor5<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor5<float>(shape[0], shape[1], shape[2], shape[3], shape[4]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor6<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor6<float>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor0<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor0<int32_t>();
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor1<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor1<int32_t>(shape[0]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor2<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor2<int32_t>(shape[0], shape[1]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor3<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor3<int32_t>(shape[0], shape[1], shape[2]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor4<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor4<int32_t>(shape[0], shape[1], shape[2], shape[3]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor5<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor5<int32_t>(shape[0], shape[1], shape[2], shape[3], shape[4]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor6<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor6<int32_t>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor0<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor0<int64_t>();
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor1<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor1<int64_t>(shape[0]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor2<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor2<int64_t>(shape[0], shape[1]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor3<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor3<int64_t>(shape[0], shape[1], shape[2]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor4<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor4<int64_t>(shape[0], shape[1], shape[2], shape[3]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor5<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor5<int64_t>(shape[0], shape[1], shape[2], shape[3], shape[4]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor6<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor6<int64_t>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor0<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor0<bool>();
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor1<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor1<bool>(shape[0]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor2<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor2<bool>(shape[0], shape[1]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor3<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor3<bool>(shape[0], shape[1], shape[2]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor4<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor4<bool>(shape[0], shape[1], shape[2], shape[3]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor5<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor5<bool>(shape[0], shape[1], shape[2], shape[3], shape[4]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor6<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor6<bool>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor0<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, "[ %%%sf ]\n", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, fp_fmt, (*tensor)(0));
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0));
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1, i2));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3, i4));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, "[");
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3, i4, i5));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, "[ %%ld ]\n");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, i64_fmt, (*tensor)(0));
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, " %%ld ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, i64_fmt, (*tensor)(i0));
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, " %%ld ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, i64_fmt, (*tensor)(i0, i1));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, " %%ld ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, i64_fmt, (*tensor)(i0, i1, i2));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, " %%ld ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, i64_fmt, (*tensor)(i0, i1, i2, i3));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, " %%ld ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, i64_fmt, (*tensor)(i0, i1, i2, i3, i4));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, " %%ld ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, "[");
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ fprintf(out, i64_fmt, (*tensor)(i0, i1, i2, i3, i4, i5));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, "[ %%d ]\n");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, i32_fmt, (*tensor)(0));
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, " %%d ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, i32_fmt, (*tensor)(i0));
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, " %%d ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, i32_fmt, (*tensor)(i0, i1));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, " %%d ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, i32_fmt, (*tensor)(i0, i1, i2));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, " %%d ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, i32_fmt, (*tensor)(i0, i1, i2, i3));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, " %%d ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, i32_fmt, (*tensor)(i0, i1, i2, i3, i4));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, " %%d ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, "[");
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ fprintf(out, i32_fmt, (*tensor)(i0, i1, i2, i3, i4, i5));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, "[ %%s ]\n");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(0)));
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, " %%s ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(i0)));
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, " %%s ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1)));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, " %%s ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1, i2)));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, " %%s ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1, i2, i3)));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, " %%s ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1, i2, i3, i4)));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, " %%s ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, "[");
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1, i2, i3, i4, i5)));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::dumpTensor(FILE* out) const
+{
+ return 0;
+}
+
+// template explicit specialization
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 0>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 1>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 2>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 3>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 4>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 5>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 6>>;
+
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 0>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 1>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 2>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 3>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 4>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 5>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 6>>;
+
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 0>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 1>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 2>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 3>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 4>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 5>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 6>>;
+
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 0>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 1>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 2>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 3>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 4>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 5>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 6>>;
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
new file mode 100644
index 0000000..2fd37cd
--- /dev/null
+++ b/reference_model/src/tensor.h
@@ -0,0 +1,815 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TOSA_REFERENCE_TENSOR_H
+#define TOSA_REFERENCE_TENSOR_H
+
+#include "model_common.h"
+#include "ops/template_types.h"
+#include "tosa_generated.h"
+#include "tosa_serialization_handler.h"
+#include <Eigen/CXX11/Tensor>
+#include <list>
+#include <vector>
+
+using namespace tosa;
+
+namespace TosaReference
+{
+class GraphNode;
+
+class Tensor
+{
+public:
+ Tensor(std::string tensorName_,
+ DType tensorDtype__,
+ const std::vector<Usage>& tensorUsage_,
+ const std::vector<Format>& tensorFormat_,
+ std::vector<int> shape_,
+ int isConst_);
+
+ virtual ~Tensor();
+
+ int setIsSubgraphInput();
+ int setIsSubgraphOutput();
+
+ int getIsSubgraphInput() const
+ {
+ return isSubgraphInput;
+ }
+
+ int getIsSubgraphOutput() const
+ {
+ return isSubgraphOutput;
+ }
+
+ int setProducer(GraphNode* node);
+ int addConsumer(GraphNode* node);
+
+ int setIsValid()
+ {
+ isValid = 1;
+ return 0;
+ }
+
+ int clearIsValid()
+ {
+ isValid = 0;
+ return 0;
+ }
+
+ int getIsValid() const
+ {
+ return isValid;
+ }
+
+ int getIsConst() const
+ {
+ return isConst;
+ }
+
+ GraphNode* getProducer()
+ {
+ return producer;
+ }
+
+ std::vector<GraphNode*>& getConsumers()
+ {
+ return consumers;
+ }
+
+ const std::string& getName() const
+ {
+ return tensorName;
+ }
+
+ const std::vector<int>& getShape() const
+ {
+ return shape;
+ }
+
+ std::string getShapeAsString() const
+ {
+ std::string shape_str("[");
+ for (auto& dim : shape)
+ {
+ shape_str += (std::to_string(dim) + ", ");
+ }
+ shape_str.append("]");
+ return shape_str;
+ }
+
+ const std::vector<Usage>& getUsage() const
+ {
+ return tensorUsage;
+ }
+
+ bool hasUsage(Usage usage) const
+ {
+ for (auto& usg : tensorUsage)
+ {
+ if (usg == usage)
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ std::string getUsageAsString() const
+ {
+ std::string usage_str("[");
+ for (auto& usg : tensorUsage)
+ {
+ usage_str += (std::string(EnumNamesUsage()[usg]) + ", ");
+ }
+ usage_str.append("]");
+ return usage_str;
+ }
+
+ const std::vector<Format>& getFormat() const
+ {
+ return tensorFormat;
+ }
+
+ bool hasFormat(Format format) const
+ {
+ for (auto& fmt : tensorFormat)
+ {
+ if (fmt == format)
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ std::string getFormatAsString() const
+ {
+ std::string format_str("[");
+ for (auto& fmt : tensorFormat)
+ {
+ format_str += (std::string(EnumNamesFormat()[fmt]) + ", ");
+ }
+ format_str.append("]");
+ return format_str;
+ }
+
+ const uint32_t getElementCount() const
+ {
+ uint32_t elements = 1;
+ for (size_t i = 0; i < shape.size(); i++)
+ elements *= shape[i];
+
+ return elements;
+ }
+
+ // Comparison of rank and type with other tensors
+ const int matchRank(const Tensor& ref) const
+ {
+ return (ref.shape.size() == shape.size()) ? 0 : 1;
+ }
+
+ const int matchType(const Tensor& ref) const
+ {
+ return (ref.tensorDtype == tensorDtype) ? 0 : 1;
+ }
+
+ const int matchRankType(const Tensor& ref) const
+ {
+ return (matchType(ref) || matchRank(ref));
+ }
+
+ const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
+ {
+ if (matchRankType(ref))
+ return 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ if (shape[i] != ref.shape[i])
+ {
+ if (!broadcastOk ||
+ // For broadcasts, at least one operand must have size 1
+ // if they don't both match
+ (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
+ {
+ return 1;
+ }
+ }
+ }
+
+ return 0;
+ }
+
+ // Sometimes we might want to match several semi-compatible types,
+ // so just check rank and size here
+ const int matchRankSize(const Tensor& ref) const
+ {
+ if (matchRank(ref))
+ return 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ if (shape[i] != ref.shape[i])
+ return 1;
+ }
+
+ return 0;
+ }
+
+ // Unary check to make sure rank matches
+ const int checkRequiredRank(const int exactRank) const
+ {
+ return (shape.size() == (size_t)exactRank) ? 0 : 1;
+ }
+
+ const int checkRequiredRank(const int minRank, const int maxRank) const
+ {
+ return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
+ }
+
+ const int getRank() const
+ {
+ return shape.size();
+ }
+
+ const DType getDtype() const
+ {
+ return tensorDtype;
+ }
+
+ virtual int dumpTensor(FILE* out) const = 0;
+ virtual int dumpTensorParams(FILE* out) const;
+ virtual int dumpTensorParams(std::ostream& out) const;
+
+ virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
+ virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
+ virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
+ virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
+ virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
+ virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
+ virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
+ virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
+
+ virtual int readFromNpyFile(const char* filename);
+ virtual int writeToNpyFile(const char* filename) const;
+ virtual int copyValueFrom(Tensor* tensor) = 0;
+
+ const char* bool_to_str(bool in) const
+ {
+ static const char* true_str = "true";
+ static const char* false_str = "false";
+ return in ? true_str : false_str;
+ }
+
+ virtual int allocate() = 0;
+ virtual int deallocate() = 0;
+ virtual bool is_allocated() = 0;
+
+protected:
+ std::string tensorName;
+ DType tensorDtype;
+ std::vector<Usage> tensorUsage;
+ std::vector<Format> tensorFormat;
+ int isConst;
+ int isValid;
+ std::vector<int> shape;
+ int isSubgraphInput;
+ int isSubgraphOutput;
+ bool isAllocated;
+
+ GraphNode* producer;
+ std::vector<GraphNode*> consumers;
+
+ // Note: the Eigen::Tensor is not declared in Tensor
+ // Instead, the TensorTemplate class keeps the templated tensor
+ // declaration so that the graph manipulation tools are isolated
+ // from the templated tensor type.
+ //
+ // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
+ // so that they can operate on the right types.
+};
+
+template <class T>
+class TensorTemplate : public Tensor
+{
+public:
+ TensorTemplate(std::string tensorName_,
+ DType tensorDtype_,
+ const std::vector<Usage>& tensorUsage_,
+ const std::vector<Format>& tensorFormat_,
+ std::vector<int> shape_,
+ int isConst_)
+ : Tensor(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, isConst_)
+ {
+ tensor = nullptr;
+ }
+
+ virtual ~TensorTemplate()
+ {
+ deallocate();
+ }
+
+ virtual int allocate()
+ {
+ tensor = new T();
+ if (tensor)
+ return 0;
+ else
+ return 1;
+ }
+
+ virtual int deallocate()
+ {
+ if (tensor)
+ {
+ delete tensor;
+ }
+ tensor = nullptr;
+ return 0;
+ }
+
+ virtual bool is_allocated()
+ {
+ if (tensor)
+ {
+ return true;
+ }
+ return false;
+ }
+
+ T& getTensor()
+ {
+ return *tensor;
+ }
+
+ virtual int dumpTensor(FILE* out) const;
+
+ virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
+ virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+ virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+ virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
+ virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
+ virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
+ virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
+ virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
+
+ virtual int copyValueFrom(Tensor* tensor);
+
+protected:
+ T* tensor;
+};
+
+// allocate() template specializations to allocate the different tensor sizes
+// Let the compiler know here before the factory uses them, but define them in the .cc file.
+template <>
+int Tensor0<float>::allocate();
+template <>
+int Tensor1<float>::allocate();
+template <>
+int Tensor2<float>::allocate();
+template <>
+int Tensor3<float>::allocate();
+template <>
+int Tensor4<float>::allocate();
+template <>
+int Tensor5<float>::allocate();
+template <>
+int Tensor6<float>::allocate();
+
+template <>
+int Tensor0<int32_t>::allocate();
+template <>
+int Tensor1<int32_t>::allocate();
+template <>
+int Tensor2<int32_t>::allocate();
+template <>
+int Tensor3<int32_t>::allocate();
+template <>
+int Tensor4<int32_t>::allocate();
+template <>
+int Tensor5<int32_t>::allocate();
+template <>
+int Tensor6<int32_t>::allocate();
+
+template <>
+int Tensor0<int64_t>::allocate();
+template <>
+int Tensor1<int64_t>::allocate();
+template <>
+int Tensor2<int64_t>::allocate();
+template <>
+int Tensor3<int64_t>::allocate();
+template <>
+int Tensor4<int64_t>::allocate();
+template <>
+int Tensor5<int64_t>::allocate();
+template <>
+int Tensor6<int64_t>::allocate();
+
+template <>
+int Tensor0<bool>::allocate();
+template <>
+int Tensor1<bool>::allocate();
+template <>
+int Tensor2<bool>::allocate();
+template <>
+int Tensor3<bool>::allocate();
+template <>
+int Tensor4<bool>::allocate();
+template <>
+int Tensor5<bool>::allocate();
+template <>
+int Tensor6<bool>::allocate();
+
+template <>
+int Tensor0<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<float>::copyValueFrom(Tensor* src);
+
+template <>
+int Tensor0<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<int32_t>::copyValueFrom(Tensor* src);
+
+template <>
+int Tensor0<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<int64_t>::copyValueFrom(Tensor* src);
+
+template <>
+int Tensor0<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<bool>::copyValueFrom(Tensor* src);
+
+template <>
+int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+
+template <>
+int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+
+template <>
+int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+
+template <>
+int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+
+template <>
+int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+
+template <>
+int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+
+template <>
+int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+
+template <>
+int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+
+// assume we only dump float type tensor now
+template <>
+int Tensor0<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor0<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor0<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor0<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<bool>::dumpTensor(FILE* out) const;
+
+class TensorFactory
+{
+public:
+ static Tensor* newTensor(std::string tensorName_,
+ DType tensorDtype_,
+ const std::vector<Usage>& tensorUsage_,
+ const std::vector<Format>& tensorFormat_,
+ std::vector<int> shape_,
+ int isConst_,
+ const uint32_t rank)
+ {
+ switch (tensorDtype_)
+ {
+ case DType_FLOAT:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 1:
+ return new Tensor1<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 2:
+ return new Tensor2<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 3:
+ return new Tensor3<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 4:
+ return new Tensor4<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 5:
+ return new Tensor5<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 6:
+ return new Tensor6<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ default:
+ goto done;
+ }
+ case DType_INT32:
+ case DType_AINT8:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 1:
+ return new Tensor1<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 2:
+ return new Tensor2<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 3:
+ return new Tensor3<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 4:
+ return new Tensor4<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 5:
+ return new Tensor5<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 6:
+ return new Tensor6<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ default:
+ goto done;
+ }
+ case DType_INT48:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 1:
+ return new Tensor1<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 2:
+ return new Tensor2<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 3:
+ return new Tensor3<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 4:
+ return new Tensor4<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 5:
+ return new Tensor5<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 6:
+ return new Tensor6<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ default:
+ goto done;
+ }
+ case DType_BOOL:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 1:
+ return new Tensor1<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 2:
+ return new Tensor2<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 3:
+ return new Tensor3<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 4:
+ return new Tensor4<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 5:
+ return new Tensor5<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 6:
+ return new Tensor6<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ default:
+ goto done;
+ }
+ default:
+ goto done;
+ }
+
+ done:
+ FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d", tensorName_.c_str(), EnumNamesDType()[tensorDtype_],
+ rank);
+ }
+
+ static Tensor* newTensor(DType type, const std::vector<int> shape);
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/scripts/xunit/xunit.py b/scripts/xunit/xunit.py
new file mode 100644
index 0000000..c636136
--- /dev/null
+++ b/scripts/xunit/xunit.py
@@ -0,0 +1,91 @@
+
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+import xml.etree.ElementTree as ET
+
+class xunit_results():
+ def __init__(self, name='Testsuites'):
+ self.name = name
+ self.suites = []
+ def create_suite(self, name):
+ s = xunit_suite(name)
+ self.suites.append(s)
+ return s
+ def write_results(self, filename):
+ suites = ET.Element(self.name)
+ tree = ET.ElementTree(suites)
+ for s in self.suites:
+ testsuite = ET.SubElement(suites, 'testsuite', {'name' : s.name, 'errors' : '0'})
+ tests = 0
+ failures = 0
+ skip = 0
+ for t in s.tests:
+ test = ET.SubElement(testsuite, 'testcase', {'name' : t.name, 'classname' : t.classname, 'time' : t.time})
+ tests += 1
+ if t.skip:
+ skip += 1
+ ET.SubElement(test, 'skipped', {'type' : 'Skipped test'})
+ if t.fail:
+ failures += 1
+ fail = ET.SubElement(test, 'failure', {'type' : 'Test failed'})
+ fail.text = t.fail
+ if t.sysout:
+ sysout = ET.SubElement(test, 'system-out')
+ sysout.text = t.sysout
+ if t.syserr:
+ syserr = ET.SubElement(test, 'system-err')
+ syserr.text = t.syserr
+ testsuite.attrib['tests'] = str(tests)
+ testsuite.attrib['failures'] = str(failures)
+ testsuite.attrib['skip'] = str(skip)
+ tree.write(filename, 'UTF-8', True)
+
+
+class xunit_suite():
+ def __init__(self, name):
+ self.name = name
+ self.tests = []
+
+class xunit_test():
+ def __init__(self, name, classname=None):
+ self.name = name
+ if classname:
+ self.classname = classname
+ else:
+ self.classname = name
+ self.time = '0.000'
+ self.fail = None
+ self.skip = False
+ self.sysout = None
+ self.syserr = None
+ def failed(self, text):
+ self.fail = text
+ def skipped(self):
+ self.skip = True
+
+
+if __name__ == '__main__':
+ r = xunit_results()
+ s = r.create_suite('selftest')
+ for i in range(0,10):
+ t = xunit_test('atest' + str(i))
+ if i == 3:
+ t.failed('Unknown failure foo')
+ if i == 7:
+ t.skipped()
+ s.tests.append(t)
+ r.write_results('foo.xml')
diff --git a/serialization/CMakeLists.txt b/serialization/CMakeLists.txt
new file mode 100644
index 0000000..7bca824
--- /dev/null
+++ b/serialization/CMakeLists.txt
@@ -0,0 +1,32 @@
+cmake_minimum_required (VERSION 3.4)
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+project (tosa)
+
+set (CMAKE_CXX_STANDARD 11)
+set (CMAKE_CXX_FLAGS "-g -Wall")
+set (FLATBUFFERS_SRC_DIR "../thirdparty/flatbuffers")
+
+set (SOURCE
+ tosa_serialization_handler.cpp
+)
+
+add_library(tosa_serialization STATIC ${SOURCE})
+
+include_directories("./")
+
+target_link_libraries(tosa_serialization PRIVATE flatbuffers)
diff --git a/serialization/attribute.def b/serialization/attribute.def
new file mode 100644
index 0000000..88e8c81
--- /dev/null
+++ b/serialization/attribute.def
@@ -0,0 +1,90 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ Syntax:
+ DEF_ATTRIBUTE(ATTRIBUTE_NAME, NUM_ARGS_IN_ATTRIBUTES, ARG0_TYPE, ARG0_SCALAR_OR_VECTOR, ARGS0_NAME, ...)
+
+ Description:
+ ATTRIBUTE_NAME: corresponding attribute name, must match corresponding "table XXXAttribute" in tosa.fbs
+ NUM_ARGS_IN_ATTRIBUTES: number of arguments in this attribute
+ ARG0_TYPE: data type of arg0 in attribute
+ ARG0_SCALAR_OR_VECTOR: is arg0 a scalar(S) or a vector(V)
+ ARG0_NAME: name of arg0
+ ...: variadic variables for more arguments, depending on NUM_ARGS_IN_ATTRIBUTES
+*/
+
+DEF_ATTRIBUTE(Pool2d, 3,
+ int32_t, V, padding,
+ int32_t, V, kernel,
+ int32_t, V, stride)
+
+DEF_ATTRIBUTE(Conv2d, 3,
+ int32_t, V, padding,
+ int32_t, V, stride,
+ int32_t, V, dilation)
+
+DEF_ATTRIBUTE(TransposeConv2d, 4,
+ int32_t, V, outpad,
+ int32_t, V, stride,
+ int32_t, V, dilation,
+ int32_t, V, output_shape)
+
+DEF_ATTRIBUTE(ReluN, 2,
+ int32_t, S, max_int,
+ float, S, max_fp)
+
+DEF_ATTRIBUTE(Axis, 1,
+ int32_t, S, axis)
+
+DEF_ATTRIBUTE(Reshape, 1,
+ int32_t, V, shape)
+
+DEF_ATTRIBUTE(Slice, 2,
+ int32_t, V, begin,
+ int32_t, V, size)
+
+DEF_ATTRIBUTE(Tile, 1,
+ int32_t, V, multiples)
+
+DEF_ATTRIBUTE(Resize, 5,
+ int32_t, V, output_size,
+ int32_t, V, stride,
+ int32_t, V, offset,
+ int32_t, S, shift,
+ ResizeMode, S, mode)
+
+DEF_ATTRIBUTE(Clamp, 4,
+ int32_t, S, min_int,
+ int32_t, S, max_int,
+ float, S, min_fp,
+ float, S, max_fp)
+
+DEF_ATTRIBUTE(Rescale, 7,
+ int32_t, S, input_zp,
+ int32_t, S, output_zp,
+ int32_t, V, multiplier,
+ int32_t, V, shift,
+ bool, S, scale32,
+ bool, S, double_round,
+ bool, S, per_channel)
+
+DEF_ATTRIBUTE(CondIf, 2,
+ string, S, then_branch,
+ string, S, else_branch)
+
+DEF_ATTRIBUTE(WhileLoop, 2,
+ string, S, cond_branch,
+ string, S, body_branch)
diff --git a/serialization/attribute.h b/serialization/attribute.h
new file mode 100644
index 0000000..2a33a8f
--- /dev/null
+++ b/serialization/attribute.h
@@ -0,0 +1,181 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef _TOSA_SERIALIZATION_ATTRIBUTE_H
+#define _TOSA_SERIALIZATION_ATTRIBUTE_H
+#include "flatbuffers/idl.h"
+#include "flatbuffers/util.h"
+#include "tosa_generated.h"
+
+using std::string;
+
+namespace tosa
+{
+
+class TosaAttributeBase
+{
+public:
+ virtual ~TosaAttributeBase()
+ {}
+};
+
+class TosaNoneAttribute : public TosaAttributeBase
+{
+public:
+ TosaNoneAttribute()
+ {}
+ TosaNoneAttribute(TosaNoneAttribute* p)
+ {}
+};
+
+#define DEF_ARGS_VER0_S_STR(V) _##V = p->V()->str();
+#define DEF_ARGS_VER0_S_DEFAULT(V) _##V = p->V();
+
+#define DEF_ARGS_VER0_S_int32_t(V) DEF_ARGS_VER0_S_DEFAULT(V)
+#define DEF_ARGS_VER0_S_float(V) DEF_ARGS_VER0_S_DEFAULT(V)
+#define DEF_ARGS_VER0_S_bool(V) DEF_ARGS_VER0_S_DEFAULT(V)
+#define DEF_ARGS_VER0_S_ResizeMode(V) DEF_ARGS_VER0_S_DEFAULT(V)
+#define DEF_ARGS_VER0_S_string(V) DEF_ARGS_VER0_S_STR(V)
+
+#define DEF_ARGS_VER0_S(T, V) DEF_ARGS_VER0_S_##T(V)
+#define DEF_ARGS_VER0_V(T, V) _##V = std::vector<T>(p->V()->begin(), p->V()->end());
+
+#define DEF_ARGS_VER1_S(T, V) const T& V
+#define DEF_ARGS_VER1_V(T, V) const std::vector<T>& V
+#define DEF_ARGS_VER2_S(T, V) _##V = V;
+#define DEF_ARGS_VER2_V(T, V) _##V = V;
+#define DEF_ARGS_VER3_S(T, V) \
+ T V() const \
+ { \
+ return _##V; \
+ }
+#define DEF_ARGS_VER3_V(T, V) \
+ std::vector<T> V() const \
+ { \
+ return _##V; \
+ }
+#define DEF_ARGS_VER4_S(T, V) T _##V;
+#define DEF_ARGS_VER4_V(T, V) std::vector<T> _##V;
+
+// another level of preprocessor indirection to handle ", " as function's input argument
+#define DEF_ARGS_VER1_TRUE(T, F, V) DEF_ARGS_VER1_##F(T, V)
+#define DEF_ARGS_VER1_FALSE(T, F, V) , DEF_ARGS_VER1_##F(T, V)
+
+#define DEF_ARGS_VER0(FIRST, T, F, V) DEF_ARGS_VER0_##F(T, V)
+#define DEF_ARGS_VER1(FIRST, T, F, V) DEF_ARGS_VER1_##FIRST(T, F, V)
+#define DEF_ARGS_VER2(FIRST, T, F, V) DEF_ARGS_VER2_##F(T, V)
+#define DEF_ARGS_VER3(FIRST, T, F, V) DEF_ARGS_VER3_##F(T, V)
+#define DEF_ARGS_VER4(FIRST, T, F, V) DEF_ARGS_VER4_##F(T, V)
+
+#define DEF_ARGS_0(VER, ...)
+#define DEF_ARGS_1(VER, T0, F0, V0) DEF_ARGS_##VER(TRUE, T0, F0, V0)
+#define DEF_ARGS_2(VER, T0, F0, V0, T1, F1, V1) DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1)
+#define DEF_ARGS_3(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2)
+#define DEF_ARGS_4(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3)
+#define DEF_ARGS_5(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4)
+
+#define DEF_ARGS_6(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5)
+
+#define DEF_ARGS_7(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \
+ DEF_ARGS_##VER(FALSE, T6, F6, V6)
+
+#define DEF_VER0_VAR_DECL_PTR(NAME) const NAME* p = static_cast<const NAME*>(options);
+#define DEF_VER0_VAR_0(NAME)
+#define DEF_VER0_VAR_1(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+#define DEF_VER0_VAR_2(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+#define DEF_VER0_VAR_3(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+#define DEF_VER0_VAR_4(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+#define DEF_VER0_VAR_5(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+#define DEF_VER0_VAR_6(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+#define DEF_VER0_VAR_7(NAME) DEF_VER0_VAR_DECL_PTR(NAME)
+
+#define DEF_ATTRIBUTE(NAME, NUM_ARGS, ...) \
+ class Tosa##NAME##Attribute : public TosaAttributeBase \
+ { \
+ public: \
+ Tosa##NAME##Attribute(const TosaAttributeBase* options) \
+ { \
+ const Tosa##NAME##Attribute* p = reinterpret_cast<const Tosa##NAME##Attribute*>(options); \
+ *this = *p; \
+ } \
+ Tosa##NAME##Attribute(const Tosa##NAME##Attribute* p) \
+ { \
+ *this = *p; \
+ } \
+ Tosa##NAME##Attribute(const void* options){ DEF_VER0_VAR_##NUM_ARGS(NAME##Attribute) \
+ DEF_ARGS_##NUM_ARGS(VER0, __VA_ARGS__) } Tosa##NAME \
+ ##Attribute(DEF_ARGS_##NUM_ARGS(VER1, __VA_ARGS__)) \
+ { \
+ DEF_ARGS_##NUM_ARGS(VER2, __VA_ARGS__) \
+ } \
+ virtual ~Tosa##NAME##Attribute() \
+ {} \
+ DEF_ARGS_##NUM_ARGS(VER3, __VA_ARGS__) private : DEF_ARGS_##NUM_ARGS(VER4, __VA_ARGS__) \
+ };
+
+#include "attribute.def"
+#undef DEF_ATTRIBUTE
+#undef DEF_ARGS_0
+#undef DEF_ARGS_1
+#undef DEF_ARGS_2
+#undef DEF_ARGS_3
+#undef DEF_ARGS_4
+#undef DEF_ARGS_5
+#undef DEF_ARGS_6
+#undef DEF_ARGS_7
+#undef DEF_ARGS_VER0
+#undef DEF_ARGS_VER1
+#undef DEF_ARGS_VER2
+#undef DEF_ARGS_VER3
+#undef DEF_ARGS_VER4
+#undef DEF_ARGS_VER0_S_int32_t
+#undef DEF_ARGS_VER0_S_float
+#undef DEF_ARGS_VER0_S_bool
+#undef DEF_ARGS_VER0_S_ResizeMode
+#undef DEF_ARGS_VER0_S_string
+#undef DEF_ARGS_VER0_S_STR
+#undef DEF_ARGS_VER0_S_DEFAULT
+#undef DEF_ARGS_VER1_TRUE
+#undef DEF_ARGS_VER1_FALSE
+#undef DEF_ARGS_VER0_S
+#undef DEF_ARGS_VER0_V
+#undef DEF_ARGS_VER1_S
+#undef DEF_ARGS_VER1_V
+#undef DEF_ARGS_VER2_S
+#undef DEF_ARGS_VER2_V
+#undef DEF_ARGS_VER3_S
+#undef DEF_ARGS_VER3_V
+#undef DEF_ARGS_VER4_S
+#undef DEF_ARGS_VER4_V
+#undef DEF_VER0_VAR_0
+#undef DEF_VER0_VAR_1
+#undef DEF_VER0_VAR_2
+#undef DEF_VER0_VAR_3
+#undef DEF_VER0_VAR_4
+#undef DEF_VER0_VAR_5
+#undef DEF_VER0_VAR_DECL_PTR
+
+} // namespace tosa
+
+#endif // _TOSA_SERIALIZATION_ATTRIBUTE_H
diff --git a/serialization/operator.def b/serialization/operator.def
new file mode 100644
index 0000000..66d3784
--- /dev/null
+++ b/serialization/operator.def
@@ -0,0 +1,123 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ Syntax:
+ DEF_OPERATOR(MLIR_NAME, SCHEMA_NAME, REF_IMPL_NAME, OPTIONS, QUANT_INFO)
+
+ Description:
+ MLIR_NAME: the symbolic string of this op, must match tosa_ops.td
+ SCHEMA_NAME: corresponding operator name, must match "enum Op" in serialization/tosa.fbs
+ REF_IMPL_NAME: name used internally in tosa reference implementation
+ OPTIONS: compile time constant options of this op, corresponding to operator_option.def
+ QUANT_INFO: quantization infomation of this op, corresponding to quant_info.def
+*/
+
+
+/* tensor operators */
+DEF_OPERATOR(argmax, ARGMAX, ArgMax, Axis, None)
+DEF_OPERATOR(avg_pool2d, AVG_POOL2D, AvgPool2d, Pool2d, Unary)
+DEF_OPERATOR(conv2d, CONV2D, Conv2d, Conv2d, Conv)
+DEF_OPERATOR(conv3d, CONV3D, Conv3d, None, None)
+DEF_OPERATOR(depthwise_conv2d, DEPTHWISE_CONV2D, DepthwiseConv2d, Conv2d, Conv)
+DEF_OPERATOR(fully_connected, FULLY_CONNECTED, FullyConnected, None, Conv)
+DEF_OPERATOR(matmul, MATMUL, MatMul, None, MatMul)
+DEF_OPERATOR(max_pool2d, MAX_POOL2D, MaxPool2d, Pool2d, None)
+DEF_OPERATOR(transpose_conv2d, TRANSPOSE_CONV2D, TransposeConv2d, TransposeConv2d, Conv)
+
+/* activation */
+DEF_OPERATOR(clamp, CLAMP, Clamp, Clamp, None)
+DEF_OPERATOR(reluN, RELUN, ReluN, ReluN, None)
+DEF_OPERATOR(sigmoid, SIGMOID, Sigmoid, None, None)
+DEF_OPERATOR(tanh, TANH, Tanh, None, None)
+
+/* elementwise - binary */
+DEF_OPERATOR(add, ADD, Add, None, None)
+DEF_OPERATOR(arithmetic_right_shift, ARITHMETIC_RIGHT_SHIFT, ArithmeticRightShift, None, None)
+DEF_OPERATOR(bitwise_and, BITWISE_AND, BitwiseAnd, None, None)
+DEF_OPERATOR(bitwise_or, BITWISE_OR, BitwiseOr, None, None)
+DEF_OPERATOR(bitwise_xor, BITWISE_XOR, BitwiseXor, None, None)
+DEF_OPERATOR(logical_and, LOGICAL_AND, LogicalAnd, None, None)
+DEF_OPERATOR(logical_left_shift, LOGICAL_LEFT_SHIFT, LogicalLeftShift, None, None)
+DEF_OPERATOR(logical_right_shift, LOGICAL_RIGHT_SHIFT, LogicalRightShift, None, None)
+DEF_OPERATOR(logical_or, LOGICAL_OR, LogicalOr, None, None)
+DEF_OPERATOR(logical_xor, LOGICAL_XOR, LogicalXor, None, None)
+DEF_OPERATOR(maximum, MAXIMUM, Maximum, None, None)
+DEF_OPERATOR(minimum, MINIMUM, Minimum, None, None)
+DEF_OPERATOR(mul, MUL, Mul, None, None)
+DEF_OPERATOR(pow, POW, Pow, None, None)
+DEF_OPERATOR(sub, SUB, Sub, None, None)
+DEF_OPERATOR(table, TABLE, Table, None, None)
+
+/* elementwise - unary */
+DEF_OPERATOR(abs, ABS, Abs, None, None)
+DEF_OPERATOR(bitwise_not, BITWISE_NOT, BitwiseNot, None, None)
+DEF_OPERATOR(ceil, CEIL, Ceil, None, None)
+DEF_OPERATOR(clz, CLZ, Clz, None, None)
+DEF_OPERATOR(exp, EXP, Exp, None, None)
+DEF_OPERATOR(floor, FLOOR, Floor, None, None)
+DEF_OPERATOR(log, LOG, Log, None, None)
+DEF_OPERATOR(logical_not, LOGICAL_NOT, LogicalNot, None, None)
+DEF_OPERATOR(negate, NEGATE, Negate, None, Unary)
+DEF_OPERATOR(reciprocal, RECIPROCAL, Reciprocal, None, None)
+DEF_OPERATOR(rsqrt, RSQRT, Rsqrt, None, None)
+
+/* elementwise - ternary */
+DEF_OPERATOR(select, SELECT, Select, None, None)
+
+/* logical */
+DEF_OPERATOR(equal, EQUAL, Equal, None, None)
+DEF_OPERATOR(greater, GREATER, Greater, None, None)
+DEF_OPERATOR(greater_equal, GREATER_EQUAL, GreaterEqual, None, None)
+
+/* reduction */
+DEF_OPERATOR(reduce_any, REDUCE_ANY, ReduceAny, Reduce, None)
+DEF_OPERATOR(reduce_all, REDUCE_ALL, ReduceAll, Reduce, None)
+DEF_OPERATOR(reduce_max, REDUCE_MAX, ReduceMax, Reduce, None)
+DEF_OPERATOR(reduce_min, REDUCE_MIN, ReduceMin, Reduce, None)
+DEF_OPERATOR(reduce_prod, REDUCE_PRODUCT, ReduceProduct, Reduce, None)
+DEF_OPERATOR(reduce_sum, REDUCE_SUM, ReduceSum, Reduce, None)
+
+/* memory operation */
+DEF_OPERATOR(concat, CONCAT, Concat, Axis, None)
+DEF_OPERATOR(pad, PAD, Pad, None, Pad)
+DEF_OPERATOR(reshape, RESHAPE, Reshape, Reshape, None)
+DEF_OPERATOR(reverse, REVERSE, Reverse, Reverse, None)
+DEF_OPERATOR(slice, SLICE, Slice, Slice, None)
+DEF_OPERATOR(tile, TILE, Tile, Tile, None)
+DEF_OPERATOR(transpose, TRANSPOSE, Transpose, None, None)
+
+/* gather/scatter */
+DEF_OPERATOR(gather, GATHER, Gather, Axis, None)
+
+/* image */
+DEF_OPERATOR(resize, RESIZE, Resize, Resize, None)
+
+/* quantization */
+DEF_OPERATOR(cast, CAST, Cast, None, None)
+DEF_OPERATOR(rescale, RESCALE, Rescale, Rescale, None)
+
+/* data nodes */
+DEF_OPERATOR(const, CONST, Const, None, None)
+DEF_OPERATOR(placeholder, PLACEHOLDER, Placeholder, None, None)
+DEF_OPERATOR(identity, IDENTITY, Identity, None, None)
+DEF_OPERATOR(identityn, IDENTITYN, IdentityN, None, None)
+
+/* custom operations */
+DEF_OPERATOR(custom, CUSTOM, Custom, None, None)
+
+/* control flow operators */
+DEF_OPERATOR(cond_if, COND_IF, CondIf, CondIf, None)
+DEF_OPERATOR(while_loop, WHILE_LOOP, WhileLoop, WhileLoop, None)
diff --git a/serialization/quant_info.def b/serialization/quant_info.def
new file mode 100644
index 0000000..39dc101
--- /dev/null
+++ b/serialization/quant_info.def
@@ -0,0 +1,43 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ Syntax:
+ DEF_QUANTIZATION_INFO(NAME, NUM_ARGS_IN_OPTIONS, ARG0_TYPE, ARG0_SCALAR_OR_VECTOR, ARGS0_NAME, ...)
+
+ Description:
+ NAME: corresponding quantization info name, must match corresponding "table XXXQuantInfo" in tosa.fbs
+ NUM_ARGS_IN_QINFO: number of arguments in this quantization info
+ ARG0_TYPE: data type of arg0
+ ARG0_SCALAR_OR_VECTOR: is arg0 a scalar (S) or a vector (V)
+ ARG0_NAME: name of arg0
+ ...: variadic variables for more arguments, depending on NUM_ARGS_IN_QINFO
+*/
+
+
+DEF_QUANTIZATION_INFO(Unary, 2,
+ int32_t, S, input_zp,
+ int32_t, S, output_zp)
+
+DEF_QUANTIZATION_INFO(Conv, 2,
+ int32_t, S, input_zp,
+ int32_t, S, weight_zp)
+
+DEF_QUANTIZATION_INFO(MatMul, 2,
+ int32_t, S, a_zp,
+ int32_t, S, b_zp)
+
+DEF_QUANTIZATION_INFO(Pad, 1,
+ int32_t, S, input_zp)
diff --git a/serialization/quant_info.h b/serialization/quant_info.h
new file mode 100644
index 0000000..03dcab9
--- /dev/null
+++ b/serialization/quant_info.h
@@ -0,0 +1,164 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef _TOSA_SERIALIZATION_QUANT_INFO_H
+#define _TOSA_SERIALIZATION_QUANT_INFO_H
+#include "flatbuffers/idl.h"
+#include "flatbuffers/util.h"
+#include "tosa_generated.h"
+
+namespace tosa
+{
+
+class TosaQuantInfoBase
+{
+public:
+ virtual ~TosaQuantInfoBase()
+ {}
+};
+
+class TosaNoneQuantInfo : public TosaQuantInfoBase
+{
+public:
+ TosaNoneQuantInfo()
+ {}
+ TosaNoneQuantInfo(TosaNoneQuantInfo* p)
+ {}
+};
+
+#define DEF_ARGS_VER0_S(T, V) _##V = p->V();
+#define DEF_ARGS_VER0_V(T, V) _##V = std::vector<T>(p->V()->begin(), p->V()->end());
+#define DEF_ARGS_VER1_S(T, V) const T& V
+#define DEF_ARGS_VER1_V(T, V) const std::vector<T>& V
+#define DEF_ARGS_VER2_S(T, V) _##V = V;
+#define DEF_ARGS_VER2_V(T, V) _##V = V;
+#define DEF_ARGS_VER3_S(T, V) \
+ T V() const \
+ { \
+ return _##V; \
+ }
+#define DEF_ARGS_VER3_V(T, V) \
+ std::vector<T> V() const \
+ { \
+ return _##V; \
+ }
+#define DEF_ARGS_VER4_S(T, V) T _##V;
+#define DEF_ARGS_VER4_V(T, V) std::vector<T> _##V;
+
+// another level of preprocessor indirection to handle ", " as function's input argument
+#define DEF_ARGS_VER1_TRUE(T, F, V) DEF_ARGS_VER1_##F(T, V)
+#define DEF_ARGS_VER1_FALSE(T, F, V) , DEF_ARGS_VER1_##F(T, V)
+
+#define DEF_ARGS_VER0(FIRST, T, F, V) DEF_ARGS_VER0_##F(T, V)
+#define DEF_ARGS_VER1(FIRST, T, F, V) DEF_ARGS_VER1_##FIRST(T, F, V)
+#define DEF_ARGS_VER2(FIRST, T, F, V) DEF_ARGS_VER2_##F(T, V)
+#define DEF_ARGS_VER3(FIRST, T, F, V) DEF_ARGS_VER3_##F(T, V)
+#define DEF_ARGS_VER4(FIRST, T, F, V) DEF_ARGS_VER4_##F(T, V)
+
+#define DEF_ARGS_1(VER, T0, F0, V0) DEF_ARGS_##VER(TRUE, T0, F0, V0)
+#define DEF_ARGS_2(VER, T0, F0, V0, T1, F1, V1) DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1)
+#define DEF_ARGS_3(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2)
+#define DEF_ARGS_4(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3)
+#define DEF_ARGS_5(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4)
+#define DEF_ARGS_6(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5)
+#define DEF_ARGS_7(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \
+ DEF_ARGS_##VER(FALSE, T6, F6, V6)
+#define DEF_ARGS_8(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \
+ DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7)
+#define DEF_ARGS_9(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7, T8, F8, V8) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \
+ DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) DEF_ARGS_##VER(FALSE, T8, F8, V8)
+#define DEF_ARGS_10(VER, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7, T8, F8, V8, T9, F9, V9) \
+ DEF_ARGS_##VER(TRUE, T0, F0, V0) DEF_ARGS_##VER(FALSE, T1, F1, V1) DEF_ARGS_##VER(FALSE, T2, F2, V2) \
+ DEF_ARGS_##VER(FALSE, T3, F3, V3) DEF_ARGS_##VER(FALSE, T4, F4, V4) DEF_ARGS_##VER(FALSE, T5, F5, V5) \
+ DEF_ARGS_##VER(FALSE, T6, F6, V6) DEF_ARGS_##VER(FALSE, T7, F7, V7) DEF_ARGS_##VER(FALSE, T8, F8, V8) \
+ DEF_ARGS_##VER(FALSE, T9, F9, V9)
+
+#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \
+ class Tosa##NAME##QuantInfo : public TosaQuantInfoBase \
+ { \
+ public: \
+ Tosa##NAME##QuantInfo(const TosaQuantInfoBase* qinfo) \
+ { \
+ const Tosa##NAME##QuantInfo* p = dynamic_cast<const Tosa##NAME##QuantInfo*>(qinfo); \
+ assert(p); \
+ *this = *p; \
+ } \
+ Tosa##NAME##QuantInfo(const Tosa##NAME##QuantInfo* p) \
+ { \
+ *this = *p; \
+ } \
+ Tosa##NAME##QuantInfo(const void* qinfo) \
+ { \
+ const NAME##QuantInfo* p = static_cast<const NAME##QuantInfo*>(qinfo); \
+ DEF_ARGS_##NUM_ARGS(VER0, __VA_ARGS__) \
+ } \
+ Tosa##NAME##QuantInfo(DEF_ARGS_##NUM_ARGS(VER1, __VA_ARGS__)) \
+ { \
+ DEF_ARGS_##NUM_ARGS(VER2, __VA_ARGS__) \
+ } \
+ virtual ~Tosa##NAME##QuantInfo() \
+ {} \
+ DEF_ARGS_##NUM_ARGS(VER3, __VA_ARGS__) private : DEF_ARGS_##NUM_ARGS(VER4, __VA_ARGS__) \
+ };
+
+#include "quant_info.def"
+#undef DEF_QUANTIZATION_INFO
+#undef DEF_ARGS_1
+#undef DEF_ARGS_2
+#undef DEF_ARGS_3
+#undef DEF_ARGS_4
+#undef DEF_ARGS_5
+#undef DEF_ARGS_6
+#undef DEF_ARGS_7
+#undef DEF_ARGS_8
+#undef DEF_ARGS_9
+#undef DEF_ARGS_10
+#undef DEF_ARGS_VER0
+#undef DEF_ARGS_VER1
+#undef DEF_ARGS_VER2
+#undef DEF_ARGS_VER3
+#undef DEF_ARGS_VER4
+#undef DEF_ARGS_VER1_TRUE
+#undef DEF_ARGS_VER1_FALSE
+#undef DEF_ARGS_VER0_S
+#undef DEF_ARGS_VER0_V
+#undef DEF_ARGS_VER1_S
+#undef DEF_ARGS_VER1_V
+#undef DEF_ARGS_VER2_S
+#undef DEF_ARGS_VER2_V
+#undef DEF_ARGS_VER3_S
+#undef DEF_ARGS_VER3_V
+#undef DEF_ARGS_VER4_S
+#undef DEF_ARGS_VER4_V
+
+} // namespace tosa
+
+#endif
diff --git a/serialization/tosa.fbs b/serialization/tosa.fbs
new file mode 100644
index 0000000..841cf3d
--- /dev/null
+++ b/serialization/tosa.fbs
@@ -0,0 +1,318 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace tosa;
+
+// This corresponds to the version.
+file_identifier "TOSA";
+// File extension of any written files.
+file_extension "tosa";
+
+enum DType:uint32 {
+ UNKNOWN = 0,
+ BOOL,
+ AINT8,
+ UINT8,
+ INT4,
+ INT8,
+ INT16,
+ INT32,
+ INT48,
+ FLOAT,
+}
+
+enum Format:uint32 {
+ UNKNOWN = 0,
+ NHWC,
+ NDHWC,
+ OHWI,
+ HWIM,
+ DOHWI,
+}
+
+enum Usage:uint32 {
+ UNKNOWN = 0,
+ ACTIVATION,
+ WEIGHT,
+ INDEX,
+}
+
+enum ResizeMode:uint32 {
+ UNKNOWN = 0,
+ NEAREST,
+ BILINEAR,
+}
+
+enum Op:uint32 {
+ UNKNOWN = 0,
+
+ // Tensor Operator
+ ARGMAX,
+ AVG_POOL2D,
+ CONV2D,
+ CONV3D,
+ DEPTHWISE_CONV2D,
+ FULLY_CONNECTED,
+ MATMUL,
+ MAX_POOL2D,
+ TRANSPOSE_CONV2D,
+
+ // Activation
+ CLAMP,
+ RELUN,
+ SIGMOID,
+ TANH,
+
+ // Elementwise-Binary
+ ADD,
+ ARITHMETIC_RIGHT_SHIFT,
+ BITWISE_AND,
+ BITWISE_OR,
+ BITWISE_XOR,
+ LOGICAL_AND,
+ LOGICAL_LEFT_SHIFT,
+ LOGICAL_RIGHT_SHIFT,
+ LOGICAL_OR,
+ LOGICAL_XOR,
+ MAXIMUM,
+ MINIMUM,
+ MUL,
+ POW,
+ SUB,
+ TABLE,
+
+ // Elementwise-Unary
+ ABS,
+ BITWISE_NOT,
+ CEIL,
+ CLZ,
+ EXP,
+ FLOOR,
+ LOG,
+ LOGICAL_NOT,
+ NEGATE,
+ RECIPROCAL,
+ RSQRT,
+
+ // Elementwise-Ternary
+ SELECT,
+
+ // Logical
+ EQUAL,
+ GREATER,
+ GREATER_EQUAL,
+
+ // Reduction
+ REDUCE_ANY,
+ REDUCE_ALL,
+ REDUCE_MAX,
+ REDUCE_MIN,
+ REDUCE_PRODUCT,
+ REDUCE_SUM,
+
+ // Data layout operation
+ CONCAT,
+ PAD,
+ RESHAPE,
+ REVERSE,
+ SLICE,
+ TILE,
+ TRANSPOSE,
+
+ // Gather/scatter operation
+ GATHER,
+
+ // Image
+ RESIZE,
+
+ // Type conversion
+ CAST,
+ RESCALE,
+
+ // Data Nodes
+ CONST,
+ PLACEHOLDER,
+ IDENTITY,
+ IDENTITYN,
+
+ // Custom operations
+ CUSTOM,
+
+ // Control flow operators
+ COND_IF,
+ WHILE_LOOP,
+}
+
+union Attribute {
+ Pool2dAttribute,
+ Conv2dAttribute,
+ TransposeConv2dAttribute,
+ ReluNAttribute,
+ AxisAttribute,
+ ReshapeAttribute,
+ SliceAttribute,
+ TileAttribute,
+ ResizeAttribute,
+ ClampAttribute,
+ RescaleAttribute,
+ CustomAttribute,
+ CondIfAttribute,
+ WhileLoopAttribute,
+}
+
+table Pool2dAttribute {
+ padding: [int32];
+ kernel: [int32];
+ stride: [int32];
+}
+
+table Conv2dAttribute {
+ padding: [int32];
+ stride: [int32];
+ dilation: [int32];
+}
+
+table TransposeConv2dAttribute {
+ outpad: [int32];
+ stride: [int32];
+ dilation: [int32];
+ output_shape: [int32];
+}
+
+table ReluNAttribute {
+ max_int: int32;
+ max_fp: float;
+}
+
+table AxisAttribute {
+ axis: int32;
+}
+
+table ReshapeAttribute {
+ shape: [int32];
+}
+
+table SliceAttribute {
+ begin: [int32];
+ size: [int32];
+}
+
+table TileAttribute {
+ multiples: [int32];
+}
+
+table ResizeAttribute {
+ output_size: [int32];
+ stride: [int32];
+ offset: [int32];
+ shift: int32;
+ mode: ResizeMode;
+}
+
+table ClampAttribute {
+ min_int: int32;
+ max_int: int32;
+ min_fp: float;
+ max_fp: float;
+}
+
+table RescaleAttribute {
+ input_zp: int32;
+ output_zp: int32;
+ multiplier: [int32];
+ shift: [int32];
+ scale32: bool;
+ double_round: bool;
+ per_channel: bool;
+}
+
+table CustomAttribute {
+ identifier: string;
+}
+
+table CondIfAttribute {
+ then_branch: string;
+ else_branch: string;
+}
+
+table WhileLoopAttribute {
+ cond_branch: string;
+ body_branch: string;
+}
+
+union QuantInfo {
+ UnaryQuantInfo,
+ ConvQuantInfo,
+ MatMulQuantInfo,
+ PadQuantInfo,
+}
+
+table UnaryQuantInfo {
+ input_zp: int32;
+ output_zp: int32;
+}
+
+table ConvQuantInfo {
+ input_zp: int32;
+ weight_zp: int32;
+}
+
+table MatMulQuantInfo {
+ a_zp: int32;
+ b_zp: int32;
+}
+
+table PadQuantInfo {
+ input_zp: int32;
+}
+
+table Version {
+ _major: int32 = 0;
+ _minor: int32 = 20;
+ _patch: int32 = 0;
+ _experimental: bool = false;
+}
+
+table TosaTensor {
+ name:string; // name of the tensor, used for solving dependency
+ shape:[int32]; // shape of the tensor
+ type:DType; // data type of the tensor
+ usage:[Usage]; // vector of possible usages. for the convenience of debugging only.
+ format:[Format]; // vector of possible formats. for the convenience of debugging only.
+ npy_filename: string; // numpy array filename
+}
+
+table TosaOperator {
+ op:Op; // operator enum
+ attribute: Attribute; // union structure. operator attribute
+ inputs:[string]; // list of input tensor names
+ outputs:[string]; // list of output tensor names
+ quant_info: QuantInfo; // op-based quantization information
+}
+
+table TosaBasicBlock {
+ name:string; // basic block name
+ operators:[TosaOperator]; // operators array
+ tensors:[TosaTensor]; // tensors array
+ inputs:[string]; // name of graph inputs
+ outputs:[string]; // name of graph outputs
+}
+
+table TosaGraph {
+ version: Version;
+ blocks:[TosaBasicBlock]; // basic blocks array
+}
+
+root_type TosaGraph;
diff --git a/serialization/tosa_generated.h b/serialization/tosa_generated.h
new file mode 100644
index 0000000..5bb21f3
--- /dev/null
+++ b/serialization/tosa_generated.h
@@ -0,0 +1,2605 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_TOSA_TOSA_H_
+#define FLATBUFFERS_GENERATED_TOSA_TOSA_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+namespace tosa {
+
+struct Pool2dAttribute;
+
+struct Conv2dAttribute;
+
+struct TransposeConv2dAttribute;
+
+struct ReluNAttribute;
+
+struct AxisAttribute;
+
+struct ReshapeAttribute;
+
+struct SliceAttribute;
+
+struct TileAttribute;
+
+struct ResizeAttribute;
+
+struct ClampAttribute;
+
+struct RescaleAttribute;
+
+struct CustomAttribute;
+
+struct CondIfAttribute;
+
+struct WhileLoopAttribute;
+
+struct UnaryQuantInfo;
+
+struct ConvQuantInfo;
+
+struct MatMulQuantInfo;
+
+struct PadQuantInfo;
+
+struct Version;
+
+struct TosaTensor;
+
+struct TosaOperator;
+
+struct TosaBasicBlock;
+
+struct TosaGraph;
+
+enum DType {
+ DType_UNKNOWN = 0,
+ DType_BOOL = 1,
+ DType_AINT8 = 2,
+ DType_UINT8 = 3,
+ DType_INT4 = 4,
+ DType_INT8 = 5,
+ DType_INT16 = 6,
+ DType_INT32 = 7,
+ DType_INT48 = 8,
+ DType_FLOAT = 9,
+ DType_MIN = DType_UNKNOWN,
+ DType_MAX = DType_FLOAT
+};
+
+inline const DType (&EnumValuesDType())[10] {
+ static const DType values[] = {
+ DType_UNKNOWN,
+ DType_BOOL,
+ DType_AINT8,
+ DType_UINT8,
+ DType_INT4,
+ DType_INT8,
+ DType_INT16,
+ DType_INT32,
+ DType_INT48,
+ DType_FLOAT
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesDType() {
+ static const char * const names[] = {
+ "UNKNOWN",
+ "BOOL",
+ "AINT8",
+ "UINT8",
+ "INT4",
+ "INT8",
+ "INT16",
+ "INT32",
+ "INT48",
+ "FLOAT",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameDType(DType e) {
+ if (e < DType_UNKNOWN || e > DType_FLOAT) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesDType()[index];
+}
+
+enum Format {
+ Format_UNKNOWN = 0,
+ Format_NHWC = 1,
+ Format_NDHWC = 2,
+ Format_OHWI = 3,
+ Format_HWIM = 4,
+ Format_DOHWI = 5,
+ Format_MIN = Format_UNKNOWN,
+ Format_MAX = Format_DOHWI
+};
+
+inline const Format (&EnumValuesFormat())[6] {
+ static const Format values[] = {
+ Format_UNKNOWN,
+ Format_NHWC,
+ Format_NDHWC,
+ Format_OHWI,
+ Format_HWIM,
+ Format_DOHWI
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesFormat() {
+ static const char * const names[] = {
+ "UNKNOWN",
+ "NHWC",
+ "NDHWC",
+ "OHWI",
+ "HWIM",
+ "DOHWI",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameFormat(Format e) {
+ if (e < Format_UNKNOWN || e > Format_DOHWI) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesFormat()[index];
+}
+
+enum Usage {
+ Usage_UNKNOWN = 0,
+ Usage_ACTIVATION = 1,
+ Usage_WEIGHT = 2,
+ Usage_INDEX = 3,
+ Usage_MIN = Usage_UNKNOWN,
+ Usage_MAX = Usage_INDEX
+};
+
+inline const Usage (&EnumValuesUsage())[4] {
+ static const Usage values[] = {
+ Usage_UNKNOWN,
+ Usage_ACTIVATION,
+ Usage_WEIGHT,
+ Usage_INDEX
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesUsage() {
+ static const char * const names[] = {
+ "UNKNOWN",
+ "ACTIVATION",
+ "WEIGHT",
+ "INDEX",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameUsage(Usage e) {
+ if (e < Usage_UNKNOWN || e > Usage_INDEX) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesUsage()[index];
+}
+
+enum ResizeMode {
+ ResizeMode_UNKNOWN = 0,
+ ResizeMode_NEAREST = 1,
+ ResizeMode_BILINEAR = 2,
+ ResizeMode_MIN = ResizeMode_UNKNOWN,
+ ResizeMode_MAX = ResizeMode_BILINEAR
+};
+
+inline const ResizeMode (&EnumValuesResizeMode())[3] {
+ static const ResizeMode values[] = {
+ ResizeMode_UNKNOWN,
+ ResizeMode_NEAREST,
+ ResizeMode_BILINEAR
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesResizeMode() {
+ static const char * const names[] = {
+ "UNKNOWN",
+ "NEAREST",
+ "BILINEAR",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameResizeMode(ResizeMode e) {
+ if (e < ResizeMode_UNKNOWN || e > ResizeMode_BILINEAR) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesResizeMode()[index];
+}
+
+enum Op {
+ Op_UNKNOWN = 0,
+ Op_ARGMAX = 1,
+ Op_AVG_POOL2D = 2,
+ Op_CONV2D = 3,
+ Op_CONV3D = 4,
+ Op_DEPTHWISE_CONV2D = 5,
+ Op_FULLY_CONNECTED = 6,
+ Op_MATMUL = 7,
+ Op_MAX_POOL2D = 8,
+ Op_TRANSPOSE_CONV2D = 9,
+ Op_CLAMP = 10,
+ Op_RELUN = 11,
+ Op_SIGMOID = 12,
+ Op_TANH = 13,
+ Op_ADD = 14,
+ Op_ARITHMETIC_RIGHT_SHIFT = 15,
+ Op_BITWISE_AND = 16,
+ Op_BITWISE_OR = 17,
+ Op_BITWISE_XOR = 18,
+ Op_LOGICAL_AND = 19,
+ Op_LOGICAL_LEFT_SHIFT = 20,
+ Op_LOGICAL_RIGHT_SHIFT = 21,
+ Op_LOGICAL_OR = 22,
+ Op_LOGICAL_XOR = 23,
+ Op_MAXIMUM = 24,
+ Op_MINIMUM = 25,
+ Op_MUL = 26,
+ Op_POW = 27,
+ Op_SUB = 28,
+ Op_TABLE = 29,
+ Op_ABS = 30,
+ Op_BITWISE_NOT = 31,
+ Op_CEIL = 32,
+ Op_CLZ = 33,
+ Op_EXP = 34,
+ Op_FLOOR = 35,
+ Op_LOG = 36,
+ Op_LOGICAL_NOT = 37,
+ Op_NEGATE = 38,
+ Op_RECIPROCAL = 39,
+ Op_RSQRT = 40,
+ Op_SELECT = 41,
+ Op_EQUAL = 42,
+ Op_GREATER = 43,
+ Op_GREATER_EQUAL = 44,
+ Op_REDUCE_ANY = 45,
+ Op_REDUCE_ALL = 46,
+ Op_REDUCE_MAX = 47,
+ Op_REDUCE_MIN = 48,
+ Op_REDUCE_PRODUCT = 49,
+ Op_REDUCE_SUM = 50,
+ Op_CONCAT = 51,
+ Op_PAD = 52,
+ Op_RESHAPE = 53,
+ Op_REVERSE = 54,
+ Op_SLICE = 55,
+ Op_TILE = 56,
+ Op_TRANSPOSE = 57,
+ Op_GATHER = 58,
+ Op_RESIZE = 59,
+ Op_CAST = 60,
+ Op_RESCALE = 61,
+ Op_CONST = 62,
+ Op_PLACEHOLDER = 63,
+ Op_IDENTITY = 64,
+ Op_IDENTITYN = 65,
+ Op_CUSTOM = 66,
+ Op_COND_IF = 67,
+ Op_WHILE_LOOP = 68,
+ Op_MIN = Op_UNKNOWN,
+ Op_MAX = Op_WHILE_LOOP
+};
+
+inline const Op (&EnumValuesOp())[69] {
+ static const Op values[] = {
+ Op_UNKNOWN,
+ Op_ARGMAX,
+ Op_AVG_POOL2D,
+ Op_CONV2D,
+ Op_CONV3D,
+ Op_DEPTHWISE_CONV2D,
+ Op_FULLY_CONNECTED,
+ Op_MATMUL,
+ Op_MAX_POOL2D,
+ Op_TRANSPOSE_CONV2D,
+ Op_CLAMP,
+ Op_RELUN,
+ Op_SIGMOID,
+ Op_TANH,
+ Op_ADD,
+ Op_ARITHMETIC_RIGHT_SHIFT,
+ Op_BITWISE_AND,
+ Op_BITWISE_OR,
+ Op_BITWISE_XOR,
+ Op_LOGICAL_AND,
+ Op_LOGICAL_LEFT_SHIFT,
+ Op_LOGICAL_RIGHT_SHIFT,
+ Op_LOGICAL_OR,
+ Op_LOGICAL_XOR,
+ Op_MAXIMUM,
+ Op_MINIMUM,
+ Op_MUL,
+ Op_POW,
+ Op_SUB,
+ Op_TABLE,
+ Op_ABS,
+ Op_BITWISE_NOT,
+ Op_CEIL,
+ Op_CLZ,
+ Op_EXP,
+ Op_FLOOR,
+ Op_LOG,
+ Op_LOGICAL_NOT,
+ Op_NEGATE,
+ Op_RECIPROCAL,
+ Op_RSQRT,
+ Op_SELECT,
+ Op_EQUAL,
+ Op_GREATER,
+ Op_GREATER_EQUAL,
+ Op_REDUCE_ANY,
+ Op_REDUCE_ALL,
+ Op_REDUCE_MAX,
+ Op_REDUCE_MIN,
+ Op_REDUCE_PRODUCT,
+ Op_REDUCE_SUM,
+ Op_CONCAT,
+ Op_PAD,
+ Op_RESHAPE,
+ Op_REVERSE,
+ Op_SLICE,
+ Op_TILE,
+ Op_TRANSPOSE,
+ Op_GATHER,
+ Op_RESIZE,
+ Op_CAST,
+ Op_RESCALE,
+ Op_CONST,
+ Op_PLACEHOLDER,
+ Op_IDENTITY,
+ Op_IDENTITYN,
+ Op_CUSTOM,
+ Op_COND_IF,
+ Op_WHILE_LOOP
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesOp() {
+ static const char * const names[] = {
+ "UNKNOWN",
+ "ARGMAX",
+ "AVG_POOL2D",
+ "CONV2D",
+ "CONV3D",
+ "DEPTHWISE_CONV2D",
+ "FULLY_CONNECTED",
+ "MATMUL",
+ "MAX_POOL2D",
+ "TRANSPOSE_CONV2D",
+ "CLAMP",
+ "RELUN",
+ "SIGMOID",
+ "TANH",
+ "ADD",
+ "ARITHMETIC_RIGHT_SHIFT",
+ "BITWISE_AND",
+ "BITWISE_OR",
+ "BITWISE_XOR",
+ "LOGICAL_AND",
+ "LOGICAL_LEFT_SHIFT",
+ "LOGICAL_RIGHT_SHIFT",
+ "LOGICAL_OR",
+ "LOGICAL_XOR",
+ "MAXIMUM",
+ "MINIMUM",
+ "MUL",
+ "POW",
+ "SUB",
+ "TABLE",
+ "ABS",
+ "BITWISE_NOT",
+ "CEIL",
+ "CLZ",
+ "EXP",
+ "FLOOR",
+ "LOG",
+ "LOGICAL_NOT",
+ "NEGATE",
+ "RECIPROCAL",
+ "RSQRT",
+ "SELECT",
+ "EQUAL",
+ "GREATER",
+ "GREATER_EQUAL",
+ "REDUCE_ANY",
+ "REDUCE_ALL",
+ "REDUCE_MAX",
+ "REDUCE_MIN",
+ "REDUCE_PRODUCT",
+ "REDUCE_SUM",
+ "CONCAT",
+ "PAD",
+ "RESHAPE",
+ "REVERSE",
+ "SLICE",
+ "TILE",
+ "TRANSPOSE",
+ "GATHER",
+ "RESIZE",
+ "CAST",
+ "RESCALE",
+ "CONST",
+ "PLACEHOLDER",
+ "IDENTITY",
+ "IDENTITYN",
+ "CUSTOM",
+ "COND_IF",
+ "WHILE_LOOP",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameOp(Op e) {
+ if (e < Op_UNKNOWN || e > Op_WHILE_LOOP) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesOp()[index];
+}
+
+enum Attribute {
+ Attribute_NONE = 0,
+ Attribute_Pool2dAttribute = 1,
+ Attribute_Conv2dAttribute = 2,
+ Attribute_TransposeConv2dAttribute = 3,
+ Attribute_ReluNAttribute = 4,
+ Attribute_AxisAttribute = 5,
+ Attribute_ReshapeAttribute = 6,
+ Attribute_SliceAttribute = 7,
+ Attribute_TileAttribute = 8,
+ Attribute_ResizeAttribute = 9,
+ Attribute_ClampAttribute = 10,
+ Attribute_RescaleAttribute = 11,
+ Attribute_CustomAttribute = 12,
+ Attribute_CondIfAttribute = 13,
+ Attribute_WhileLoopAttribute = 14,
+ Attribute_MIN = Attribute_NONE,
+ Attribute_MAX = Attribute_WhileLoopAttribute
+};
+
+inline const Attribute (&EnumValuesAttribute())[15] {
+ static const Attribute values[] = {
+ Attribute_NONE,
+ Attribute_Pool2dAttribute,
+ Attribute_Conv2dAttribute,
+ Attribute_TransposeConv2dAttribute,
+ Attribute_ReluNAttribute,
+ Attribute_AxisAttribute,
+ Attribute_ReshapeAttribute,
+ Attribute_SliceAttribute,
+ Attribute_TileAttribute,
+ Attribute_ResizeAttribute,
+ Attribute_ClampAttribute,
+ Attribute_RescaleAttribute,
+ Attribute_CustomAttribute,
+ Attribute_CondIfAttribute,
+ Attribute_WhileLoopAttribute
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesAttribute() {
+ static const char * const names[] = {
+ "NONE",
+ "Pool2dAttribute",
+ "Conv2dAttribute",
+ "TransposeConv2dAttribute",
+ "ReluNAttribute",
+ "AxisAttribute",
+ "ReshapeAttribute",
+ "SliceAttribute",
+ "TileAttribute",
+ "ResizeAttribute",
+ "ClampAttribute",
+ "RescaleAttribute",
+ "CustomAttribute",
+ "CondIfAttribute",
+ "WhileLoopAttribute",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameAttribute(Attribute e) {
+ if (e < Attribute_NONE || e > Attribute_WhileLoopAttribute) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesAttribute()[index];
+}
+
+template<typename T> struct AttributeTraits {
+ static const Attribute enum_value = Attribute_NONE;
+};
+
+template<> struct AttributeTraits<Pool2dAttribute> {
+ static const Attribute enum_value = Attribute_Pool2dAttribute;
+};
+
+template<> struct AttributeTraits<Conv2dAttribute> {
+ static const Attribute enum_value = Attribute_Conv2dAttribute;
+};
+
+template<> struct AttributeTraits<TransposeConv2dAttribute> {
+ static const Attribute enum_value = Attribute_TransposeConv2dAttribute;
+};
+
+template<> struct AttributeTraits<ReluNAttribute> {
+ static const Attribute enum_value = Attribute_ReluNAttribute;
+};
+
+template<> struct AttributeTraits<AxisAttribute> {
+ static const Attribute enum_value = Attribute_AxisAttribute;
+};
+
+template<> struct AttributeTraits<ReshapeAttribute> {
+ static const Attribute enum_value = Attribute_ReshapeAttribute;
+};
+
+template<> struct AttributeTraits<SliceAttribute> {
+ static const Attribute enum_value = Attribute_SliceAttribute;
+};
+
+template<> struct AttributeTraits<TileAttribute> {
+ static const Attribute enum_value = Attribute_TileAttribute;
+};
+
+template<> struct AttributeTraits<ResizeAttribute> {
+ static const Attribute enum_value = Attribute_ResizeAttribute;
+};
+
+template<> struct AttributeTraits<ClampAttribute> {
+ static const Attribute enum_value = Attribute_ClampAttribute;
+};
+
+template<> struct AttributeTraits<RescaleAttribute> {
+ static const Attribute enum_value = Attribute_RescaleAttribute;
+};
+
+template<> struct AttributeTraits<CustomAttribute> {
+ static const Attribute enum_value = Attribute_CustomAttribute;
+};
+
+template<> struct AttributeTraits<CondIfAttribute> {
+ static const Attribute enum_value = Attribute_CondIfAttribute;
+};
+
+template<> struct AttributeTraits<WhileLoopAttribute> {
+ static const Attribute enum_value = Attribute_WhileLoopAttribute;
+};
+
+bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type);
+bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+enum QuantInfo {
+ QuantInfo_NONE = 0,
+ QuantInfo_UnaryQuantInfo = 1,
+ QuantInfo_ConvQuantInfo = 2,
+ QuantInfo_MatMulQuantInfo = 3,
+ QuantInfo_PadQuantInfo = 4,
+ QuantInfo_MIN = QuantInfo_NONE,
+ QuantInfo_MAX = QuantInfo_PadQuantInfo
+};
+
+inline const QuantInfo (&EnumValuesQuantInfo())[5] {
+ static const QuantInfo values[] = {
+ QuantInfo_NONE,
+ QuantInfo_UnaryQuantInfo,
+ QuantInfo_ConvQuantInfo,
+ QuantInfo_MatMulQuantInfo,
+ QuantInfo_PadQuantInfo
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesQuantInfo() {
+ static const char * const names[] = {
+ "NONE",
+ "UnaryQuantInfo",
+ "ConvQuantInfo",
+ "MatMulQuantInfo",
+ "PadQuantInfo",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameQuantInfo(QuantInfo e) {
+ if (e < QuantInfo_NONE || e > QuantInfo_PadQuantInfo) return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesQuantInfo()[index];
+}
+
+template<typename T> struct QuantInfoTraits {
+ static const QuantInfo enum_value = QuantInfo_NONE;
+};
+
+template<> struct QuantInfoTraits<UnaryQuantInfo> {
+ static const QuantInfo enum_value = QuantInfo_UnaryQuantInfo;
+};
+
+template<> struct QuantInfoTraits<ConvQuantInfo> {
+ static const QuantInfo enum_value = QuantInfo_ConvQuantInfo;
+};
+
+template<> struct QuantInfoTraits<MatMulQuantInfo> {
+ static const QuantInfo enum_value = QuantInfo_MatMulQuantInfo;
+};
+
+template<> struct QuantInfoTraits<PadQuantInfo> {
+ static const QuantInfo enum_value = QuantInfo_PadQuantInfo;
+};
+
+bool VerifyQuantInfo(flatbuffers::Verifier &verifier, const void *obj, QuantInfo type);
+bool VerifyQuantInfoVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
+struct Pool2dAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_PADDING = 4,
+ VT_KERNEL = 6,
+ VT_STRIDE = 8
+ };
+ const flatbuffers::Vector<int32_t> *padding() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PADDING);
+ }
+ const flatbuffers::Vector<int32_t> *kernel() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_KERNEL);
+ }
+ const flatbuffers::Vector<int32_t> *stride() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_PADDING) &&
+ verifier.VerifyVector(padding()) &&
+ VerifyOffset(verifier, VT_KERNEL) &&
+ verifier.VerifyVector(kernel()) &&
+ VerifyOffset(verifier, VT_STRIDE) &&
+ verifier.VerifyVector(stride()) &&
+ verifier.EndTable();
+ }
+};
+
+struct Pool2dAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_padding(flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding) {
+ fbb_.AddOffset(Pool2dAttribute::VT_PADDING, padding);
+ }
+ void add_kernel(flatbuffers::Offset<flatbuffers::Vector<int32_t>> kernel) {
+ fbb_.AddOffset(Pool2dAttribute::VT_KERNEL, kernel);
+ }
+ void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) {
+ fbb_.AddOffset(Pool2dAttribute::VT_STRIDE, stride);
+ }
+ explicit Pool2dAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Pool2dAttributeBuilder &operator=(const Pool2dAttributeBuilder &);
+ flatbuffers::Offset<Pool2dAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Pool2dAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Pool2dAttribute> CreatePool2dAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> kernel = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0) {
+ Pool2dAttributeBuilder builder_(_fbb);
+ builder_.add_stride(stride);
+ builder_.add_kernel(kernel);
+ builder_.add_padding(padding);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Pool2dAttribute> CreatePool2dAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *padding = nullptr,
+ const std::vector<int32_t> *kernel = nullptr,
+ const std::vector<int32_t> *stride = nullptr) {
+ auto padding__ = padding ? _fbb.CreateVector<int32_t>(*padding) : 0;
+ auto kernel__ = kernel ? _fbb.CreateVector<int32_t>(*kernel) : 0;
+ auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
+ return tosa::CreatePool2dAttribute(
+ _fbb,
+ padding__,
+ kernel__,
+ stride__);
+}
+
+struct Conv2dAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_PADDING = 4,
+ VT_STRIDE = 6,
+ VT_DILATION = 8
+ };
+ const flatbuffers::Vector<int32_t> *padding() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PADDING);
+ }
+ const flatbuffers::Vector<int32_t> *stride() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE);
+ }
+ const flatbuffers::Vector<int32_t> *dilation() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_DILATION);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_PADDING) &&
+ verifier.VerifyVector(padding()) &&
+ VerifyOffset(verifier, VT_STRIDE) &&
+ verifier.VerifyVector(stride()) &&
+ VerifyOffset(verifier, VT_DILATION) &&
+ verifier.VerifyVector(dilation()) &&
+ verifier.EndTable();
+ }
+};
+
+struct Conv2dAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_padding(flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding) {
+ fbb_.AddOffset(Conv2dAttribute::VT_PADDING, padding);
+ }
+ void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) {
+ fbb_.AddOffset(Conv2dAttribute::VT_STRIDE, stride);
+ }
+ void add_dilation(flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation) {
+ fbb_.AddOffset(Conv2dAttribute::VT_DILATION, dilation);
+ }
+ explicit Conv2dAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ Conv2dAttributeBuilder &operator=(const Conv2dAttributeBuilder &);
+ flatbuffers::Offset<Conv2dAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Conv2dAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Conv2dAttribute> CreateConv2dAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> padding = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation = 0) {
+ Conv2dAttributeBuilder builder_(_fbb);
+ builder_.add_dilation(dilation);
+ builder_.add_stride(stride);
+ builder_.add_padding(padding);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Conv2dAttribute> CreateConv2dAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *padding = nullptr,
+ const std::vector<int32_t> *stride = nullptr,
+ const std::vector<int32_t> *dilation = nullptr) {
+ auto padding__ = padding ? _fbb.CreateVector<int32_t>(*padding) : 0;
+ auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
+ auto dilation__ = dilation ? _fbb.CreateVector<int32_t>(*dilation) : 0;
+ return tosa::CreateConv2dAttribute(
+ _fbb,
+ padding__,
+ stride__,
+ dilation__);
+}
+
+struct TransposeConv2dAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OUTPAD = 4,
+ VT_STRIDE = 6,
+ VT_DILATION = 8,
+ VT_OUTPUT_SHAPE = 10
+ };
+ const flatbuffers::Vector<int32_t> *outpad() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OUTPAD);
+ }
+ const flatbuffers::Vector<int32_t> *stride() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE);
+ }
+ const flatbuffers::Vector<int32_t> *dilation() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_DILATION);
+ }
+ const flatbuffers::Vector<int32_t> *output_shape() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OUTPUT_SHAPE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OUTPAD) &&
+ verifier.VerifyVector(outpad()) &&
+ VerifyOffset(verifier, VT_STRIDE) &&
+ verifier.VerifyVector(stride()) &&
+ VerifyOffset(verifier, VT_DILATION) &&
+ verifier.VerifyVector(dilation()) &&
+ VerifyOffset(verifier, VT_OUTPUT_SHAPE) &&
+ verifier.VerifyVector(output_shape()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TransposeConv2dAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_outpad(flatbuffers::Offset<flatbuffers::Vector<int32_t>> outpad) {
+ fbb_.AddOffset(TransposeConv2dAttribute::VT_OUTPAD, outpad);
+ }
+ void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) {
+ fbb_.AddOffset(TransposeConv2dAttribute::VT_STRIDE, stride);
+ }
+ void add_dilation(flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation) {
+ fbb_.AddOffset(TransposeConv2dAttribute::VT_DILATION, dilation);
+ }
+ void add_output_shape(flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_shape) {
+ fbb_.AddOffset(TransposeConv2dAttribute::VT_OUTPUT_SHAPE, output_shape);
+ }
+ explicit TransposeConv2dAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TransposeConv2dAttributeBuilder &operator=(const TransposeConv2dAttributeBuilder &);
+ flatbuffers::Offset<TransposeConv2dAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TransposeConv2dAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TransposeConv2dAttribute> CreateTransposeConv2dAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> outpad = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> dilation = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_shape = 0) {
+ TransposeConv2dAttributeBuilder builder_(_fbb);
+ builder_.add_output_shape(output_shape);
+ builder_.add_dilation(dilation);
+ builder_.add_stride(stride);
+ builder_.add_outpad(outpad);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TransposeConv2dAttribute> CreateTransposeConv2dAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *outpad = nullptr,
+ const std::vector<int32_t> *stride = nullptr,
+ const std::vector<int32_t> *dilation = nullptr,
+ const std::vector<int32_t> *output_shape = nullptr) {
+ auto outpad__ = outpad ? _fbb.CreateVector<int32_t>(*outpad) : 0;
+ auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
+ auto dilation__ = dilation ? _fbb.CreateVector<int32_t>(*dilation) : 0;
+ auto output_shape__ = output_shape ? _fbb.CreateVector<int32_t>(*output_shape) : 0;
+ return tosa::CreateTransposeConv2dAttribute(
+ _fbb,
+ outpad__,
+ stride__,
+ dilation__,
+ output_shape__);
+}
+
+struct ReluNAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_MAX_INT = 4,
+ VT_MAX_FP = 6
+ };
+ int32_t max_int() const {
+ return GetField<int32_t>(VT_MAX_INT, 0);
+ }
+ float max_fp() const {
+ return GetField<float>(VT_MAX_FP, 0.0f);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_MAX_INT) &&
+ VerifyField<float>(verifier, VT_MAX_FP) &&
+ verifier.EndTable();
+ }
+};
+
+struct ReluNAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_max_int(int32_t max_int) {
+ fbb_.AddElement<int32_t>(ReluNAttribute::VT_MAX_INT, max_int, 0);
+ }
+ void add_max_fp(float max_fp) {
+ fbb_.AddElement<float>(ReluNAttribute::VT_MAX_FP, max_fp, 0.0f);
+ }
+ explicit ReluNAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ReluNAttributeBuilder &operator=(const ReluNAttributeBuilder &);
+ flatbuffers::Offset<ReluNAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ReluNAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ReluNAttribute> CreateReluNAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t max_int = 0,
+ float max_fp = 0.0f) {
+ ReluNAttributeBuilder builder_(_fbb);
+ builder_.add_max_fp(max_fp);
+ builder_.add_max_int(max_int);
+ return builder_.Finish();
+}
+
+struct AxisAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_AXIS = 4
+ };
+ int32_t axis() const {
+ return GetField<int32_t>(VT_AXIS, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_AXIS) &&
+ verifier.EndTable();
+ }
+};
+
+struct AxisAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_axis(int32_t axis) {
+ fbb_.AddElement<int32_t>(AxisAttribute::VT_AXIS, axis, 0);
+ }
+ explicit AxisAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ AxisAttributeBuilder &operator=(const AxisAttributeBuilder &);
+ flatbuffers::Offset<AxisAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<AxisAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<AxisAttribute> CreateAxisAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t axis = 0) {
+ AxisAttributeBuilder builder_(_fbb);
+ builder_.add_axis(axis);
+ return builder_.Finish();
+}
+
+struct ReshapeAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_SHAPE = 4
+ };
+ const flatbuffers::Vector<int32_t> *shape() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_SHAPE) &&
+ verifier.VerifyVector(shape()) &&
+ verifier.EndTable();
+ }
+};
+
+struct ReshapeAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_shape(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape) {
+ fbb_.AddOffset(ReshapeAttribute::VT_SHAPE, shape);
+ }
+ explicit ReshapeAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ReshapeAttributeBuilder &operator=(const ReshapeAttributeBuilder &);
+ flatbuffers::Offset<ReshapeAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ReshapeAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ReshapeAttribute> CreateReshapeAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape = 0) {
+ ReshapeAttributeBuilder builder_(_fbb);
+ builder_.add_shape(shape);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<ReshapeAttribute> CreateReshapeAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *shape = nullptr) {
+ auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
+ return tosa::CreateReshapeAttribute(
+ _fbb,
+ shape__);
+}
+
+struct SliceAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BEGIN = 4,
+ VT_SIZE = 6
+ };
+ const flatbuffers::Vector<int32_t> *begin() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_BEGIN);
+ }
+ const flatbuffers::Vector<int32_t> *size() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SIZE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_BEGIN) &&
+ verifier.VerifyVector(begin()) &&
+ VerifyOffset(verifier, VT_SIZE) &&
+ verifier.VerifyVector(size()) &&
+ verifier.EndTable();
+ }
+};
+
+struct SliceAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_begin(flatbuffers::Offset<flatbuffers::Vector<int32_t>> begin) {
+ fbb_.AddOffset(SliceAttribute::VT_BEGIN, begin);
+ }
+ void add_size(flatbuffers::Offset<flatbuffers::Vector<int32_t>> size) {
+ fbb_.AddOffset(SliceAttribute::VT_SIZE, size);
+ }
+ explicit SliceAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SliceAttributeBuilder &operator=(const SliceAttributeBuilder &);
+ flatbuffers::Offset<SliceAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SliceAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SliceAttribute> CreateSliceAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> begin = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> size = 0) {
+ SliceAttributeBuilder builder_(_fbb);
+ builder_.add_size(size);
+ builder_.add_begin(begin);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<SliceAttribute> CreateSliceAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *begin = nullptr,
+ const std::vector<int32_t> *size = nullptr) {
+ auto begin__ = begin ? _fbb.CreateVector<int32_t>(*begin) : 0;
+ auto size__ = size ? _fbb.CreateVector<int32_t>(*size) : 0;
+ return tosa::CreateSliceAttribute(
+ _fbb,
+ begin__,
+ size__);
+}
+
+struct TileAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_MULTIPLES = 4
+ };
+ const flatbuffers::Vector<int32_t> *multiples() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_MULTIPLES);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_MULTIPLES) &&
+ verifier.VerifyVector(multiples()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TileAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_multiples(flatbuffers::Offset<flatbuffers::Vector<int32_t>> multiples) {
+ fbb_.AddOffset(TileAttribute::VT_MULTIPLES, multiples);
+ }
+ explicit TileAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TileAttributeBuilder &operator=(const TileAttributeBuilder &);
+ flatbuffers::Offset<TileAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TileAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TileAttribute> CreateTileAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> multiples = 0) {
+ TileAttributeBuilder builder_(_fbb);
+ builder_.add_multiples(multiples);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TileAttribute> CreateTileAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *multiples = nullptr) {
+ auto multiples__ = multiples ? _fbb.CreateVector<int32_t>(*multiples) : 0;
+ return tosa::CreateTileAttribute(
+ _fbb,
+ multiples__);
+}
+
+struct ResizeAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OUTPUT_SIZE = 4,
+ VT_STRIDE = 6,
+ VT_OFFSET = 8,
+ VT_SHIFT = 10,
+ VT_MODE = 12
+ };
+ const flatbuffers::Vector<int32_t> *output_size() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OUTPUT_SIZE);
+ }
+ const flatbuffers::Vector<int32_t> *stride() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_STRIDE);
+ }
+ const flatbuffers::Vector<int32_t> *offset() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OFFSET);
+ }
+ int32_t shift() const {
+ return GetField<int32_t>(VT_SHIFT, 0);
+ }
+ ResizeMode mode() const {
+ return static_cast<ResizeMode>(GetField<uint32_t>(VT_MODE, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_OUTPUT_SIZE) &&
+ verifier.VerifyVector(output_size()) &&
+ VerifyOffset(verifier, VT_STRIDE) &&
+ verifier.VerifyVector(stride()) &&
+ VerifyOffset(verifier, VT_OFFSET) &&
+ verifier.VerifyVector(offset()) &&
+ VerifyField<int32_t>(verifier, VT_SHIFT) &&
+ VerifyField<uint32_t>(verifier, VT_MODE) &&
+ verifier.EndTable();
+ }
+};
+
+struct ResizeAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_output_size(flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_size) {
+ fbb_.AddOffset(ResizeAttribute::VT_OUTPUT_SIZE, output_size);
+ }
+ void add_stride(flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride) {
+ fbb_.AddOffset(ResizeAttribute::VT_STRIDE, stride);
+ }
+ void add_offset(flatbuffers::Offset<flatbuffers::Vector<int32_t>> offset) {
+ fbb_.AddOffset(ResizeAttribute::VT_OFFSET, offset);
+ }
+ void add_shift(int32_t shift) {
+ fbb_.AddElement<int32_t>(ResizeAttribute::VT_SHIFT, shift, 0);
+ }
+ void add_mode(ResizeMode mode) {
+ fbb_.AddElement<uint32_t>(ResizeAttribute::VT_MODE, static_cast<uint32_t>(mode), 0);
+ }
+ explicit ResizeAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ResizeAttributeBuilder &operator=(const ResizeAttributeBuilder &);
+ flatbuffers::Offset<ResizeAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ResizeAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ResizeAttribute> CreateResizeAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> output_size = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> offset = 0,
+ int32_t shift = 0,
+ ResizeMode mode = ResizeMode_UNKNOWN) {
+ ResizeAttributeBuilder builder_(_fbb);
+ builder_.add_mode(mode);
+ builder_.add_shift(shift);
+ builder_.add_offset(offset);
+ builder_.add_stride(stride);
+ builder_.add_output_size(output_size);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<ResizeAttribute> CreateResizeAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *output_size = nullptr,
+ const std::vector<int32_t> *stride = nullptr,
+ const std::vector<int32_t> *offset = nullptr,
+ int32_t shift = 0,
+ ResizeMode mode = ResizeMode_UNKNOWN) {
+ auto output_size__ = output_size ? _fbb.CreateVector<int32_t>(*output_size) : 0;
+ auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
+ auto offset__ = offset ? _fbb.CreateVector<int32_t>(*offset) : 0;
+ return tosa::CreateResizeAttribute(
+ _fbb,
+ output_size__,
+ stride__,
+ offset__,
+ shift,
+ mode);
+}
+
+struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_MIN_INT = 4,
+ VT_MAX_INT = 6,
+ VT_MIN_FP = 8,
+ VT_MAX_FP = 10
+ };
+ int32_t min_int() const {
+ return GetField<int32_t>(VT_MIN_INT, 0);
+ }
+ int32_t max_int() const {
+ return GetField<int32_t>(VT_MAX_INT, 0);
+ }
+ float min_fp() const {
+ return GetField<float>(VT_MIN_FP, 0.0f);
+ }
+ float max_fp() const {
+ return GetField<float>(VT_MAX_FP, 0.0f);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_MIN_INT) &&
+ VerifyField<int32_t>(verifier, VT_MAX_INT) &&
+ VerifyField<float>(verifier, VT_MIN_FP) &&
+ VerifyField<float>(verifier, VT_MAX_FP) &&
+ verifier.EndTable();
+ }
+};
+
+struct ClampAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_min_int(int32_t min_int) {
+ fbb_.AddElement<int32_t>(ClampAttribute::VT_MIN_INT, min_int, 0);
+ }
+ void add_max_int(int32_t max_int) {
+ fbb_.AddElement<int32_t>(ClampAttribute::VT_MAX_INT, max_int, 0);
+ }
+ void add_min_fp(float min_fp) {
+ fbb_.AddElement<float>(ClampAttribute::VT_MIN_FP, min_fp, 0.0f);
+ }
+ void add_max_fp(float max_fp) {
+ fbb_.AddElement<float>(ClampAttribute::VT_MAX_FP, max_fp, 0.0f);
+ }
+ explicit ClampAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ClampAttributeBuilder &operator=(const ClampAttributeBuilder &);
+ flatbuffers::Offset<ClampAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ClampAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ClampAttribute> CreateClampAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t min_int = 0,
+ int32_t max_int = 0,
+ float min_fp = 0.0f,
+ float max_fp = 0.0f) {
+ ClampAttributeBuilder builder_(_fbb);
+ builder_.add_max_fp(max_fp);
+ builder_.add_min_fp(min_fp);
+ builder_.add_max_int(max_int);
+ builder_.add_min_int(min_int);
+ return builder_.Finish();
+}
+
+struct RescaleAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_INPUT_ZP = 4,
+ VT_OUTPUT_ZP = 6,
+ VT_MULTIPLIER = 8,
+ VT_SHIFT = 10,
+ VT_SCALE32 = 12,
+ VT_DOUBLE_ROUND = 14,
+ VT_PER_CHANNEL = 16
+ };
+ int32_t input_zp() const {
+ return GetField<int32_t>(VT_INPUT_ZP, 0);
+ }
+ int32_t output_zp() const {
+ return GetField<int32_t>(VT_OUTPUT_ZP, 0);
+ }
+ const flatbuffers::Vector<int32_t> *multiplier() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_MULTIPLIER);
+ }
+ const flatbuffers::Vector<int32_t> *shift() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHIFT);
+ }
+ bool scale32() const {
+ return GetField<uint8_t>(VT_SCALE32, 0) != 0;
+ }
+ bool double_round() const {
+ return GetField<uint8_t>(VT_DOUBLE_ROUND, 0) != 0;
+ }
+ bool per_channel() const {
+ return GetField<uint8_t>(VT_PER_CHANNEL, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_INPUT_ZP) &&
+ VerifyField<int32_t>(verifier, VT_OUTPUT_ZP) &&
+ VerifyOffset(verifier, VT_MULTIPLIER) &&
+ verifier.VerifyVector(multiplier()) &&
+ VerifyOffset(verifier, VT_SHIFT) &&
+ verifier.VerifyVector(shift()) &&
+ VerifyField<uint8_t>(verifier, VT_SCALE32) &&
+ VerifyField<uint8_t>(verifier, VT_DOUBLE_ROUND) &&
+ VerifyField<uint8_t>(verifier, VT_PER_CHANNEL) &&
+ verifier.EndTable();
+ }
+};
+
+struct RescaleAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_input_zp(int32_t input_zp) {
+ fbb_.AddElement<int32_t>(RescaleAttribute::VT_INPUT_ZP, input_zp, 0);
+ }
+ void add_output_zp(int32_t output_zp) {
+ fbb_.AddElement<int32_t>(RescaleAttribute::VT_OUTPUT_ZP, output_zp, 0);
+ }
+ void add_multiplier(flatbuffers::Offset<flatbuffers::Vector<int32_t>> multiplier) {
+ fbb_.AddOffset(RescaleAttribute::VT_MULTIPLIER, multiplier);
+ }
+ void add_shift(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shift) {
+ fbb_.AddOffset(RescaleAttribute::VT_SHIFT, shift);
+ }
+ void add_scale32(bool scale32) {
+ fbb_.AddElement<uint8_t>(RescaleAttribute::VT_SCALE32, static_cast<uint8_t>(scale32), 0);
+ }
+ void add_double_round(bool double_round) {
+ fbb_.AddElement<uint8_t>(RescaleAttribute::VT_DOUBLE_ROUND, static_cast<uint8_t>(double_round), 0);
+ }
+ void add_per_channel(bool per_channel) {
+ fbb_.AddElement<uint8_t>(RescaleAttribute::VT_PER_CHANNEL, static_cast<uint8_t>(per_channel), 0);
+ }
+ explicit RescaleAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ RescaleAttributeBuilder &operator=(const RescaleAttributeBuilder &);
+ flatbuffers::Offset<RescaleAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<RescaleAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<RescaleAttribute> CreateRescaleAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t input_zp = 0,
+ int32_t output_zp = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> multiplier = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> shift = 0,
+ bool scale32 = false,
+ bool double_round = false,
+ bool per_channel = false) {
+ RescaleAttributeBuilder builder_(_fbb);
+ builder_.add_shift(shift);
+ builder_.add_multiplier(multiplier);
+ builder_.add_output_zp(output_zp);
+ builder_.add_input_zp(input_zp);
+ builder_.add_per_channel(per_channel);
+ builder_.add_double_round(double_round);
+ builder_.add_scale32(scale32);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<RescaleAttribute> CreateRescaleAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t input_zp = 0,
+ int32_t output_zp = 0,
+ const std::vector<int32_t> *multiplier = nullptr,
+ const std::vector<int32_t> *shift = nullptr,
+ bool scale32 = false,
+ bool double_round = false,
+ bool per_channel = false) {
+ auto multiplier__ = multiplier ? _fbb.CreateVector<int32_t>(*multiplier) : 0;
+ auto shift__ = shift ? _fbb.CreateVector<int32_t>(*shift) : 0;
+ return tosa::CreateRescaleAttribute(
+ _fbb,
+ input_zp,
+ output_zp,
+ multiplier__,
+ shift__,
+ scale32,
+ double_round,
+ per_channel);
+}
+
+struct CustomAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_IDENTIFIER = 4
+ };
+ const flatbuffers::String *identifier() const {
+ return GetPointer<const flatbuffers::String *>(VT_IDENTIFIER);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_IDENTIFIER) &&
+ verifier.VerifyString(identifier()) &&
+ verifier.EndTable();
+ }
+};
+
+struct CustomAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_identifier(flatbuffers::Offset<flatbuffers::String> identifier) {
+ fbb_.AddOffset(CustomAttribute::VT_IDENTIFIER, identifier);
+ }
+ explicit CustomAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CustomAttributeBuilder &operator=(const CustomAttributeBuilder &);
+ flatbuffers::Offset<CustomAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CustomAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CustomAttribute> CreateCustomAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> identifier = 0) {
+ CustomAttributeBuilder builder_(_fbb);
+ builder_.add_identifier(identifier);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<CustomAttribute> CreateCustomAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *identifier = nullptr) {
+ auto identifier__ = identifier ? _fbb.CreateString(identifier) : 0;
+ return tosa::CreateCustomAttribute(
+ _fbb,
+ identifier__);
+}
+
+struct CondIfAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_THEN_BRANCH = 4,
+ VT_ELSE_BRANCH = 6
+ };
+ const flatbuffers::String *then_branch() const {
+ return GetPointer<const flatbuffers::String *>(VT_THEN_BRANCH);
+ }
+ const flatbuffers::String *else_branch() const {
+ return GetPointer<const flatbuffers::String *>(VT_ELSE_BRANCH);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_THEN_BRANCH) &&
+ verifier.VerifyString(then_branch()) &&
+ VerifyOffset(verifier, VT_ELSE_BRANCH) &&
+ verifier.VerifyString(else_branch()) &&
+ verifier.EndTable();
+ }
+};
+
+struct CondIfAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_then_branch(flatbuffers::Offset<flatbuffers::String> then_branch) {
+ fbb_.AddOffset(CondIfAttribute::VT_THEN_BRANCH, then_branch);
+ }
+ void add_else_branch(flatbuffers::Offset<flatbuffers::String> else_branch) {
+ fbb_.AddOffset(CondIfAttribute::VT_ELSE_BRANCH, else_branch);
+ }
+ explicit CondIfAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ CondIfAttributeBuilder &operator=(const CondIfAttributeBuilder &);
+ flatbuffers::Offset<CondIfAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<CondIfAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<CondIfAttribute> CreateCondIfAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> then_branch = 0,
+ flatbuffers::Offset<flatbuffers::String> else_branch = 0) {
+ CondIfAttributeBuilder builder_(_fbb);
+ builder_.add_else_branch(else_branch);
+ builder_.add_then_branch(then_branch);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<CondIfAttribute> CreateCondIfAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *then_branch = nullptr,
+ const char *else_branch = nullptr) {
+ auto then_branch__ = then_branch ? _fbb.CreateString(then_branch) : 0;
+ auto else_branch__ = else_branch ? _fbb.CreateString(else_branch) : 0;
+ return tosa::CreateCondIfAttribute(
+ _fbb,
+ then_branch__,
+ else_branch__);
+}
+
+struct WhileLoopAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_COND_BRANCH = 4,
+ VT_BODY_BRANCH = 6
+ };
+ const flatbuffers::String *cond_branch() const {
+ return GetPointer<const flatbuffers::String *>(VT_COND_BRANCH);
+ }
+ const flatbuffers::String *body_branch() const {
+ return GetPointer<const flatbuffers::String *>(VT_BODY_BRANCH);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_COND_BRANCH) &&
+ verifier.VerifyString(cond_branch()) &&
+ VerifyOffset(verifier, VT_BODY_BRANCH) &&
+ verifier.VerifyString(body_branch()) &&
+ verifier.EndTable();
+ }
+};
+
+struct WhileLoopAttributeBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_cond_branch(flatbuffers::Offset<flatbuffers::String> cond_branch) {
+ fbb_.AddOffset(WhileLoopAttribute::VT_COND_BRANCH, cond_branch);
+ }
+ void add_body_branch(flatbuffers::Offset<flatbuffers::String> body_branch) {
+ fbb_.AddOffset(WhileLoopAttribute::VT_BODY_BRANCH, body_branch);
+ }
+ explicit WhileLoopAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ WhileLoopAttributeBuilder &operator=(const WhileLoopAttributeBuilder &);
+ flatbuffers::Offset<WhileLoopAttribute> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<WhileLoopAttribute>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<WhileLoopAttribute> CreateWhileLoopAttribute(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> cond_branch = 0,
+ flatbuffers::Offset<flatbuffers::String> body_branch = 0) {
+ WhileLoopAttributeBuilder builder_(_fbb);
+ builder_.add_body_branch(body_branch);
+ builder_.add_cond_branch(cond_branch);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<WhileLoopAttribute> CreateWhileLoopAttributeDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *cond_branch = nullptr,
+ const char *body_branch = nullptr) {
+ auto cond_branch__ = cond_branch ? _fbb.CreateString(cond_branch) : 0;
+ auto body_branch__ = body_branch ? _fbb.CreateString(body_branch) : 0;
+ return tosa::CreateWhileLoopAttribute(
+ _fbb,
+ cond_branch__,
+ body_branch__);
+}
+
+struct UnaryQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_INPUT_ZP = 4,
+ VT_OUTPUT_ZP = 6
+ };
+ int32_t input_zp() const {
+ return GetField<int32_t>(VT_INPUT_ZP, 0);
+ }
+ int32_t output_zp() const {
+ return GetField<int32_t>(VT_OUTPUT_ZP, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_INPUT_ZP) &&
+ VerifyField<int32_t>(verifier, VT_OUTPUT_ZP) &&
+ verifier.EndTable();
+ }
+};
+
+struct UnaryQuantInfoBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_input_zp(int32_t input_zp) {
+ fbb_.AddElement<int32_t>(UnaryQuantInfo::VT_INPUT_ZP, input_zp, 0);
+ }
+ void add_output_zp(int32_t output_zp) {
+ fbb_.AddElement<int32_t>(UnaryQuantInfo::VT_OUTPUT_ZP, output_zp, 0);
+ }
+ explicit UnaryQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UnaryQuantInfoBuilder &operator=(const UnaryQuantInfoBuilder &);
+ flatbuffers::Offset<UnaryQuantInfo> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<UnaryQuantInfo>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<UnaryQuantInfo> CreateUnaryQuantInfo(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t input_zp = 0,
+ int32_t output_zp = 0) {
+ UnaryQuantInfoBuilder builder_(_fbb);
+ builder_.add_output_zp(output_zp);
+ builder_.add_input_zp(input_zp);
+ return builder_.Finish();
+}
+
+struct ConvQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_INPUT_ZP = 4,
+ VT_WEIGHT_ZP = 6
+ };
+ int32_t input_zp() const {
+ return GetField<int32_t>(VT_INPUT_ZP, 0);
+ }
+ int32_t weight_zp() const {
+ return GetField<int32_t>(VT_WEIGHT_ZP, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_INPUT_ZP) &&
+ VerifyField<int32_t>(verifier, VT_WEIGHT_ZP) &&
+ verifier.EndTable();
+ }
+};
+
+struct ConvQuantInfoBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_input_zp(int32_t input_zp) {
+ fbb_.AddElement<int32_t>(ConvQuantInfo::VT_INPUT_ZP, input_zp, 0);
+ }
+ void add_weight_zp(int32_t weight_zp) {
+ fbb_.AddElement<int32_t>(ConvQuantInfo::VT_WEIGHT_ZP, weight_zp, 0);
+ }
+ explicit ConvQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ConvQuantInfoBuilder &operator=(const ConvQuantInfoBuilder &);
+ flatbuffers::Offset<ConvQuantInfo> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ConvQuantInfo>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ConvQuantInfo> CreateConvQuantInfo(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t input_zp = 0,
+ int32_t weight_zp = 0) {
+ ConvQuantInfoBuilder builder_(_fbb);
+ builder_.add_weight_zp(weight_zp);
+ builder_.add_input_zp(input_zp);
+ return builder_.Finish();
+}
+
+struct MatMulQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_A_ZP = 4,
+ VT_B_ZP = 6
+ };
+ int32_t a_zp() const {
+ return GetField<int32_t>(VT_A_ZP, 0);
+ }
+ int32_t b_zp() const {
+ return GetField<int32_t>(VT_B_ZP, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_A_ZP) &&
+ VerifyField<int32_t>(verifier, VT_B_ZP) &&
+ verifier.EndTable();
+ }
+};
+
+struct MatMulQuantInfoBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_a_zp(int32_t a_zp) {
+ fbb_.AddElement<int32_t>(MatMulQuantInfo::VT_A_ZP, a_zp, 0);
+ }
+ void add_b_zp(int32_t b_zp) {
+ fbb_.AddElement<int32_t>(MatMulQuantInfo::VT_B_ZP, b_zp, 0);
+ }
+ explicit MatMulQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ MatMulQuantInfoBuilder &operator=(const MatMulQuantInfoBuilder &);
+ flatbuffers::Offset<MatMulQuantInfo> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<MatMulQuantInfo>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<MatMulQuantInfo> CreateMatMulQuantInfo(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t a_zp = 0,
+ int32_t b_zp = 0) {
+ MatMulQuantInfoBuilder builder_(_fbb);
+ builder_.add_b_zp(b_zp);
+ builder_.add_a_zp(a_zp);
+ return builder_.Finish();
+}
+
+struct PadQuantInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_INPUT_ZP = 4
+ };
+ int32_t input_zp() const {
+ return GetField<int32_t>(VT_INPUT_ZP, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_INPUT_ZP) &&
+ verifier.EndTable();
+ }
+};
+
+struct PadQuantInfoBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_input_zp(int32_t input_zp) {
+ fbb_.AddElement<int32_t>(PadQuantInfo::VT_INPUT_ZP, input_zp, 0);
+ }
+ explicit PadQuantInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PadQuantInfoBuilder &operator=(const PadQuantInfoBuilder &);
+ flatbuffers::Offset<PadQuantInfo> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PadQuantInfo>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PadQuantInfo> CreatePadQuantInfo(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t input_zp = 0) {
+ PadQuantInfoBuilder builder_(_fbb);
+ builder_.add_input_zp(input_zp);
+ return builder_.Finish();
+}
+
+struct Version FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT__MAJOR = 4,
+ VT__MINOR = 6,
+ VT__PATCH = 8,
+ VT__EXPERIMENTAL = 10
+ };
+ int32_t _major() const {
+ return GetField<int32_t>(VT__MAJOR, 0);
+ }
+ int32_t _minor() const {
+ return GetField<int32_t>(VT__MINOR, 20);
+ }
+ int32_t _patch() const {
+ return GetField<int32_t>(VT__PATCH, 0);
+ }
+ bool _experimental() const {
+ return GetField<uint8_t>(VT__EXPERIMENTAL, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT__MAJOR) &&
+ VerifyField<int32_t>(verifier, VT__MINOR) &&
+ VerifyField<int32_t>(verifier, VT__PATCH) &&
+ VerifyField<uint8_t>(verifier, VT__EXPERIMENTAL) &&
+ verifier.EndTable();
+ }
+};
+
+struct VersionBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add__major(int32_t _major) {
+ fbb_.AddElement<int32_t>(Version::VT__MAJOR, _major, 0);
+ }
+ void add__minor(int32_t _minor) {
+ fbb_.AddElement<int32_t>(Version::VT__MINOR, _minor, 20);
+ }
+ void add__patch(int32_t _patch) {
+ fbb_.AddElement<int32_t>(Version::VT__PATCH, _patch, 0);
+ }
+ void add__experimental(bool _experimental) {
+ fbb_.AddElement<uint8_t>(Version::VT__EXPERIMENTAL, static_cast<uint8_t>(_experimental), 0);
+ }
+ explicit VersionBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ VersionBuilder &operator=(const VersionBuilder &);
+ flatbuffers::Offset<Version> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Version>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Version> CreateVersion(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t _major = 0,
+ int32_t _minor = 20,
+ int32_t _patch = 0,
+ bool _experimental = false) {
+ VersionBuilder builder_(_fbb);
+ builder_.add__patch(_patch);
+ builder_.add__minor(_minor);
+ builder_.add__major(_major);
+ builder_.add__experimental(_experimental);
+ return builder_.Finish();
+}
+
+struct TosaTensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_NAME = 4,
+ VT_SHAPE = 6,
+ VT_TYPE = 8,
+ VT_USAGE = 10,
+ VT_FORMAT = 12,
+ VT_NPY_FILENAME = 14
+ };
+ const flatbuffers::String *name() const {
+ return GetPointer<const flatbuffers::String *>(VT_NAME);
+ }
+ const flatbuffers::Vector<int32_t> *shape() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE);
+ }
+ DType type() const {
+ return static_cast<DType>(GetField<uint32_t>(VT_TYPE, 0));
+ }
+ const flatbuffers::Vector<uint32_t> *usage() const {
+ return GetPointer<const flatbuffers::Vector<uint32_t> *>(VT_USAGE);
+ }
+ const flatbuffers::Vector<uint32_t> *format() const {
+ return GetPointer<const flatbuffers::Vector<uint32_t> *>(VT_FORMAT);
+ }
+ const flatbuffers::String *npy_filename() const {
+ return GetPointer<const flatbuffers::String *>(VT_NPY_FILENAME);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) &&
+ VerifyOffset(verifier, VT_SHAPE) &&
+ verifier.VerifyVector(shape()) &&
+ VerifyField<uint32_t>(verifier, VT_TYPE) &&
+ VerifyOffset(verifier, VT_USAGE) &&
+ verifier.VerifyVector(usage()) &&
+ VerifyOffset(verifier, VT_FORMAT) &&
+ verifier.VerifyVector(format()) &&
+ VerifyOffset(verifier, VT_NPY_FILENAME) &&
+ verifier.VerifyString(npy_filename()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TosaTensorBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_name(flatbuffers::Offset<flatbuffers::String> name) {
+ fbb_.AddOffset(TosaTensor::VT_NAME, name);
+ }
+ void add_shape(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape) {
+ fbb_.AddOffset(TosaTensor::VT_SHAPE, shape);
+ }
+ void add_type(DType type) {
+ fbb_.AddElement<uint32_t>(TosaTensor::VT_TYPE, static_cast<uint32_t>(type), 0);
+ }
+ void add_usage(flatbuffers::Offset<flatbuffers::Vector<uint32_t>> usage) {
+ fbb_.AddOffset(TosaTensor::VT_USAGE, usage);
+ }
+ void add_format(flatbuffers::Offset<flatbuffers::Vector<uint32_t>> format) {
+ fbb_.AddOffset(TosaTensor::VT_FORMAT, format);
+ }
+ void add_npy_filename(flatbuffers::Offset<flatbuffers::String> npy_filename) {
+ fbb_.AddOffset(TosaTensor::VT_NPY_FILENAME, npy_filename);
+ }
+ explicit TosaTensorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TosaTensorBuilder &operator=(const TosaTensorBuilder &);
+ flatbuffers::Offset<TosaTensor> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TosaTensor>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TosaTensor> CreateTosaTensor(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> name = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape = 0,
+ DType type = DType_UNKNOWN,
+ flatbuffers::Offset<flatbuffers::Vector<uint32_t>> usage = 0,
+ flatbuffers::Offset<flatbuffers::Vector<uint32_t>> format = 0,
+ flatbuffers::Offset<flatbuffers::String> npy_filename = 0) {
+ TosaTensorBuilder builder_(_fbb);
+ builder_.add_npy_filename(npy_filename);
+ builder_.add_format(format);
+ builder_.add_usage(usage);
+ builder_.add_type(type);
+ builder_.add_shape(shape);
+ builder_.add_name(name);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *name = nullptr,
+ const std::vector<int32_t> *shape = nullptr,
+ DType type = DType_UNKNOWN,
+ const std::vector<uint32_t> *usage = nullptr,
+ const std::vector<uint32_t> *format = nullptr,
+ const char *npy_filename = nullptr) {
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
+ auto usage__ = usage ? _fbb.CreateVector<uint32_t>(*usage) : 0;
+ auto format__ = format ? _fbb.CreateVector<uint32_t>(*format) : 0;
+ auto npy_filename__ = npy_filename ? _fbb.CreateString(npy_filename) : 0;
+ return tosa::CreateTosaTensor(
+ _fbb,
+ name__,
+ shape__,
+ type,
+ usage__,
+ format__,
+ npy_filename__);
+}
+
+struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_OP = 4,
+ VT_ATTRIBUTE_TYPE = 6,
+ VT_ATTRIBUTE = 8,
+ VT_INPUTS = 10,
+ VT_OUTPUTS = 12,
+ VT_QUANT_INFO_TYPE = 14,
+ VT_QUANT_INFO = 16
+ };
+ Op op() const {
+ return static_cast<Op>(GetField<uint32_t>(VT_OP, 0));
+ }
+ Attribute attribute_type() const {
+ return static_cast<Attribute>(GetField<uint8_t>(VT_ATTRIBUTE_TYPE, 0));
+ }
+ const void *attribute() const {
+ return GetPointer<const void *>(VT_ATTRIBUTE);
+ }
+ template<typename T> const T *attribute_as() const;
+ const Pool2dAttribute *attribute_as_Pool2dAttribute() const {
+ return attribute_type() == Attribute_Pool2dAttribute ? static_cast<const Pool2dAttribute *>(attribute()) : nullptr;
+ }
+ const Conv2dAttribute *attribute_as_Conv2dAttribute() const {
+ return attribute_type() == Attribute_Conv2dAttribute ? static_cast<const Conv2dAttribute *>(attribute()) : nullptr;
+ }
+ const TransposeConv2dAttribute *attribute_as_TransposeConv2dAttribute() const {
+ return attribute_type() == Attribute_TransposeConv2dAttribute ? static_cast<const TransposeConv2dAttribute *>(attribute()) : nullptr;
+ }
+ const ReluNAttribute *attribute_as_ReluNAttribute() const {
+ return attribute_type() == Attribute_ReluNAttribute ? static_cast<const ReluNAttribute *>(attribute()) : nullptr;
+ }
+ const AxisAttribute *attribute_as_AxisAttribute() const {
+ return attribute_type() == Attribute_AxisAttribute ? static_cast<const AxisAttribute *>(attribute()) : nullptr;
+ }
+ const ReshapeAttribute *attribute_as_ReshapeAttribute() const {
+ return attribute_type() == Attribute_ReshapeAttribute ? static_cast<const ReshapeAttribute *>(attribute()) : nullptr;
+ }
+ const SliceAttribute *attribute_as_SliceAttribute() const {
+ return attribute_type() == Attribute_SliceAttribute ? static_cast<const SliceAttribute *>(attribute()) : nullptr;
+ }
+ const TileAttribute *attribute_as_TileAttribute() const {
+ return attribute_type() == Attribute_TileAttribute ? static_cast<const TileAttribute *>(attribute()) : nullptr;
+ }
+ const ResizeAttribute *attribute_as_ResizeAttribute() const {
+ return attribute_type() == Attribute_ResizeAttribute ? static_cast<const ResizeAttribute *>(attribute()) : nullptr;
+ }
+ const ClampAttribute *attribute_as_ClampAttribute() const {
+ return attribute_type() == Attribute_ClampAttribute ? static_cast<const ClampAttribute *>(attribute()) : nullptr;
+ }
+ const RescaleAttribute *attribute_as_RescaleAttribute() const {
+ return attribute_type() == Attribute_RescaleAttribute ? static_cast<const RescaleAttribute *>(attribute()) : nullptr;
+ }
+ const CustomAttribute *attribute_as_CustomAttribute() const {
+ return attribute_type() == Attribute_CustomAttribute ? static_cast<const CustomAttribute *>(attribute()) : nullptr;
+ }
+ const CondIfAttribute *attribute_as_CondIfAttribute() const {
+ return attribute_type() == Attribute_CondIfAttribute ? static_cast<const CondIfAttribute *>(attribute()) : nullptr;
+ }
+ const WhileLoopAttribute *attribute_as_WhileLoopAttribute() const {
+ return attribute_type() == Attribute_WhileLoopAttribute ? static_cast<const WhileLoopAttribute *>(attribute()) : nullptr;
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *inputs() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_INPUTS);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *outputs() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OUTPUTS);
+ }
+ QuantInfo quant_info_type() const {
+ return static_cast<QuantInfo>(GetField<uint8_t>(VT_QUANT_INFO_TYPE, 0));
+ }
+ const void *quant_info() const {
+ return GetPointer<const void *>(VT_QUANT_INFO);
+ }
+ template<typename T> const T *quant_info_as() const;
+ const UnaryQuantInfo *quant_info_as_UnaryQuantInfo() const {
+ return quant_info_type() == QuantInfo_UnaryQuantInfo ? static_cast<const UnaryQuantInfo *>(quant_info()) : nullptr;
+ }
+ const ConvQuantInfo *quant_info_as_ConvQuantInfo() const {
+ return quant_info_type() == QuantInfo_ConvQuantInfo ? static_cast<const ConvQuantInfo *>(quant_info()) : nullptr;
+ }
+ const MatMulQuantInfo *quant_info_as_MatMulQuantInfo() const {
+ return quant_info_type() == QuantInfo_MatMulQuantInfo ? static_cast<const MatMulQuantInfo *>(quant_info()) : nullptr;
+ }
+ const PadQuantInfo *quant_info_as_PadQuantInfo() const {
+ return quant_info_type() == QuantInfo_PadQuantInfo ? static_cast<const PadQuantInfo *>(quant_info()) : nullptr;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint32_t>(verifier, VT_OP) &&
+ VerifyField<uint8_t>(verifier, VT_ATTRIBUTE_TYPE) &&
+ VerifyOffset(verifier, VT_ATTRIBUTE) &&
+ VerifyAttribute(verifier, attribute(), attribute_type()) &&
+ VerifyOffset(verifier, VT_INPUTS) &&
+ verifier.VerifyVector(inputs()) &&
+ verifier.VerifyVectorOfStrings(inputs()) &&
+ VerifyOffset(verifier, VT_OUTPUTS) &&
+ verifier.VerifyVector(outputs()) &&
+ verifier.VerifyVectorOfStrings(outputs()) &&
+ VerifyField<uint8_t>(verifier, VT_QUANT_INFO_TYPE) &&
+ VerifyOffset(verifier, VT_QUANT_INFO) &&
+ VerifyQuantInfo(verifier, quant_info(), quant_info_type()) &&
+ verifier.EndTable();
+ }
+};
+
+template<> inline const Pool2dAttribute *TosaOperator::attribute_as<Pool2dAttribute>() const {
+ return attribute_as_Pool2dAttribute();
+}
+
+template<> inline const Conv2dAttribute *TosaOperator::attribute_as<Conv2dAttribute>() const {
+ return attribute_as_Conv2dAttribute();
+}
+
+template<> inline const TransposeConv2dAttribute *TosaOperator::attribute_as<TransposeConv2dAttribute>() const {
+ return attribute_as_TransposeConv2dAttribute();
+}
+
+template<> inline const ReluNAttribute *TosaOperator::attribute_as<ReluNAttribute>() const {
+ return attribute_as_ReluNAttribute();
+}
+
+template<> inline const AxisAttribute *TosaOperator::attribute_as<AxisAttribute>() const {
+ return attribute_as_AxisAttribute();
+}
+
+template<> inline const ReshapeAttribute *TosaOperator::attribute_as<ReshapeAttribute>() const {
+ return attribute_as_ReshapeAttribute();
+}
+
+template<> inline const SliceAttribute *TosaOperator::attribute_as<SliceAttribute>() const {
+ return attribute_as_SliceAttribute();
+}
+
+template<> inline const TileAttribute *TosaOperator::attribute_as<TileAttribute>() const {
+ return attribute_as_TileAttribute();
+}
+
+template<> inline const ResizeAttribute *TosaOperator::attribute_as<ResizeAttribute>() const {
+ return attribute_as_ResizeAttribute();
+}
+
+template<> inline const ClampAttribute *TosaOperator::attribute_as<ClampAttribute>() const {
+ return attribute_as_ClampAttribute();
+}
+
+template<> inline const RescaleAttribute *TosaOperator::attribute_as<RescaleAttribute>() const {
+ return attribute_as_RescaleAttribute();
+}
+
+template<> inline const CustomAttribute *TosaOperator::attribute_as<CustomAttribute>() const {
+ return attribute_as_CustomAttribute();
+}
+
+template<> inline const CondIfAttribute *TosaOperator::attribute_as<CondIfAttribute>() const {
+ return attribute_as_CondIfAttribute();
+}
+
+template<> inline const WhileLoopAttribute *TosaOperator::attribute_as<WhileLoopAttribute>() const {
+ return attribute_as_WhileLoopAttribute();
+}
+
+template<> inline const UnaryQuantInfo *TosaOperator::quant_info_as<UnaryQuantInfo>() const {
+ return quant_info_as_UnaryQuantInfo();
+}
+
+template<> inline const ConvQuantInfo *TosaOperator::quant_info_as<ConvQuantInfo>() const {
+ return quant_info_as_ConvQuantInfo();
+}
+
+template<> inline const MatMulQuantInfo *TosaOperator::quant_info_as<MatMulQuantInfo>() const {
+ return quant_info_as_MatMulQuantInfo();
+}
+
+template<> inline const PadQuantInfo *TosaOperator::quant_info_as<PadQuantInfo>() const {
+ return quant_info_as_PadQuantInfo();
+}
+
+struct TosaOperatorBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_op(Op op) {
+ fbb_.AddElement<uint32_t>(TosaOperator::VT_OP, static_cast<uint32_t>(op), 0);
+ }
+ void add_attribute_type(Attribute attribute_type) {
+ fbb_.AddElement<uint8_t>(TosaOperator::VT_ATTRIBUTE_TYPE, static_cast<uint8_t>(attribute_type), 0);
+ }
+ void add_attribute(flatbuffers::Offset<void> attribute) {
+ fbb_.AddOffset(TosaOperator::VT_ATTRIBUTE, attribute);
+ }
+ void add_inputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs) {
+ fbb_.AddOffset(TosaOperator::VT_INPUTS, inputs);
+ }
+ void add_outputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs) {
+ fbb_.AddOffset(TosaOperator::VT_OUTPUTS, outputs);
+ }
+ void add_quant_info_type(QuantInfo quant_info_type) {
+ fbb_.AddElement<uint8_t>(TosaOperator::VT_QUANT_INFO_TYPE, static_cast<uint8_t>(quant_info_type), 0);
+ }
+ void add_quant_info(flatbuffers::Offset<void> quant_info) {
+ fbb_.AddOffset(TosaOperator::VT_QUANT_INFO, quant_info);
+ }
+ explicit TosaOperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TosaOperatorBuilder &operator=(const TosaOperatorBuilder &);
+ flatbuffers::Offset<TosaOperator> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TosaOperator>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TosaOperator> CreateTosaOperator(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ Op op = Op_UNKNOWN,
+ Attribute attribute_type = Attribute_NONE,
+ flatbuffers::Offset<void> attribute = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs = 0,
+ QuantInfo quant_info_type = QuantInfo_NONE,
+ flatbuffers::Offset<void> quant_info = 0) {
+ TosaOperatorBuilder builder_(_fbb);
+ builder_.add_quant_info(quant_info);
+ builder_.add_outputs(outputs);
+ builder_.add_inputs(inputs);
+ builder_.add_attribute(attribute);
+ builder_.add_op(op);
+ builder_.add_quant_info_type(quant_info_type);
+ builder_.add_attribute_type(attribute_type);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TosaOperator> CreateTosaOperatorDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ Op op = Op_UNKNOWN,
+ Attribute attribute_type = Attribute_NONE,
+ flatbuffers::Offset<void> attribute = 0,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *inputs = nullptr,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputs = nullptr,
+ QuantInfo quant_info_type = QuantInfo_NONE,
+ flatbuffers::Offset<void> quant_info = 0) {
+ auto inputs__ = inputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*inputs) : 0;
+ auto outputs__ = outputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*outputs) : 0;
+ return tosa::CreateTosaOperator(
+ _fbb,
+ op,
+ attribute_type,
+ attribute,
+ inputs__,
+ outputs__,
+ quant_info_type,
+ quant_info);
+}
+
+struct TosaBasicBlock FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_NAME = 4,
+ VT_OPERATORS = 6,
+ VT_TENSORS = 8,
+ VT_INPUTS = 10,
+ VT_OUTPUTS = 12
+ };
+ const flatbuffers::String *name() const {
+ return GetPointer<const flatbuffers::String *>(VT_NAME);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<TosaOperator>> *operators() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TosaOperator>> *>(VT_OPERATORS);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<TosaTensor>> *tensors() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TosaTensor>> *>(VT_TENSORS);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *inputs() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_INPUTS);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *outputs() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_OUTPUTS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) &&
+ VerifyOffset(verifier, VT_OPERATORS) &&
+ verifier.VerifyVector(operators()) &&
+ verifier.VerifyVectorOfTables(operators()) &&
+ VerifyOffset(verifier, VT_TENSORS) &&
+ verifier.VerifyVector(tensors()) &&
+ verifier.VerifyVectorOfTables(tensors()) &&
+ VerifyOffset(verifier, VT_INPUTS) &&
+ verifier.VerifyVector(inputs()) &&
+ verifier.VerifyVectorOfStrings(inputs()) &&
+ VerifyOffset(verifier, VT_OUTPUTS) &&
+ verifier.VerifyVector(outputs()) &&
+ verifier.VerifyVectorOfStrings(outputs()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TosaBasicBlockBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_name(flatbuffers::Offset<flatbuffers::String> name) {
+ fbb_.AddOffset(TosaBasicBlock::VT_NAME, name);
+ }
+ void add_operators(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaOperator>>> operators) {
+ fbb_.AddOffset(TosaBasicBlock::VT_OPERATORS, operators);
+ }
+ void add_tensors(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaTensor>>> tensors) {
+ fbb_.AddOffset(TosaBasicBlock::VT_TENSORS, tensors);
+ }
+ void add_inputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs) {
+ fbb_.AddOffset(TosaBasicBlock::VT_INPUTS, inputs);
+ }
+ void add_outputs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs) {
+ fbb_.AddOffset(TosaBasicBlock::VT_OUTPUTS, outputs);
+ }
+ explicit TosaBasicBlockBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TosaBasicBlockBuilder &operator=(const TosaBasicBlockBuilder &);
+ flatbuffers::Offset<TosaBasicBlock> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TosaBasicBlock>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TosaBasicBlock> CreateTosaBasicBlock(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> name = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaOperator>>> operators = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaTensor>>> tensors = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> inputs = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputs = 0) {
+ TosaBasicBlockBuilder builder_(_fbb);
+ builder_.add_outputs(outputs);
+ builder_.add_inputs(inputs);
+ builder_.add_tensors(tensors);
+ builder_.add_operators(operators);
+ builder_.add_name(name);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TosaBasicBlock> CreateTosaBasicBlockDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *name = nullptr,
+ const std::vector<flatbuffers::Offset<TosaOperator>> *operators = nullptr,
+ const std::vector<flatbuffers::Offset<TosaTensor>> *tensors = nullptr,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *inputs = nullptr,
+ const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputs = nullptr) {
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ auto operators__ = operators ? _fbb.CreateVector<flatbuffers::Offset<TosaOperator>>(*operators) : 0;
+ auto tensors__ = tensors ? _fbb.CreateVector<flatbuffers::Offset<TosaTensor>>(*tensors) : 0;
+ auto inputs__ = inputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*inputs) : 0;
+ auto outputs__ = outputs ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*outputs) : 0;
+ return tosa::CreateTosaBasicBlock(
+ _fbb,
+ name__,
+ operators__,
+ tensors__,
+ inputs__,
+ outputs__);
+}
+
+struct TosaGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_VERSION = 4,
+ VT_BLOCKS = 6
+ };
+ const Version *version() const {
+ return GetPointer<const Version *>(VT_VERSION);
+ }
+ const flatbuffers::Vector<flatbuffers::Offset<TosaBasicBlock>> *blocks() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TosaBasicBlock>> *>(VT_BLOCKS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_VERSION) &&
+ verifier.VerifyTable(version()) &&
+ VerifyOffset(verifier, VT_BLOCKS) &&
+ verifier.VerifyVector(blocks()) &&
+ verifier.VerifyVectorOfTables(blocks()) &&
+ verifier.EndTable();
+ }
+};
+
+struct TosaGraphBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_version(flatbuffers::Offset<Version> version) {
+ fbb_.AddOffset(TosaGraph::VT_VERSION, version);
+ }
+ void add_blocks(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaBasicBlock>>> blocks) {
+ fbb_.AddOffset(TosaGraph::VT_BLOCKS, blocks);
+ }
+ explicit TosaGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TosaGraphBuilder &operator=(const TosaGraphBuilder &);
+ flatbuffers::Offset<TosaGraph> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TosaGraph>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TosaGraph> CreateTosaGraph(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<Version> version = 0,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<TosaBasicBlock>>> blocks = 0) {
+ TosaGraphBuilder builder_(_fbb);
+ builder_.add_blocks(blocks);
+ builder_.add_version(version);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TosaGraph> CreateTosaGraphDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<Version> version = 0,
+ const std::vector<flatbuffers::Offset<TosaBasicBlock>> *blocks = nullptr) {
+ auto blocks__ = blocks ? _fbb.CreateVector<flatbuffers::Offset<TosaBasicBlock>>(*blocks) : 0;
+ return tosa::CreateTosaGraph(
+ _fbb,
+ version,
+ blocks__);
+}
+
+inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type) {
+ switch (type) {
+ case Attribute_NONE: {
+ return true;
+ }
+ case Attribute_Pool2dAttribute: {
+ auto ptr = reinterpret_cast<const Pool2dAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_Conv2dAttribute: {
+ auto ptr = reinterpret_cast<const Conv2dAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_TransposeConv2dAttribute: {
+ auto ptr = reinterpret_cast<const TransposeConv2dAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_ReluNAttribute: {
+ auto ptr = reinterpret_cast<const ReluNAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_AxisAttribute: {
+ auto ptr = reinterpret_cast<const AxisAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_ReshapeAttribute: {
+ auto ptr = reinterpret_cast<const ReshapeAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_SliceAttribute: {
+ auto ptr = reinterpret_cast<const SliceAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_TileAttribute: {
+ auto ptr = reinterpret_cast<const TileAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_ResizeAttribute: {
+ auto ptr = reinterpret_cast<const ResizeAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_ClampAttribute: {
+ auto ptr = reinterpret_cast<const ClampAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_RescaleAttribute: {
+ auto ptr = reinterpret_cast<const RescaleAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_CustomAttribute: {
+ auto ptr = reinterpret_cast<const CustomAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_CondIfAttribute: {
+ auto ptr = reinterpret_cast<const CondIfAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case Attribute_WhileLoopAttribute: {
+ auto ptr = reinterpret_cast<const WhileLoopAttribute *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return false;
+ }
+}
+
+inline bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyAttribute(
+ verifier, values->Get(i), types->GetEnum<Attribute>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline bool VerifyQuantInfo(flatbuffers::Verifier &verifier, const void *obj, QuantInfo type) {
+ switch (type) {
+ case QuantInfo_NONE: {
+ return true;
+ }
+ case QuantInfo_UnaryQuantInfo: {
+ auto ptr = reinterpret_cast<const UnaryQuantInfo *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case QuantInfo_ConvQuantInfo: {
+ auto ptr = reinterpret_cast<const ConvQuantInfo *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case QuantInfo_MatMulQuantInfo: {
+ auto ptr = reinterpret_cast<const MatMulQuantInfo *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case QuantInfo_PadQuantInfo: {
+ auto ptr = reinterpret_cast<const PadQuantInfo *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default: return false;
+ }
+}
+
+inline bool VerifyQuantInfoVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+ if (!values || !types) return !values && !types;
+ if (values->size() != types->size()) return false;
+ for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+ if (!VerifyQuantInfo(
+ verifier, values->Get(i), types->GetEnum<QuantInfo>(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline const tosa::TosaGraph *GetTosaGraph(const void *buf) {
+ return flatbuffers::GetRoot<tosa::TosaGraph>(buf);
+}
+
+inline const tosa::TosaGraph *GetSizePrefixedTosaGraph(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<tosa::TosaGraph>(buf);
+}
+
+inline const char *TosaGraphIdentifier() {
+ return "TOSA";
+}
+
+inline bool TosaGraphBufferHasIdentifier(const void *buf) {
+ return flatbuffers::BufferHasIdentifier(
+ buf, TosaGraphIdentifier());
+}
+
+inline bool VerifyTosaGraphBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<tosa::TosaGraph>(TosaGraphIdentifier());
+}
+
+inline bool VerifySizePrefixedTosaGraphBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<tosa::TosaGraph>(TosaGraphIdentifier());
+}
+
+inline const char *TosaGraphExtension() {
+ return "tosa";
+}
+
+inline void FinishTosaGraphBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<tosa::TosaGraph> root) {
+ fbb.Finish(root, TosaGraphIdentifier());
+}
+
+inline void FinishSizePrefixedTosaGraphBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<tosa::TosaGraph> root) {
+ fbb.FinishSizePrefixed(root, TosaGraphIdentifier());
+}
+
+} // namespace tosa
+
+#endif // FLATBUFFERS_GENERATED_TOSA_TOSA_H_
diff --git a/serialization/tosa_serialization_handler.cpp b/serialization/tosa_serialization_handler.cpp
new file mode 100644
index 0000000..7fe9f47
--- /dev/null
+++ b/serialization/tosa_serialization_handler.cpp
@@ -0,0 +1,1526 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tosa_serialization_handler.h"
+
+#include <iostream>
+using namespace tosa;
+
+TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name,
+ const flatbuffers::Vector<uint32_t>& usage,
+ const flatbuffers::Vector<int32_t>& shape,
+ DType dtype,
+ const flatbuffers::Vector<uint32_t>& format,
+ const flatbuffers::String* npy_filename)
+{
+ _dtype = dtype;
+
+ _usage = new std::vector<Usage>(usage.size());
+ for (uint32_t us : usage)
+ {
+ _usage->push_back((Usage)us);
+ }
+ assert(_usage);
+
+ _format = new std::vector<Format>(format.size());
+ for (uint32_t fm : format)
+ {
+ _format->push_back((Format)fm);
+ }
+ assert(_format);
+
+ _shape = new std::vector<int32_t>(shape.begin(), shape.end());
+
+ _shape = new std::vector<int32_t>(shape.begin(), shape.end());
+ assert(_shape);
+
+ assert(name);
+ _name = new std::string(name->str());
+ assert(_name);
+
+ if (npy_filename)
+ {
+ _npy_filename = new std::string(npy_filename->str());
+ assert(_npy_filename);
+ }
+ else
+ {
+ _npy_filename = nullptr;
+ }
+}
+
+TosaSerializationTensor::TosaSerializationTensor(std::string name,
+ const std::vector<Usage>& usage,
+ const std::vector<int32_t>& shape,
+ DType dtype,
+ const std::vector<Format>& format,
+ const std::string* npy_filename)
+{
+
+ _dtype = dtype;
+
+ _usage = new std::vector<Usage>(usage);
+ assert(_usage);
+
+ _format = new std::vector<Format>(format);
+ assert(_format);
+
+ _shape = new std::vector<int32_t>(shape);
+ assert(_shape);
+
+ _name = new std::string(name);
+ assert(_name);
+
+ if (npy_filename)
+ {
+ _npy_filename = new std::string(*npy_filename);
+ assert(_npy_filename);
+ }
+ else
+ {
+ _npy_filename = nullptr;
+ }
+}
+
+TosaSerializationTensor::TosaSerializationTensor()
+{
+ _dtype = DType_UNKNOWN;
+
+ _usage = new std::vector<Usage>();
+ _format = new std::vector<Format>();
+ _shape = new std::vector<int32_t>();
+ _name = new std::string("UNKNOWN");
+ assert(_usage && _format && _shape && _name);
+
+ _npy_filename = nullptr;
+}
+
+TosaSerializationTensor::TosaSerializationTensor(const TosaSerializationTensor& rhs)
+{
+ _dtype = rhs._dtype;
+
+ assert(rhs._usage);
+ _usage = new std::vector<Usage>(*rhs._usage);
+ assert(_usage);
+
+ assert(rhs._format);
+ _format = new std::vector<Format>(*rhs._format);
+ assert(_format);
+
+ assert(rhs._shape);
+ _shape = new std::vector<int32_t>(*rhs._shape);
+ assert(_shape);
+
+ assert(rhs._name);
+ _name = new std::string(*rhs._name);
+ assert(_name);
+
+ if (rhs._npy_filename)
+ {
+ _npy_filename = new std::string(*rhs._npy_filename);
+ assert(_npy_filename);
+ }
+ else
+ {
+ _npy_filename = nullptr;
+ }
+}
+
+TosaSerializationTensor& TosaSerializationTensor::operator=(const TosaSerializationTensor& rhs)
+{
+ _dtype = rhs._dtype;
+
+ delete _usage;
+ assert(rhs._usage);
+ _usage = new std::vector<Usage>(*rhs._usage);
+ assert(_usage);
+
+ delete _format;
+ assert(rhs._format);
+ _format = new std::vector<Format>(*rhs._format);
+ assert(_format);
+
+ delete _shape;
+ assert(rhs._shape);
+ _shape = new std::vector<int32_t>(*rhs._shape);
+ assert(_shape);
+
+ delete _name;
+ assert(rhs._name);
+ _name = new std::string(*rhs._name);
+ assert(_name);
+
+ if (_npy_filename)
+ delete _npy_filename;
+
+ if (rhs._npy_filename)
+ {
+ _npy_filename = new std::string(*rhs._npy_filename);
+ }
+ else
+ {
+ _npy_filename = nullptr;
+ }
+ return *this;
+}
+
+TosaSerializationTensor::TosaSerializationTensor(TosaSerializationTensor&& rhs)
+{
+ _dtype = rhs._dtype;
+ std::swap(_format, rhs._format);
+ std::swap(_usage, rhs._usage);
+ std::swap(_shape, rhs._shape);
+ std::swap(_name, rhs._name);
+ std::swap(_npy_filename, rhs._npy_filename);
+}
+
+TosaSerializationTensor& TosaSerializationTensor::operator=(TosaSerializationTensor&& rhs)
+{
+ _dtype = rhs._dtype;
+ std::swap(_format, rhs._format);
+ std::swap(_usage, rhs._usage);
+ std::swap(_shape, rhs._shape);
+ std::swap(_name, rhs._name);
+ std::swap(_npy_filename, rhs._npy_filename);
+ return *this;
+}
+
+TosaSerializationTensor::~TosaSerializationTensor()
+{
+ delete _usage;
+ delete _format;
+ delete _shape;
+ delete _name;
+ if (_npy_filename)
+ delete _npy_filename;
+}
+
+TosaSerializationOperator::TosaSerializationOperator(Op op,
+ Attribute attribute_type,
+ const TosaAttributeBase* attribute,
+ QuantInfo qinfo_type,
+ const TosaQuantInfoBase* qinfo,
+ std::vector<std::string> input_tensor_names,
+ std::vector<std::string> output_tensor_names)
+{
+ _op = op;
+ _attribute_type = attribute_type;
+
+ switch (attribute_type)
+ {
+ case Attribute_NONE:
+ _attribute = new TosaNoneAttribute();
+ break;
+#define DEF_ATTRIBUTE(NAME, ...) \
+ case Attribute_##NAME##Attribute: \
+ _attribute = new Tosa##NAME##Attribute(attribute); \
+ break;
+#include "attribute.def"
+#undef DEF_ATTRIBUTE
+ default:
+ printf("TosaSerializationOperator::TosaSerializationOperator(): Attribute %s not implemented yet\n",
+ EnumNamesAttribute()[attribute_type]);
+ assert(0);
+ }
+
+ _qinfo_type = qinfo_type;
+ switch (qinfo_type)
+ {
+ case QuantInfo_NONE:
+ _qinfo = new TosaNoneQuantInfo();
+ break;
+#define DEF_QUANTIZATION_INFO(NAME, ...) \
+ case QuantInfo_##NAME##QuantInfo: \
+ _qinfo = new Tosa##NAME##QuantInfo(qinfo); \
+ break;
+#include "quant_info.def"
+#undef DEF_QUANTIZATION_INFO
+ default:
+ printf("TosaSerializationOperator::TosaSerializationOperator(): QuantInfo %s not implemented yet\n",
+ EnumNamesQuantInfo()[qinfo_type]);
+ assert(0);
+ }
+
+ assert(_attribute && _qinfo);
+
+ _input_tensor_names = new std::vector<std::string>(input_tensor_names);
+ _output_tensor_names = new std::vector<std::string>(output_tensor_names);
+
+ assert(_input_tensor_names && _output_tensor_names);
+
+ _input_tensors = new std::vector<TosaSerializationTensor*>();
+ _output_tensors = new std::vector<TosaSerializationTensor*>();
+
+ assert(_input_tensors && _output_tensors);
+}
+
+TosaSerializationOperator::~TosaSerializationOperator()
+{
+ delete _attribute;
+ delete _qinfo;
+ delete _input_tensor_names;
+ delete _output_tensor_names;
+ // TosaSerializationTensor should be free'd in TosaSerializationSerializationHandler destructor
+ delete _input_tensors;
+ delete _output_tensors;
+}
+
+TosaSerializationBasicBlock::TosaSerializationBasicBlock(std::string name,
+ std::vector<TosaSerializationOperator*> operators,
+ std::vector<TosaSerializationTensor*> tensors,
+ std::vector<std::string> inputs,
+ std::vector<std::string> outputs)
+{
+
+ _name = new std::string(name);
+ assert(_name);
+
+ _operators = new std::vector<TosaSerializationOperator*>(operators);
+ assert(_operators);
+
+ _tensors = new std::vector<TosaSerializationTensor*>(tensors);
+ assert(_tensors);
+
+ _inputs = new std::vector<std::string>(inputs);
+ assert(_inputs);
+
+ _outputs = new std::vector<std::string>(outputs);
+ assert(_outputs);
+}
+
+TosaSerializationBasicBlock::~TosaSerializationBasicBlock()
+{
+ delete _name;
+
+ // deallocate all operators
+ for (auto op : GetOperators())
+ {
+ delete op; // ~TosaSerializationOperator()
+ }
+ delete _operators;
+
+ // deallocate all tensors
+ for (auto ts : GetTensors())
+ {
+ delete ts; // ~TosaSerializationTensor()
+ }
+ _tensors->clear();
+
+ delete _inputs;
+ delete _outputs;
+}
+
+TosaSerializationHandler::TosaSerializationHandler()
+{
+ _schemaLoaded = false;
+ _builder = new flatbuffers::FlatBufferBuilder();
+ _parser = new flatbuffers::Parser();
+ _blocks = new std::vector<TosaSerializationBasicBlock*>();
+
+ assert(_builder && _parser && _blocks);
+
+ SetTosaVersion();
+}
+
+TosaSerializationHandler::~TosaSerializationHandler()
+{
+ if (_version)
+ delete _version;
+ delete _builder;
+ delete _parser;
+
+ Clear(); // deallocate all basic blocks
+
+ delete _blocks;
+}
+
+tosa_err_t TosaSerializationHandler::SetTosaVersion()
+{
+ // version is specified within .fbs
+ // and it's encoded as defaulted value of CreateTosaVersion()
+ // need to write out one object to read that value out
+ // TODO: very costly now. is there any better way to encode constant in .fbs?
+ auto fboffset_version = CreateVersion(*_builder);
+ auto fboffset_tosa_graph = CreateTosaGraphDirect(*_builder, fboffset_version, nullptr);
+ _builder->Finish(fboffset_tosa_graph);
+ std::string jsongen;
+ uint8_t* buf = _builder->GetBufferPointer();
+ auto fb_tosa_graph = GetTosaGraph(buf);
+ auto fb_tosa_version = fb_tosa_graph->version();
+
+ _version = new TosaVersion(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(),
+ fb_tosa_version->_experimental());
+
+ assert(_version);
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::LoadFileSchema(const char* schema_filename)
+{
+ std::string schema;
+ bool ok;
+
+ ok = flatbuffers::LoadFile(schema_filename, false, &schema);
+ if (!ok)
+ {
+ printf("Error loading schema file: %s\n", schema_filename);
+ return TOSA_FILE_ERROR;
+ }
+
+ ok = _parser->Parse(schema.c_str());
+ if (!ok)
+ {
+ printf("Error parsing ISA schema file: %s\n", schema_filename);
+ return TOSA_FILE_ERROR;
+ }
+ _schemaLoaded = true;
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::LoadFileJson(const char* filename)
+{
+ std::string jsonfile;
+ bool ok;
+ tosa_err_t err;
+
+ if (!_schemaLoaded)
+ {
+ return TOSA_SCHEMA_MISSING;
+ }
+
+ ok = flatbuffers::LoadFile(filename, false, &jsonfile);
+ if (!ok)
+ {
+ printf("Error loading json file: %s\n", filename);
+ return TOSA_FILE_ERROR;
+ }
+
+ ok = _parser->Parse(jsonfile.c_str());
+ if (!ok)
+ {
+ printf("Error parsing json file: %s\n", filename);
+ return TOSA_FILE_ERROR;
+ }
+
+ uint8_t* buf = _parser->builder_.GetBufferPointer();
+
+ err = InitWithBuf(buf);
+ if (err != TOSA_OK)
+ {
+ return err;
+ }
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::SaveFileJson(const char* filename)
+{
+ std::string jsongen;
+ tosa_err_t err;
+
+ if (!_schemaLoaded)
+ {
+ return TOSA_SCHEMA_MISSING;
+ }
+
+ err = FreezeBuilder();
+ if (err != TOSA_OK)
+ {
+ return err;
+ }
+
+ uint8_t* buf = _builder->GetBufferPointer();
+
+ if (!GenerateText(*_parser, buf, &jsongen))
+ {
+ printf("Couldn't serialize parsed data to JSON!\n");
+ return TOSA_FILE_ERROR;
+ }
+
+ FILE* file = fopen(filename, "wb");
+
+ if (!file)
+ {
+ printf("Couldn't open output file: %s\n", filename);
+ return TOSA_FILE_ERROR;
+ }
+
+ if (fwrite(jsongen.c_str(), sizeof(char), jsongen.size(), file) != jsongen.size())
+ {
+ printf("Error writing to json output file: %s\n", filename);
+ fclose(file);
+ return TOSA_FILE_ERROR;
+ }
+
+ if (file)
+ fclose(file);
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::LoadFileTosaFlatbuffer(const char* filename)
+{
+ std::string read_buffer;
+ tosa_err_t err;
+ uint8_t* buf;
+ bool ok;
+
+ ok = flatbuffers::LoadFile(filename, false, &read_buffer);
+ if (!ok)
+ {
+ printf("Error loading flatbuffer file: %s\n", filename);
+ return TOSA_FILE_ERROR;
+ }
+
+ buf = (uint8_t*)read_buffer.data();
+
+ err = InitWithBuf(buf);
+ if (err != TOSA_OK)
+ {
+ return err;
+ }
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::SaveFileTosaFlatbuffer(const char* filename)
+{
+ tosa_err_t err;
+
+ err = FreezeBuilder();
+ if (err != TOSA_OK)
+ {
+ return err;
+ }
+
+ uint8_t* buf = _builder->GetBufferPointer();
+
+ bool ok = flatbuffers::SaveFile(filename, (const char*)buf, _builder->GetSize(), false);
+ if (!ok)
+ {
+ printf("Error saving floatbuffer file: %s\n", filename);
+ return TOSA_FILE_ERROR;
+ }
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::Clear()
+{
+ // deallocate all basic blocks
+ for (auto bb : GetBlocks())
+ {
+ delete bb;
+ }
+ _blocks->clear();
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::CheckTosaVersion(const TosaVersion& read_version)
+{
+ if ((*_version) != read_version)
+ {
+ printf("WARNING: read tosa version: %s != schema tosa version %s\n", read_version.to_string().c_str(),
+ this->_version->to_string().c_str());
+ return TOSA_VERSION_MISMATCH;
+ }
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::InitWithBuf(const uint8_t* buf)
+{
+ auto fb_tosa_graph = GetTosaGraph(buf);
+ auto fb_tosa_version = fb_tosa_graph->version();
+ auto fb_tosa_blocks = fb_tosa_graph->blocks();
+
+ std::vector<std::string> operator_inputs_container;
+ std::vector<std::string> operator_outputs_container;
+
+ std::vector<TosaSerializationOperator*> block_operators_container;
+ std::vector<TosaSerializationTensor*> block_tensors_container;
+ std::vector<std::string> block_inputs_container;
+ std::vector<std::string> block_outputs_container;
+
+ TosaAttributeBase* typed_attribute = NULL;
+ TosaQuantInfoBase* typed_qinfo = NULL;
+ TosaSerializationOperator* new_operator = NULL;
+ TosaSerializationBasicBlock* new_block = NULL;
+ TosaSerializationTensor* new_tensor = NULL;
+
+ // erase container
+ Clear();
+
+ TosaVersion read_version(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(),
+ fb_tosa_version->_experimental());
+ tosa_err_t err = CheckTosaVersion(read_version);
+
+ if (err != TOSA_OK)
+ return err;
+
+ for (size_t i = 0; i < fb_tosa_blocks->size(); i++)
+ {
+ auto curr_block = fb_tosa_blocks->Get(i);
+
+ auto block_name = curr_block->name()->str();
+
+ auto fb_tosa_operators = curr_block->operators();
+ block_operators_container.clear();
+ for (size_t j = 0; j < fb_tosa_operators->size(); j++)
+ {
+ auto curr_operator = fb_tosa_operators->Get(j);
+
+ auto operator_op = curr_operator->op();
+ auto attribute_type = curr_operator->attribute_type();
+ auto attribute = curr_operator->attribute();
+ auto operator_qinfo_type = curr_operator->quant_info_type();
+ auto operator_qinfo = curr_operator->quant_info();
+
+ // input tensors
+ auto operator_inputs = curr_operator->inputs();
+ operator_inputs_container.clear();
+ if (operator_inputs)
+ {
+ for (size_t k = 0; k < operator_inputs->size(); k++)
+ {
+ auto curr_input = operator_inputs->Get(k);
+ operator_inputs_container.push_back(curr_input->str());
+ }
+ }
+
+ // output tensors
+ auto operator_outputs = curr_operator->outputs();
+ operator_outputs_container.clear();
+ if (operator_outputs)
+ {
+ for (size_t k = 0; k < operator_outputs->size(); k++)
+ {
+ auto curr_output = operator_outputs->Get(k);
+ operator_outputs_container.push_back(curr_output->str());
+ }
+ }
+
+ switch (attribute_type)
+ {
+ case Attribute_NONE:
+ typed_attribute = new TosaNoneAttribute();
+ break;
+#define DEF_ATTRIBUTE(NAME, ...) \
+ case Attribute_##NAME##Attribute: \
+ typed_attribute = new Tosa##NAME##Attribute(attribute); \
+ break;
+#include "attribute.def"
+#undef DEF_ATTRIBUTE
+ default:
+ printf("TosaSerializationHandler::InitWithBuf(): Attribute %s not implemented yet\n",
+ EnumNamesAttribute()[attribute_type]);
+ return TOSA_INTERNAL_ERROR;
+ }
+
+ switch (operator_qinfo_type)
+ {
+ case QuantInfo_NONE:
+ typed_qinfo = new TosaNoneQuantInfo();
+ break;
+#define DEF_QUANTIZATION_INFO(NAME, ...) \
+ case QuantInfo_##NAME##QuantInfo: \
+ typed_qinfo = new Tosa##NAME##QuantInfo(operator_qinfo); \
+ break;
+
+#include "quant_info.def"
+#undef DEF_QUANTIZATION_INFO
+ default:
+ printf("TosaSerializationHandler::InitWithBuf(): QuantInfo %s not implemented yet\n",
+ EnumNamesQuantInfo()[operator_qinfo_type]);
+ return TOSA_INTERNAL_ERROR;
+ }
+
+ new_operator =
+ new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, operator_qinfo_type,
+ typed_qinfo, operator_inputs_container, operator_outputs_container);
+ if (new_operator)
+ {
+ block_operators_container.push_back(new_operator);
+ }
+ else
+ {
+ return TOSA_MEMORY_ERROR;
+ }
+
+ if (typed_attribute)
+ delete typed_attribute;
+ if (typed_qinfo)
+ delete typed_qinfo;
+ }
+
+ auto fb_tosa_tensors = curr_block->tensors();
+ block_tensors_container.clear();
+ for (size_t j = 0; j < fb_tosa_tensors->size(); j++)
+ {
+ auto curr_tensor = fb_tosa_tensors->Get(j);
+
+ auto tensor_name = curr_tensor->name();
+ auto tensor_usage = curr_tensor->usage();
+ auto tensor_shape = curr_tensor->shape();
+ auto tensor_type = curr_tensor->type();
+ auto tensor_format = curr_tensor->format();
+ auto tensor_npy_filename = curr_tensor->npy_filename();
+
+ new_tensor = new TosaSerializationTensor(tensor_name, *tensor_usage, *tensor_shape, tensor_type,
+ *tensor_format, tensor_npy_filename);
+ if (new_tensor)
+ {
+ block_tensors_container.push_back(new_tensor);
+ }
+ else
+ {
+ return TOSA_MEMORY_ERROR;
+ }
+ }
+
+ auto block_inputs = curr_block->inputs();
+ auto block_outputs = curr_block->outputs();
+
+ block_inputs_container.clear();
+ block_outputs_container.clear();
+
+ for (size_t j = 0; j < block_inputs->size(); j++)
+ {
+ auto curr_block_input = block_inputs->Get(j);
+ block_inputs_container.push_back(curr_block_input->str());
+ }
+ for (size_t j = 0; j < block_outputs->size(); j++)
+ {
+ auto curr_block_output = block_outputs->Get(j);
+ block_outputs_container.push_back(curr_block_output->str());
+ }
+
+ new_block = new TosaSerializationBasicBlock(block_name, block_operators_container, block_tensors_container,
+ block_inputs_container, block_outputs_container);
+ if (new_block)
+ {
+ this->GetBlocks().push_back(new_block);
+ }
+ else
+ {
+ return TOSA_MEMORY_ERROR;
+ }
+ }
+
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::FreezeBuilder()
+{
+ std::vector<flatbuffers::Offset<TosaBasicBlock>> fboffset_blocks;
+
+ std::vector<flatbuffers::Offset<TosaOperator>> fboffset_block_operators;
+ std::vector<flatbuffers::Offset<TosaTensor>> fboffset_block_tensors;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_inputs;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_outputs;
+
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_inputs;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_outputs;
+
+ // translate TosaFlatbufferOperator to flatbuffers::Offset<TosaOperator>
+ for (auto block : GetBlocks())
+ {
+ fboffset_block_operators.clear();
+ fboffset_block_tensors.clear();
+ fboffset_block_inputs.clear();
+ fboffset_block_outputs.clear();
+
+ auto block_name = _builder->CreateString(block->GetName().c_str());
+
+ for (auto tensor_str : block->GetInputs())
+ {
+ auto tensor_name = _builder->CreateString(tensor_str.c_str());
+ fboffset_block_inputs.push_back(tensor_name);
+ }
+
+ for (auto tensor_str : block->GetOutputs())
+ {
+ auto tensor_name = _builder->CreateString(tensor_str.c_str());
+ fboffset_block_outputs.push_back(tensor_name);
+ }
+
+ auto fb_block_inputs = _builder->CreateVector(fboffset_block_inputs);
+ auto fb_block_outputs = _builder->CreateVector(fboffset_block_outputs);
+
+ for (auto op : block->GetOperators())
+ {
+ fboffset_operator_inputs.clear();
+ fboffset_operator_outputs.clear();
+
+ auto operator_op = op->GetOp();
+ auto attribute_type = op->GetAttributeType();
+
+ for (auto tensor_str : op->GetInputTensorNames())
+ {
+ auto tensor_name = _builder->CreateString(tensor_str.c_str());
+ fboffset_operator_inputs.push_back(tensor_name);
+ }
+
+ for (auto tensor_str : op->GetOutputTensorNames())
+ {
+ auto tensor_name = _builder->CreateString(tensor_str.c_str());
+ fboffset_operator_outputs.push_back(tensor_name);
+ }
+
+ auto fb_operator_inputs = _builder->CreateVector(fboffset_operator_inputs);
+ auto fb_operator_outputs = _builder->CreateVector(fboffset_operator_outputs);
+
+ flatbuffers::Offset<void> fb_attribute;
+ switch (attribute_type)
+ {
+ case Attribute_NONE:
+ fb_attribute = 0;
+ break;
+
+#define DEF_ARGS_S_STR(NAME, V) , _builder->CreateString(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V().c_str())
+#define DEF_ARGS_S_DEFAULT(NAME, V) , reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V()
+
+#define DEF_ARGS_S_int32_t(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
+#define DEF_ARGS_S_float(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
+#define DEF_ARGS_S_bool(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
+#define DEF_ARGS_S_ResizeMode(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
+#define DEF_ARGS_S_string(NAME, V) DEF_ARGS_S_STR(NAME, V)
+
+#define DEF_ARGS_S(NAME, T, V) DEF_ARGS_S_##T(NAME, V)
+#define DEF_ARGS_V(NAME, T, V) , _builder->CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V())
+
+#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0)
+#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1)
+#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2)
+#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3)
+#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4)
+#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5)
+#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6)
+#define DEF_ATTRIBUTE(NAME, NUM_ARGS, ...) \
+ case Attribute_##NAME##Attribute: \
+ fb_attribute = Create##NAME##Attribute(*_builder DEF_ARGS_##NUM_ARGS(NAME##Attribute, __VA_ARGS__)).Union(); \
+ break;
+
+#include "attribute.def"
+#undef DEF_ATTRIBUTE
+#undef DEF_ARGS_1
+#undef DEF_ARGS_2
+#undef DEF_ARGS_3
+#undef DEF_ARGS_4
+#undef DEF_ARGS_5
+#undef DEF_ARGS_6
+#undef DEF_ARGS_7
+#undef DEF_ARGS_S
+#undef DEF_ARGS_V
+#undef DEF_ARGS_S_int32_t
+#undef DEF_ARGS_S_float
+#undef DEF_ARGS_S_bool
+#undef DEF_ARGS_S_ResizeMode
+#undef DEF_ARGS_S_string
+#undef DEF_ARGS_S_STR
+#undef DEF_ARGS_S_DEFAULT
+ default:
+ printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n",
+ EnumNamesAttribute()[attribute_type]);
+ return TOSA_INTERNAL_ERROR;
+ }
+
+ auto qinfo_type = op->GetQInfoType();
+ flatbuffers::Offset<void> fb_operator_qinfo;
+ switch (qinfo_type)
+ {
+ case QuantInfo_NONE:
+ fb_operator_qinfo = 0;
+ break;
+#define DEF_ARGS_S(NAME, T, V) , reinterpret_cast<Tosa##NAME*>(op->GetQInfo())->V()
+#define DEF_ARGS_V(NAME, T, V) , _builder->CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetQInfo())->V())
+
+#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0)
+#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1)
+#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2)
+#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3)
+#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4)
+#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5)
+#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6)
+#define DEF_ARGS_8(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
+ DEF_ARGS_##F7(NAME, T7, V7)
+#define DEF_ARGS_9(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7, T8, F8, V8) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
+ DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8)
+#define DEF_ARGS_10(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7, T8, F8, V8, T9, F9, V9) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
+ DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) DEF_ARGS_##F9(NAME, T9, V9)
+#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \
+ case QuantInfo_##NAME##QuantInfo: \
+ fb_operator_qinfo = \
+ Create##NAME##QuantInfo(*_builder DEF_ARGS_##NUM_ARGS(NAME##QuantInfo, __VA_ARGS__)).Union(); \
+ break;
+
+#include "quant_info.def"
+#undef DEF_QUANTIZATION_INFO
+#undef DEF_ARGS_1
+#undef DEF_ARGS_2
+#undef DEF_ARGS_3
+#undef DEF_ARGS_4
+#undef DEF_ARGS_5
+#undef DEF_ARGS_6
+#undef DEF_ARGS_7
+#undef DEF_ARGS_8
+#undef DEF_ARGS_9
+#undef DEF_ARGS_10
+#undef DEF_ARGS_S
+#undef DEF_ARGS_V
+ default:
+ printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n",
+ EnumNamesAttribute()[attribute_type]);
+ return TOSA_INTERNAL_ERROR;
+ }
+
+ auto fboffset_operator =
+ CreateTosaOperator(*_builder, operator_op, attribute_type, fb_attribute, fb_operator_inputs,
+ fb_operator_outputs, qinfo_type, fb_operator_qinfo);
+ fboffset_block_operators.push_back(fboffset_operator);
+ }
+
+ auto fb_block_operators = _builder->CreateVector(fboffset_block_operators);
+
+ for (auto tensor : block->GetTensors())
+ {
+
+ auto tensor_name = _builder->CreateString(tensor->GetName().c_str());
+ auto tensor_usage =
+ _builder->CreateVector(std::vector<uint32_t>(tensor->GetUsage().begin(), tensor->GetUsage().end()));
+ auto tensor_shape = _builder->CreateVector(tensor->GetShape());
+ auto tensor_dtype = tensor->GetDtype();
+ auto tensor_format =
+ _builder->CreateVector(std::vector<uint32_t>(tensor->GetFormat().begin(), tensor->GetFormat().end()));
+ flatbuffers::Offset<flatbuffers::String> tensor_npy_filename = 0;
+ if (tensor->GetNpyFilePtr())
+ tensor_npy_filename = _builder->CreateString(tensor->GetNpyFilePtr()->c_str());
+
+ auto fboffset_tensor = CreateTosaTensor(*_builder, tensor_name, tensor_shape, tensor_dtype, tensor_usage,
+ tensor_format, tensor_npy_filename);
+ fboffset_block_tensors.push_back(fboffset_tensor);
+ }
+
+ auto fb_block_tensors = _builder->CreateVector(fboffset_block_tensors);
+
+ auto fboffset_block = CreateTosaBasicBlock(*_builder, block_name, fb_block_operators, fb_block_tensors,
+ fb_block_inputs, fb_block_outputs);
+ fboffset_blocks.push_back(fboffset_block);
+ }
+
+ auto fb_blocks = _builder->CreateVector(fboffset_blocks);
+
+ auto fb_version = CreateVersion(*_builder, GetTosaVersion()->_major, GetTosaVersion()->_minor,
+ GetTosaVersion()->_patch, GetTosaVersion()->_experimental);
+
+ auto fb_graph = CreateTosaGraph(*_builder, fb_version, fb_blocks);
+ _builder->Finish(fb_graph);
+
+ return TOSA_OK;
+}
+
+// Magic NUMPY header
+static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
+static const int NUMPY_HEADER_SZ = 128;
+
+NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
+{
+ const char dtype_str[] = "'|b1'";
+ FILE* infile = nullptr;
+ NPError rc = NO_ERROR;
+
+ assert(filename);
+ assert(databuf);
+
+ infile = fopen(filename, "rb");
+ if (!infile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ rc = checkNpyHeader(infile, elems, dtype_str);
+ if (rc != NO_ERROR)
+ {
+ goto done;
+ }
+
+ // Read in the data from numpy byte array to native bool
+ // array format
+ for (uint32_t i = 0; i < elems; i++)
+ {
+ int val = fgetc(infile);
+
+ if (val == EOF)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+ databuf[i] = val;
+ }
+
+done:
+
+ if (infile)
+ fclose(infile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
+{
+ const char dtype_str[] = "'<i4'";
+ FILE* infile = nullptr;
+ NPError rc = NO_ERROR;
+
+ assert(filename);
+ assert(databuf);
+
+ infile = fopen(filename, "rb");
+ if (!infile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ rc = checkNpyHeader(infile, elems, dtype_str);
+ if (rc != NO_ERROR)
+ {
+ goto done;
+ }
+
+ // Now we are at the beginning of the data
+ // Parse based on the datatype and number of dimensions
+ if (fread(databuf, sizeof(int32_t), elems, infile) != elems)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ if (infile)
+ fclose(infile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
+{
+ const char dtype_str[] = "'<i8'";
+ FILE* infile = nullptr;
+ NPError rc = NO_ERROR;
+
+ assert(filename);
+ assert(databuf);
+
+ infile = fopen(filename, "rb");
+ if (!infile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ rc = checkNpyHeader(infile, elems, dtype_str);
+ if (rc != NO_ERROR)
+ {
+ goto done;
+ }
+
+ // Now we are at the beginning of the data
+ // Parse based on the datatype and number of dimensions
+ if (fread(databuf, sizeof(int64_t), elems, infile) != elems)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ if (infile)
+ fclose(infile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
+{
+ const char dtype_str[] = "'<f4'";
+ FILE* infile = nullptr;
+ NPError rc = NO_ERROR;
+
+ assert(filename);
+ assert(databuf);
+
+ infile = fopen(filename, "rb");
+ if (!infile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ rc = checkNpyHeader(infile, elems, dtype_str);
+ if (rc != NO_ERROR)
+ {
+ goto done;
+ }
+
+ // Now we are at the beginning of the data
+ // Parse based on the datatype and number of dimensions
+ if (fread(databuf, sizeof(float), elems, infile) != elems)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ if (infile)
+ fclose(infile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
+{
+ char buf[NUMPY_HEADER_SZ + 1];
+ char* ptr = nullptr;
+ NPError rc = NO_ERROR;
+ bool foundFormat = false;
+ bool foundOrder = false;
+ bool foundShape = false;
+ bool fortranOrder = false;
+ std::vector<int> shape;
+ uint32_t totalElems = 1;
+ char* outer_end = NULL;
+
+ assert(infile);
+ assert(elems > 0);
+
+ if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
+ {
+ rc = HEADER_PARSE_ERROR;
+ goto done;
+ }
+
+ if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
+ {
+ rc = HEADER_PARSE_ERROR;
+ goto done;
+ }
+
+ ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
+
+ // Read in the data type, order, and shape
+ while (ptr && (!foundFormat || !foundOrder || !foundShape))
+ {
+
+ // End of string?
+ if (!ptr)
+ break;
+
+ // Skip whitespace
+ while (isspace(*ptr))
+ ptr++;
+
+ // Parse the dictionary field name
+ if (!strcmp(ptr, "'descr'"))
+ {
+ ptr = strtok_r(NULL, ",", &outer_end);
+ if (!ptr)
+ break;
+
+ while (isspace(*ptr))
+ ptr++;
+
+ if (strcmp(ptr, dtype_str))
+ {
+ rc = FILE_TYPE_MISMATCH;
+ goto done;
+ }
+
+ foundFormat = true;
+ }
+ else if (!strcmp(ptr, "'fortran_order'"))
+ {
+ ptr = strtok_r(NULL, ",", &outer_end);
+ if (!ptr)
+ break;
+
+ while (isspace(*ptr))
+ ptr++;
+
+ if (!strcmp(ptr, "False"))
+ {
+ fortranOrder = false;
+ }
+ else
+ {
+ rc = FILE_TYPE_MISMATCH;
+ goto done;
+ }
+
+ foundOrder = true;
+ }
+ else if (!strcmp(ptr, "'shape'"))
+ {
+
+ ptr = strtok_r(NULL, "(", &outer_end);
+ if (!ptr)
+ break;
+ ptr = strtok_r(NULL, ")", &outer_end);
+ if (!ptr)
+ break;
+
+ while (isspace(*ptr))
+ ptr++;
+
+ // The shape contains N comma-separated integers. Read up to 4.
+ char* end = NULL;
+
+ ptr = strtok_r(ptr, ",", &end);
+ for (int i = 0; i < 4; i++)
+ {
+ // Out of dimensions
+ if (!ptr)
+ break;
+
+ shape.push_back(atoi(ptr));
+ totalElems *= atoi(ptr);
+ ptr = strtok_r(NULL, ",", &end);
+ }
+
+ foundShape = true;
+ }
+ else
+ {
+ rc = HEADER_PARSE_ERROR;
+ goto done;
+ }
+
+ if (!ptr)
+ break;
+
+ ptr = strtok_r(NULL, ":", &outer_end);
+ }
+
+ if (!foundShape || !foundFormat || !foundOrder)
+ {
+ rc = HEADER_PARSE_ERROR;
+ goto done;
+ }
+
+ // Validate header
+ if (fortranOrder != false)
+ {
+ rc = FILE_TYPE_MISMATCH;
+ goto done;
+ }
+
+ if (totalElems != elems)
+ {
+ rc = BUFFER_SIZE_MISMATCH;
+ goto done;
+ }
+
+ // Go back to the begininng and read until the end of the header dictionary
+ rewind(infile);
+ int val;
+
+ do
+ {
+ val = fgetc(infile);
+ } while (val != EOF && val != '\n');
+
+done:
+
+ return rc;
+}
+
+NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
+{
+ std::vector<int32_t> shape = { (int32_t)elems };
+ return writeToNpyFile(filename, shape, databuf);
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
+{
+ const char dtype_str[] = "'|b1'";
+ FILE* outfile = nullptr;
+ NPError rc = NO_ERROR;
+ uint32_t totalElems = 1;
+
+ assert(filename);
+ assert(shape.size() >= 0);
+ assert(databuf);
+
+ outfile = fopen(filename, "wb");
+
+ if (!outfile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ for (uint32_t i = 0; i < shape.size(); i++)
+ {
+ totalElems *= shape[i];
+ }
+
+ rc = writeNpyHeader(outfile, shape, dtype_str);
+
+ // Numpy save format stores booleans as a byte array
+ // with one byte per boolean. This somewhat inefficiently
+ // remaps from system bool[] to this format.
+ for (uint32_t i = 0; i < totalElems; i++)
+ {
+ int val = databuf[i] ? 1 : 0;
+ if (fputc(val, outfile) == EOF)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+ }
+
+done:
+
+ if (outfile)
+ fclose(outfile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
+{
+ std::vector<int32_t> shape = { (int32_t)elems };
+ return writeToNpyFile(filename, shape, databuf);
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
+{
+ const char dtype_str[] = "'<i4'";
+ FILE* outfile = nullptr;
+ NPError rc = NO_ERROR;
+ uint32_t totalElems = 1;
+
+ assert(filename);
+ assert(shape.size() >= 0);
+ assert(databuf);
+
+ outfile = fopen(filename, "wb");
+
+ if (!outfile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ for (uint32_t i = 0; i < shape.size(); i++)
+ {
+ totalElems *= shape[i];
+ }
+
+ rc = writeNpyHeader(outfile, shape, dtype_str);
+
+ if (fwrite(databuf, sizeof(int32_t), totalElems, outfile) != totalElems)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ if (outfile)
+ fclose(outfile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
+{
+ std::vector<int32_t> shape = { (int32_t)elems };
+ return writeToNpyFile(filename, shape, databuf);
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
+{
+ const char dtype_str[] = "'<i8'";
+ FILE* outfile = nullptr;
+ NPError rc = NO_ERROR;
+ uint32_t totalElems = 1;
+
+ assert(filename);
+ assert(shape.size() >= 0);
+ assert(databuf);
+
+ outfile = fopen(filename, "wb");
+
+ if (!outfile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ for (uint32_t i = 0; i < shape.size(); i++)
+ {
+ totalElems *= shape[i];
+ }
+
+ rc = writeNpyHeader(outfile, shape, dtype_str);
+
+ if (fwrite(databuf, sizeof(int64_t), totalElems, outfile) != totalElems)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ if (outfile)
+ fclose(outfile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
+{
+ std::vector<int32_t> shape = { (int32_t)elems };
+ return writeToNpyFile(filename, shape, databuf);
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
+{
+ const char dtype_str[] = "'<f4'";
+ FILE* outfile = nullptr;
+ NPError rc = NO_ERROR;
+ uint32_t totalElems = 1;
+
+ assert(filename);
+ assert(shape.size() >= 0);
+ assert(databuf);
+
+ outfile = fopen(filename, "wb");
+
+ if (!outfile)
+ {
+ rc = FILE_NOT_FOUND;
+ goto done;
+ }
+
+ for (uint32_t i = 0; i < shape.size(); i++)
+ {
+ totalElems *= shape[i];
+ }
+
+ rc = writeNpyHeader(outfile, shape, dtype_str);
+
+ if (fwrite(databuf, sizeof(float), totalElems, outfile) != totalElems)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ if (outfile)
+ fclose(outfile);
+
+ return rc;
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
+{
+ NPError rc = NO_ERROR;
+ uint32_t i;
+ char header[NUMPY_HEADER_SZ + 1];
+ int headerPos = 0;
+
+ assert(outfile);
+ assert(shape.size() >= 0);
+
+ // Space-fill the header and end with a newline to start per numpy spec
+ memset(header, 0x20, NUMPY_HEADER_SZ);
+ header[NUMPY_HEADER_SZ - 1] = '\n';
+ header[NUMPY_HEADER_SZ] = 0;
+
+ // Write out the hard-coded header. We only support a 128-byte 1.0 header
+ // for now, which should be sufficient for simple tensor types of any
+ // reasonable rank.
+ memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
+ headerPos += sizeof(NUMPY_HEADER_STR) - 1;
+
+ // Output the format dictionary
+ // Hard-coded for I32 for now
+ headerPos +=
+ snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
+ dtype_str, shape.size() > 0 ? shape[0] : 1);
+
+ // Remainder of shape array
+ for (i = 1; i < shape.size(); i++)
+ {
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
+ }
+
+ // Close off the dictionary
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
+
+ // snprintf leaves a NULL at the end. Replace with a space
+ header[headerPos] = 0x20;
+
+ if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
+ {
+ rc = FILE_IO_ERROR;
+ goto done;
+ }
+
+done:
+
+ return rc;
+}
diff --git a/serialization/tosa_serialization_handler.h b/serialization/tosa_serialization_handler.h
new file mode 100644
index 0000000..124b8e0
--- /dev/null
+++ b/serialization/tosa_serialization_handler.h
@@ -0,0 +1,423 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef _TOSA_SERIALIZATION_HANDLER_H
+#define _TOSA_SERIALIZATION_HANDLER_H
+#include "attribute.h"
+#include "flatbuffers/idl.h"
+#include "flatbuffers/util.h"
+#include "quant_info.h"
+#include "tosa_generated.h"
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace tosa
+{
+
+enum tosa_err_t
+{
+ TOSA_OK,
+ TOSA_USER_ERROR,
+ TOSA_FILE_ERROR,
+ TOSA_MEMORY_ERROR,
+ TOSA_SCHEMA_MISSING,
+ TOSA_INTERNAL_ERROR,
+ TOSA_VERSION_MISMATCH,
+ NUM_TOSA_ERROR
+};
+
+struct TosaVersion
+{
+ int32_t _major;
+ int32_t _minor;
+ int32_t _patch;
+ bool _experimental;
+
+ TosaVersion() = delete;
+ TosaVersion(int32_t major, int32_t minor, int32_t patch, bool experimental)
+ {
+ _major = major;
+ _minor = minor;
+ _patch = patch;
+ _experimental = experimental;
+ }
+
+ std::string to_string() const
+ {
+ std::string str;
+ str += std::to_string(_major) + ".";
+ str += std::to_string(_minor) + ".";
+ str += std::to_string(_patch);
+ if (_experimental)
+ str += "(experimental)";
+ return str;
+ };
+
+ bool operator==(const TosaVersion& rhs)
+ {
+ if (rhs._major == _major && rhs._minor == _minor && rhs._patch == _patch && rhs._experimental == _experimental)
+ {
+ return true;
+ }
+ return false;
+ }
+
+ bool operator!=(const TosaVersion& rhs)
+ {
+ return !((*this) == rhs);
+ }
+};
+
+class TosaSerializationHandler;
+
+class TosaSerializationTensor
+{
+public:
+ // constructor and destructor
+ TosaSerializationTensor(const flatbuffers::String* name,
+ const flatbuffers::Vector<uint32_t>& usage,
+ const flatbuffers::Vector<int32_t>& shape,
+ DType dtype,
+ const flatbuffers::Vector<uint32_t>& format,
+ const flatbuffers::String* npy_filename);
+ TosaSerializationTensor(std::string name,
+ const std::vector<Usage>& usage,
+ const std::vector<int32_t>& shape,
+ DType dtype,
+ const std::vector<Format>& format,
+ const std::string* npy_filename);
+ TosaSerializationTensor();
+ ~TosaSerializationTensor();
+
+ // copy constructor/assignment
+ TosaSerializationTensor(const TosaSerializationTensor& rhs);
+ TosaSerializationTensor& operator=(const TosaSerializationTensor& rhs);
+
+ // move constructor/assignment
+ TosaSerializationTensor(TosaSerializationTensor&& rhs);
+ TosaSerializationTensor& operator=(TosaSerializationTensor&& rhs);
+
+ // accessor
+ std::string GetName() const
+ {
+ return *_name;
+ }
+ const std::vector<int32_t>& GetShape() const
+ {
+ return *_shape;
+ }
+ DType GetDtype()
+ {
+ return _dtype;
+ }
+ bool HasFormat(Format format)
+ {
+ for (Format us : *_format)
+ {
+ if (us == format)
+ return true;
+ }
+ return false;
+ }
+ std::vector<Format>& GetFormat()
+ {
+ return *_format;
+ }
+ bool HasUsage(Usage usage)
+ {
+ for (Usage us : *_usage)
+ {
+ if (us == usage)
+ return true;
+ }
+ return false;
+ }
+ std::vector<Usage>& GetUsage()
+ {
+ return *_usage;
+ }
+ std::string* GetNpyFilePtr() const
+ {
+ return _npy_filename;
+ }
+
+ // modifier
+ void SetDtype(DType dtype)
+ {
+ _dtype = dtype;
+ }
+ void SetName(std::string name)
+ {
+ *_name = name;
+ }
+
+private:
+ DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
+ std::vector<Format>* _format; /* list of possible tensor format */
+ std::vector<Usage>* _usage; /* list of possible tensor usage */
+ std::vector<int32_t>* _shape; /* shape of the tensor */
+ std::string* _name; /* name of the tensor, used for solving dependency */
+ std::string* _npy_filename; /* numpy array filename if not null. so null is the distinguisher */
+};
+
+class TosaSerializationOperator
+{
+public:
+ // use default copy, void constructor
+ // constructor and destructor
+ TosaSerializationOperator(Op op_name,
+ Attribute attribute_type,
+ const TosaAttributeBase* attribute,
+ QuantInfo qinfo_type,
+ const TosaQuantInfoBase* qinfo,
+ std::vector<std::string> input_tensor_names,
+ std::vector<std::string> output_tensor_names);
+ ~TosaSerializationOperator();
+
+ // accessor
+ Op GetOp() const
+ {
+ return _op;
+ }
+ Attribute GetAttributeType() const
+ {
+ return _attribute_type;
+ }
+ TosaAttributeBase* GetAttribute() const
+ {
+ return _attribute;
+ }
+ QuantInfo GetQInfoType() const
+ {
+ return _qinfo_type;
+ }
+ TosaQuantInfoBase* GetQInfo() const
+ {
+ return _qinfo;
+ }
+ std::vector<std::string>& GetInputTensorNames() const
+ {
+ return *_input_tensor_names;
+ }
+ std::vector<std::string>& GetOutputTensorNames() const
+ {
+ return *_output_tensor_names;
+ }
+ std::vector<TosaSerializationTensor*>& GetInputTensors() const
+ {
+ return *_input_tensors;
+ }
+ std::vector<TosaSerializationTensor*>& GetOutputTensors() const
+ {
+ return *_output_tensors;
+ }
+
+private:
+ Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */
+ Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */
+ TosaAttributeBase* _attribute; /* real attribute class goes here */
+ QuantInfo _qinfo_type; /* QuantInfo enum */
+ TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */
+ std::vector<std::string>* _input_tensor_names; /* array of input tensor names */
+ std::vector<std::string>* _output_tensor_names; /* array of output tensor names */
+
+ std::vector<TosaSerializationTensor*>* _input_tensors; /* array of input TosaSerializationTensor */
+ std::vector<TosaSerializationTensor*>* _output_tensors; /* array of output TosaSerializationTensor */
+};
+
+class TosaSerializationBasicBlock
+{
+public:
+ // constructor and destructor
+ TosaSerializationBasicBlock(std::string name,
+ std::vector<TosaSerializationOperator*> operators,
+ std::vector<TosaSerializationTensor*> tensors,
+ std::vector<std::string> inputs,
+ std::vector<std::string> outputs);
+ ~TosaSerializationBasicBlock();
+
+ // accessor
+ std::string GetName() const
+ {
+ return *_name;
+ }
+ std::vector<TosaSerializationOperator*>& GetOperators()
+ {
+ return *_operators;
+ }
+ std::vector<TosaSerializationTensor*>& GetTensors()
+ {
+ return *_tensors;
+ }
+
+ TosaSerializationTensor* GetTensorByName(std::string name)
+ {
+ TosaSerializationTensor* result = nullptr;
+ for (auto tensor : GetTensors())
+ {
+ if (tensor->GetName() == name)
+ {
+ result = tensor;
+ break;
+ }
+ }
+ return result;
+ }
+
+ std::vector<std::string>& GetInputs()
+ {
+ return *_inputs;
+ }
+ std::vector<std::string>& GetOutputs()
+ {
+ return *_outputs;
+ }
+
+private:
+ std::string* _name; /* name of basic block */
+ std::vector<TosaSerializationOperator*>* _operators; /* TosaSerializationOperator list */
+ std::vector<TosaSerializationTensor*>* _tensors; /* TosaSerializationTensor list */
+ std::vector<std::string>* _inputs; /* array of string to specify block inputs */
+ std::vector<std::string>* _outputs; /* array of string to specify block outputs */
+};
+
+/*
+ * this is a helper class for writing/reading Tosa ISA
+ * supported format: .tosa (flatbuffer), .json
+ * and provide high-level std::vector-like interface
+ * to access internal data structure
+ */
+class TosaSerializationHandler
+{
+public:
+ // constructor and destructor
+ TosaSerializationHandler();
+ ~TosaSerializationHandler();
+
+ // file io
+ tosa_err_t LoadFileJson(const char* filename);
+ tosa_err_t LoadFileTosaFlatbuffer(const char* filename);
+ tosa_err_t SaveFileJson(const char* filename);
+ tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
+ tosa_err_t LoadFileSchema(const char* filename);
+
+ // version
+ TosaVersion* GetTosaVersion() const
+ {
+ return _version;
+ }
+
+ // accessor
+ std::vector<TosaSerializationBasicBlock*>& GetBlocks()
+ {
+ return *_blocks;
+ }
+
+ TosaSerializationBasicBlock* GetBlockByName(std::string name)
+ {
+ TosaSerializationBasicBlock* result = nullptr;
+ for (auto block : GetBlocks())
+ {
+ if (block->GetName() == name)
+ {
+ result = block;
+ break;
+ }
+ }
+ return result;
+ }
+ TosaSerializationBasicBlock* GetMainBlock()
+ {
+ TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main"));
+ assert(main_block);
+ return main_block;
+ }
+
+ std::vector<std::string>& GetInputs()
+ {
+ return GetMainBlock()->GetInputs();
+ }
+ std::vector<std::string>& GetOutputs()
+ {
+ return GetMainBlock()->GetOutputs();
+ }
+
+ bool GetSchemaLoaded() const
+ {
+ return _schemaLoaded;
+ }
+
+protected:
+ tosa_err_t Clear();
+ tosa_err_t InitWithBuf(const uint8_t* buf);
+ tosa_err_t FreezeBuilder();
+ tosa_err_t SetTosaVersion();
+ tosa_err_t CheckTosaVersion(const TosaVersion& read_version);
+
+private:
+ TosaVersion* _version; /* tosa version */
+ flatbuffers::FlatBufferBuilder* _builder; /* flatbuffer builder */
+ flatbuffers::Parser* _parser; /* flatbuffer parser, used for json parsing */
+ std::vector<TosaSerializationBasicBlock*>* _blocks; /* array structure to store all TosaSerializationBasicBlock */
+ bool _schemaLoaded; /* is the schema properly loaded? */
+};
+
+class NumpyUtilities
+{
+public:
+ enum NPError
+ {
+ NO_ERROR = 0,
+ FILE_NOT_FOUND,
+ FILE_IO_ERROR,
+ FILE_TYPE_MISMATCH,
+ HEADER_PARSE_ERROR,
+ BUFFER_SIZE_MISMATCH,
+ };
+
+ static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* buf);
+
+ static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* buf);
+
+ static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* buf);
+
+ static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* buf);
+
+ static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* buf);
+
+ static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* buf);
+
+ static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* buf);
+
+ static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* buf);
+
+ static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* buf);
+
+ static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* buf);
+
+ static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* buf);
+
+ static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* buf);
+
+private:
+ static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str);
+ static NPError writeNpyHeader(FILE* infile, const std::vector<int32_t>& shape, const char* dtype_str);
+};
+
+} // namespace tosa
+
+#endif // _TOSA_SERIALIZATION_HANDLER_H
diff --git a/thirdparty/CMakeLists.txt b/thirdparty/CMakeLists.txt
new file mode 100644
index 0000000..8c7bee3
--- /dev/null
+++ b/thirdparty/CMakeLists.txt
@@ -0,0 +1,10 @@
+cmake_minimum_required (VERSION 3.4)
+
+set(CMAKE_INSTALL_PREFIX "./thirdparty" CACHE PATH "..." FORCE)
+
+project(thirdparty LANGUAGES CXX)
+
+# Flatbuffers tests are not needed
+set(FLATBUFFERS_BUILD_TESTS OFF)
+
+add_subdirectory(flatbuffers)
diff --git a/thirdparty/eigen b/thirdparty/eigen
new file mode 160000
+Subproject 21ae2afd4edaa1b69782c67a54182d34efe43f9
diff --git a/thirdparty/flatbuffers b/thirdparty/flatbuffers
new file mode 160000
+Subproject bf9eb67ab9371755c6bcece13cadc7693bcbf26
diff --git a/verif/tosa/Attribute.py b/verif/tosa/Attribute.py
new file mode 100644
index 0000000..a4d96e0
--- /dev/null
+++ b/verif/tosa/Attribute.py
@@ -0,0 +1,36 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+class Attribute(object):
+ NONE = 0
+ Pool2dAttribute = 1
+ Conv2dAttribute = 2
+ TransposeConv2dAttribute = 3
+ ReluNAttribute = 4
+ AxisAttribute = 5
+ ReshapeAttribute = 6
+ SliceAttribute = 7
+ TileAttribute = 8
+ ResizeAttribute = 9
+ ClampAttribute = 10
+ RescaleAttribute = 11
+ CustomAttribute = 12
+ CondIfAttribute = 13
+ WhileLoopAttribute = 14
+
diff --git a/verif/tosa/AxisAttribute.py b/verif/tosa/AxisAttribute.py
new file mode 100644
index 0000000..d47eb81
--- /dev/null
+++ b/verif/tosa/AxisAttribute.py
@@ -0,0 +1,45 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class AxisAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsAxisAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = AxisAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # AxisAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # AxisAttribute
+ def Axis(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+def AxisAttributeStart(builder): builder.StartObject(1)
+def AxisAttributeAddAxis(builder, axis): builder.PrependInt32Slot(0, axis, 0)
+def AxisAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/ClampAttribute.py b/verif/tosa/ClampAttribute.py
new file mode 100644
index 0000000..ddc95cf
--- /dev/null
+++ b/verif/tosa/ClampAttribute.py
@@ -0,0 +1,69 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class ClampAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsClampAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = ClampAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # ClampAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # ClampAttribute
+ def MinInt(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+ # ClampAttribute
+ def MaxInt(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+ # ClampAttribute
+ def MinFp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
+ return 0.0
+
+ # ClampAttribute
+ def MaxFp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
+ return 0.0
+
+def ClampAttributeStart(builder): builder.StartObject(4)
+def ClampAttributeAddMinInt(builder, minInt): builder.PrependInt32Slot(0, minInt, 0)
+def ClampAttributeAddMaxInt(builder, maxInt): builder.PrependInt32Slot(1, maxInt, 0)
+def ClampAttributeAddMinFp(builder, minFp): builder.PrependFloat32Slot(2, minFp, 0.0)
+def ClampAttributeAddMaxFp(builder, maxFp): builder.PrependFloat32Slot(3, maxFp, 0.0)
+def ClampAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/CondIfAttribute.py b/verif/tosa/CondIfAttribute.py
new file mode 100644
index 0000000..0bf4566
--- /dev/null
+++ b/verif/tosa/CondIfAttribute.py
@@ -0,0 +1,53 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class CondIfAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsCondIfAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = CondIfAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # CondIfAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # CondIfAttribute
+ def ThenBranch(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.String(o + self._tab.Pos)
+ return None
+
+ # CondIfAttribute
+ def ElseBranch(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.String(o + self._tab.Pos)
+ return None
+
+def CondIfAttributeStart(builder): builder.StartObject(2)
+def CondIfAttributeAddThenBranch(builder, thenBranch): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(thenBranch), 0)
+def CondIfAttributeAddElseBranch(builder, elseBranch): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(elseBranch), 0)
+def CondIfAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/Conv2dAttribute.py b/verif/tosa/Conv2dAttribute.py
new file mode 100644
index 0000000..c7861a5
--- /dev/null
+++ b/verif/tosa/Conv2dAttribute.py
@@ -0,0 +1,109 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class Conv2dAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsConv2dAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = Conv2dAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # Conv2dAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # Conv2dAttribute
+ def Padding(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # Conv2dAttribute
+ def PaddingAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # Conv2dAttribute
+ def PaddingLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # Conv2dAttribute
+ def Stride(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # Conv2dAttribute
+ def StrideAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # Conv2dAttribute
+ def StrideLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # Conv2dAttribute
+ def Dilation(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # Conv2dAttribute
+ def DilationAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # Conv2dAttribute
+ def DilationLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+def Conv2dAttributeStart(builder): builder.StartObject(3)
+def Conv2dAttributeAddPadding(builder, padding): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(padding), 0)
+def Conv2dAttributeStartPaddingVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def Conv2dAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0)
+def Conv2dAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def Conv2dAttributeAddDilation(builder, dilation): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(dilation), 0)
+def Conv2dAttributeStartDilationVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def Conv2dAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/ConvQuantInfo.py b/verif/tosa/ConvQuantInfo.py
new file mode 100644
index 0000000..a88bfa6
--- /dev/null
+++ b/verif/tosa/ConvQuantInfo.py
@@ -0,0 +1,53 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class ConvQuantInfo(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsConvQuantInfo(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = ConvQuantInfo()
+ x.Init(buf, n + offset)
+ return x
+
+ # ConvQuantInfo
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # ConvQuantInfo
+ def InputZp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+ # ConvQuantInfo
+ def WeightZp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+def ConvQuantInfoStart(builder): builder.StartObject(2)
+def ConvQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0)
+def ConvQuantInfoAddWeightZp(builder, weightZp): builder.PrependInt32Slot(1, weightZp, 0)
+def ConvQuantInfoEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/CustomAttribute.py b/verif/tosa/CustomAttribute.py
new file mode 100644
index 0000000..25f6759
--- /dev/null
+++ b/verif/tosa/CustomAttribute.py
@@ -0,0 +1,45 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class CustomAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsCustomAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = CustomAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # CustomAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # CustomAttribute
+ def Identifier(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.String(o + self._tab.Pos)
+ return None
+
+def CustomAttributeStart(builder): builder.StartObject(1)
+def CustomAttributeAddIdentifier(builder, identifier): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(identifier), 0)
+def CustomAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/DType.py b/verif/tosa/DType.py
new file mode 100644
index 0000000..44d9970
--- /dev/null
+++ b/verif/tosa/DType.py
@@ -0,0 +1,31 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+class DType(object):
+ UNKNOWN = 0
+ BOOL = 1
+ AINT8 = 2
+ UINT8 = 3
+ INT4 = 4
+ INT8 = 5
+ INT16 = 6
+ INT32 = 7
+ INT48 = 8
+ FLOAT = 9
+
diff --git a/verif/tosa/Format.py b/verif/tosa/Format.py
new file mode 100644
index 0000000..5db4f27
--- /dev/null
+++ b/verif/tosa/Format.py
@@ -0,0 +1,27 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+class Format(object):
+ UNKNOWN = 0
+ NHWC = 1
+ NDHWC = 2
+ OHWI = 3
+ HWIM = 4
+ DOHWI = 5
+
diff --git a/verif/tosa/MatMulQuantInfo.py b/verif/tosa/MatMulQuantInfo.py
new file mode 100644
index 0000000..b8390a9
--- /dev/null
+++ b/verif/tosa/MatMulQuantInfo.py
@@ -0,0 +1,53 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class MatMulQuantInfo(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsMatMulQuantInfo(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = MatMulQuantInfo()
+ x.Init(buf, n + offset)
+ return x
+
+ # MatMulQuantInfo
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # MatMulQuantInfo
+ def AZp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+ # MatMulQuantInfo
+ def BZp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+def MatMulQuantInfoStart(builder): builder.StartObject(2)
+def MatMulQuantInfoAddAZp(builder, aZp): builder.PrependInt32Slot(0, aZp, 0)
+def MatMulQuantInfoAddBZp(builder, bZp): builder.PrependInt32Slot(1, bZp, 0)
+def MatMulQuantInfoEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/Op.py b/verif/tosa/Op.py
new file mode 100644
index 0000000..09f1364
--- /dev/null
+++ b/verif/tosa/Op.py
@@ -0,0 +1,90 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+class Op(object):
+ UNKNOWN = 0
+ ARGMAX = 1
+ AVG_POOL2D = 2
+ CONV2D = 3
+ CONV3D = 4
+ DEPTHWISE_CONV2D = 5
+ FULLY_CONNECTED = 6
+ MATMUL = 7
+ MAX_POOL2D = 8
+ TRANSPOSE_CONV2D = 9
+ CLAMP = 10
+ RELUN = 11
+ SIGMOID = 12
+ TANH = 13
+ ADD = 14
+ ARITHMETIC_RIGHT_SHIFT = 15
+ BITWISE_AND = 16
+ BITWISE_OR = 17
+ BITWISE_XOR = 18
+ LOGICAL_AND = 19
+ LOGICAL_LEFT_SHIFT = 20
+ LOGICAL_RIGHT_SHIFT = 21
+ LOGICAL_OR = 22
+ LOGICAL_XOR = 23
+ MAXIMUM = 24
+ MINIMUM = 25
+ MUL = 26
+ POW = 27
+ SUB = 28
+ TABLE = 29
+ ABS = 30
+ BITWISE_NOT = 31
+ CEIL = 32
+ CLZ = 33
+ EXP = 34
+ FLOOR = 35
+ LOG = 36
+ LOGICAL_NOT = 37
+ NEGATE = 38
+ RECIPROCAL = 39
+ RSQRT = 40
+ SELECT = 41
+ EQUAL = 42
+ GREATER = 43
+ GREATER_EQUAL = 44
+ REDUCE_ANY = 45
+ REDUCE_ALL = 46
+ REDUCE_MAX = 47
+ REDUCE_MIN = 48
+ REDUCE_PRODUCT = 49
+ REDUCE_SUM = 50
+ CONCAT = 51
+ PAD = 52
+ RESHAPE = 53
+ REVERSE = 54
+ SLICE = 55
+ TILE = 56
+ TRANSPOSE = 57
+ GATHER = 58
+ RESIZE = 59
+ CAST = 60
+ RESCALE = 61
+ CONST = 62
+ PLACEHOLDER = 63
+ IDENTITY = 64
+ IDENTITYN = 65
+ CUSTOM = 66
+ COND_IF = 67
+ WHILE_LOOP = 68
+
diff --git a/verif/tosa/PadQuantInfo.py b/verif/tosa/PadQuantInfo.py
new file mode 100644
index 0000000..df61926
--- /dev/null
+++ b/verif/tosa/PadQuantInfo.py
@@ -0,0 +1,45 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class PadQuantInfo(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsPadQuantInfo(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = PadQuantInfo()
+ x.Init(buf, n + offset)
+ return x
+
+ # PadQuantInfo
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # PadQuantInfo
+ def InputZp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+def PadQuantInfoStart(builder): builder.StartObject(1)
+def PadQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0)
+def PadQuantInfoEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/Pool2dAttribute.py b/verif/tosa/Pool2dAttribute.py
new file mode 100644
index 0000000..1520de2
--- /dev/null
+++ b/verif/tosa/Pool2dAttribute.py
@@ -0,0 +1,109 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class Pool2dAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsPool2dAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = Pool2dAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # Pool2dAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # Pool2dAttribute
+ def Padding(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # Pool2dAttribute
+ def PaddingAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # Pool2dAttribute
+ def PaddingLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # Pool2dAttribute
+ def Kernel(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # Pool2dAttribute
+ def KernelAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # Pool2dAttribute
+ def KernelLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # Pool2dAttribute
+ def Stride(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # Pool2dAttribute
+ def StrideAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # Pool2dAttribute
+ def StrideLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+def Pool2dAttributeStart(builder): builder.StartObject(3)
+def Pool2dAttributeAddPadding(builder, padding): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(padding), 0)
+def Pool2dAttributeStartPaddingVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def Pool2dAttributeAddKernel(builder, kernel): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernel), 0)
+def Pool2dAttributeStartKernelVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def Pool2dAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0)
+def Pool2dAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def Pool2dAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/QuantInfo.py b/verif/tosa/QuantInfo.py
new file mode 100644
index 0000000..0544cce
--- /dev/null
+++ b/verif/tosa/QuantInfo.py
@@ -0,0 +1,26 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+class QuantInfo(object):
+ NONE = 0
+ UnaryQuantInfo = 1
+ ConvQuantInfo = 2
+ MatMulQuantInfo = 3
+ PadQuantInfo = 4
+
diff --git a/verif/tosa/README.md b/verif/tosa/README.md
new file mode 100644
index 0000000..de8c1f9
--- /dev/null
+++ b/verif/tosa/README.md
@@ -0,0 +1,14 @@
+TOSA FlatBuffers python serialization library
+=============================================
+
+Files in this directory are automatically generated by running:
+
+``` bash
+../build/thirdparty/flatbuffers/flatc --python ../serialization/tosa.fbs
+```
+
+From the ``verif/`` directory. Flatc is compiled along with the *TOSA
+Reference Model*.
+
+*Because they are automatically generated, please do not edit the
+python files in this directory by hand.*
diff --git a/verif/tosa/ReluNAttribute.py b/verif/tosa/ReluNAttribute.py
new file mode 100644
index 0000000..e446c03
--- /dev/null
+++ b/verif/tosa/ReluNAttribute.py
@@ -0,0 +1,53 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class ReluNAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsReluNAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = ReluNAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # ReluNAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # ReluNAttribute
+ def MaxInt(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+ # ReluNAttribute
+ def MaxFp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
+ return 0.0
+
+def ReluNAttributeStart(builder): builder.StartObject(2)
+def ReluNAttributeAddMaxInt(builder, maxInt): builder.PrependInt32Slot(0, maxInt, 0)
+def ReluNAttributeAddMaxFp(builder, maxFp): builder.PrependFloat32Slot(1, maxFp, 0.0)
+def ReluNAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/RescaleAttribute.py b/verif/tosa/RescaleAttribute.py
new file mode 100644
index 0000000..0ec8c2b
--- /dev/null
+++ b/verif/tosa/RescaleAttribute.py
@@ -0,0 +1,125 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class RescaleAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsRescaleAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = RescaleAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # RescaleAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # RescaleAttribute
+ def InputZp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+ # RescaleAttribute
+ def OutputZp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+ # RescaleAttribute
+ def Multiplier(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # RescaleAttribute
+ def MultiplierAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # RescaleAttribute
+ def MultiplierLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # RescaleAttribute
+ def Shift(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # RescaleAttribute
+ def ShiftAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # RescaleAttribute
+ def ShiftLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # RescaleAttribute
+ def Scale32(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+ return False
+
+ # RescaleAttribute
+ def DoubleRound(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+ if o != 0:
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+ return False
+
+ # RescaleAttribute
+ def PerChannel(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+ if o != 0:
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+ return False
+
+def RescaleAttributeStart(builder): builder.StartObject(7)
+def RescaleAttributeAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0)
+def RescaleAttributeAddOutputZp(builder, outputZp): builder.PrependInt32Slot(1, outputZp, 0)
+def RescaleAttributeAddMultiplier(builder, multiplier): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(multiplier), 0)
+def RescaleAttributeStartMultiplierVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def RescaleAttributeAddShift(builder, shift): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(shift), 0)
+def RescaleAttributeStartShiftVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def RescaleAttributeAddScale32(builder, scale32): builder.PrependBoolSlot(4, scale32, 0)
+def RescaleAttributeAddDoubleRound(builder, doubleRound): builder.PrependBoolSlot(5, doubleRound, 0)
+def RescaleAttributeAddPerChannel(builder, perChannel): builder.PrependBoolSlot(6, perChannel, 0)
+def RescaleAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/ReshapeAttribute.py b/verif/tosa/ReshapeAttribute.py
new file mode 100644
index 0000000..2c50cef
--- /dev/null
+++ b/verif/tosa/ReshapeAttribute.py
@@ -0,0 +1,61 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class ReshapeAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsReshapeAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = ReshapeAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # ReshapeAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # ReshapeAttribute
+ def Shape(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # ReshapeAttribute
+ def ShapeAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # ReshapeAttribute
+ def ShapeLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+def ReshapeAttributeStart(builder): builder.StartObject(1)
+def ReshapeAttributeAddShape(builder, shape): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)
+def ReshapeAttributeStartShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def ReshapeAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/ResizeAttribute.py b/verif/tosa/ResizeAttribute.py
new file mode 100644
index 0000000..1e6941f
--- /dev/null
+++ b/verif/tosa/ResizeAttribute.py
@@ -0,0 +1,125 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class ResizeAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsResizeAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = ResizeAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # ResizeAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # ResizeAttribute
+ def OutputSize(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # ResizeAttribute
+ def OutputSizeAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # ResizeAttribute
+ def OutputSizeLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # ResizeAttribute
+ def Stride(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # ResizeAttribute
+ def StrideAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # ResizeAttribute
+ def StrideLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # ResizeAttribute
+ def Offset(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # ResizeAttribute
+ def OffsetAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # ResizeAttribute
+ def OffsetLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # ResizeAttribute
+ def Shift(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+ # ResizeAttribute
+ def Mode(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
+ return 0
+
+def ResizeAttributeStart(builder): builder.StartObject(5)
+def ResizeAttributeAddOutputSize(builder, outputSize): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(outputSize), 0)
+def ResizeAttributeStartOutputSizeVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def ResizeAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0)
+def ResizeAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def ResizeAttributeAddOffset(builder, offset): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(offset), 0)
+def ResizeAttributeStartOffsetVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def ResizeAttributeAddShift(builder, shift): builder.PrependInt32Slot(3, shift, 0)
+def ResizeAttributeAddMode(builder, mode): builder.PrependUint32Slot(4, mode, 0)
+def ResizeAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/ResizeMode.py b/verif/tosa/ResizeMode.py
new file mode 100644
index 0000000..02bed51
--- /dev/null
+++ b/verif/tosa/ResizeMode.py
@@ -0,0 +1,24 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+class ResizeMode(object):
+ UNKNOWN = 0
+ NEAREST = 1
+ BILINEAR = 2
+
diff --git a/verif/tosa/SliceAttribute.py b/verif/tosa/SliceAttribute.py
new file mode 100644
index 0000000..d156a4a
--- /dev/null
+++ b/verif/tosa/SliceAttribute.py
@@ -0,0 +1,85 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class SliceAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsSliceAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = SliceAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # SliceAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # SliceAttribute
+ def Begin(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # SliceAttribute
+ def BeginAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # SliceAttribute
+ def BeginLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # SliceAttribute
+ def Size(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # SliceAttribute
+ def SizeAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # SliceAttribute
+ def SizeLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+def SliceAttributeStart(builder): builder.StartObject(2)
+def SliceAttributeAddBegin(builder, begin): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(begin), 0)
+def SliceAttributeStartBeginVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def SliceAttributeAddSize(builder, size): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(size), 0)
+def SliceAttributeStartSizeVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def SliceAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/TileAttribute.py b/verif/tosa/TileAttribute.py
new file mode 100644
index 0000000..6385edd
--- /dev/null
+++ b/verif/tosa/TileAttribute.py
@@ -0,0 +1,61 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class TileAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsTileAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = TileAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # TileAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # TileAttribute
+ def Multiples(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # TileAttribute
+ def MultiplesAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # TileAttribute
+ def MultiplesLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+def TileAttributeStart(builder): builder.StartObject(1)
+def TileAttributeAddMultiples(builder, multiples): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(multiples), 0)
+def TileAttributeStartMultiplesVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TileAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/TosaBasicBlock.py b/verif/tosa/TosaBasicBlock.py
new file mode 100644
index 0000000..42a7379
--- /dev/null
+++ b/verif/tosa/TosaBasicBlock.py
@@ -0,0 +1,123 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class TosaBasicBlock(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsTosaBasicBlock(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = TosaBasicBlock()
+ x.Init(buf, n + offset)
+ return x
+
+ # TosaBasicBlock
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # TosaBasicBlock
+ def Name(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.String(o + self._tab.Pos)
+ return None
+
+ # TosaBasicBlock
+ def Operators(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ x = self._tab.Vector(o)
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
+ x = self._tab.Indirect(x)
+ from .TosaOperator import TosaOperator
+ obj = TosaOperator()
+ obj.Init(self._tab.Bytes, x)
+ return obj
+ return None
+
+ # TosaBasicBlock
+ def OperatorsLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TosaBasicBlock
+ def Tensors(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ x = self._tab.Vector(o)
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
+ x = self._tab.Indirect(x)
+ from .TosaTensor import TosaTensor
+ obj = TosaTensor()
+ obj.Init(self._tab.Bytes, x)
+ return obj
+ return None
+
+ # TosaBasicBlock
+ def TensorsLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TosaBasicBlock
+ def Inputs(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return ""
+
+ # TosaBasicBlock
+ def InputsLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TosaBasicBlock
+ def Outputs(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return ""
+
+ # TosaBasicBlock
+ def OutputsLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+def TosaBasicBlockStart(builder): builder.StartObject(5)
+def TosaBasicBlockAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
+def TosaBasicBlockAddOperators(builder, operators): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(operators), 0)
+def TosaBasicBlockStartOperatorsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaBasicBlockAddTensors(builder, tensors): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(tensors), 0)
+def TosaBasicBlockStartTensorsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaBasicBlockAddInputs(builder, inputs): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0)
+def TosaBasicBlockStartInputsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaBasicBlockAddOutputs(builder, outputs): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0)
+def TosaBasicBlockStartOutputsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaBasicBlockEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/TosaGraph.py b/verif/tosa/TosaGraph.py
new file mode 100644
index 0000000..92568b9
--- /dev/null
+++ b/verif/tosa/TosaGraph.py
@@ -0,0 +1,71 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class TosaGraph(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsTosaGraph(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = TosaGraph()
+ x.Init(buf, n + offset)
+ return x
+
+ # TosaGraph
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # TosaGraph
+ def Version(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ x = self._tab.Indirect(o + self._tab.Pos)
+ from .Version import Version
+ obj = Version()
+ obj.Init(self._tab.Bytes, x)
+ return obj
+ return None
+
+ # TosaGraph
+ def Blocks(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ x = self._tab.Vector(o)
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
+ x = self._tab.Indirect(x)
+ from .TosaBasicBlock import TosaBasicBlock
+ obj = TosaBasicBlock()
+ obj.Init(self._tab.Bytes, x)
+ return obj
+ return None
+
+ # TosaGraph
+ def BlocksLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+def TosaGraphStart(builder): builder.StartObject(2)
+def TosaGraphAddVersion(builder, version): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(version), 0)
+def TosaGraphAddBlocks(builder, blocks): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(blocks), 0)
+def TosaGraphStartBlocksVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaGraphEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/TosaOperator.py b/verif/tosa/TosaOperator.py
new file mode 100644
index 0000000..ab4a160
--- /dev/null
+++ b/verif/tosa/TosaOperator.py
@@ -0,0 +1,117 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class TosaOperator(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsTosaOperator(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = TosaOperator()
+ x.Init(buf, n + offset)
+ return x
+
+ # TosaOperator
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # TosaOperator
+ def Op(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
+ return 0
+
+ # TosaOperator
+ def AttributeType(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
+ return 0
+
+ # TosaOperator
+ def Attribute(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ from flatbuffers.table import Table
+ obj = Table(bytearray(), 0)
+ self._tab.Union(obj, o)
+ return obj
+ return None
+
+ # TosaOperator
+ def Inputs(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return ""
+
+ # TosaOperator
+ def InputsLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TosaOperator
+ def Outputs(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return ""
+
+ # TosaOperator
+ def OutputsLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TosaOperator
+ def QuantInfoType(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
+ return 0
+
+ # TosaOperator
+ def QuantInfo(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+ if o != 0:
+ from flatbuffers.table import Table
+ obj = Table(bytearray(), 0)
+ self._tab.Union(obj, o)
+ return obj
+ return None
+
+def TosaOperatorStart(builder): builder.StartObject(7)
+def TosaOperatorAddOp(builder, op): builder.PrependUint32Slot(0, op, 0)
+def TosaOperatorAddAttributeType(builder, attributeType): builder.PrependUint8Slot(1, attributeType, 0)
+def TosaOperatorAddAttribute(builder, attribute): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(attribute), 0)
+def TosaOperatorAddInputs(builder, inputs): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0)
+def TosaOperatorStartInputsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaOperatorAddOutputs(builder, outputs): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0)
+def TosaOperatorStartOutputsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaOperatorAddQuantInfoType(builder, quantInfoType): builder.PrependUint8Slot(5, quantInfoType, 0)
+def TosaOperatorAddQuantInfo(builder, quantInfo): builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(quantInfo), 0)
+def TosaOperatorEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/TosaTensor.py b/verif/tosa/TosaTensor.py
new file mode 100644
index 0000000..0b30266
--- /dev/null
+++ b/verif/tosa/TosaTensor.py
@@ -0,0 +1,133 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class TosaTensor(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsTosaTensor(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = TosaTensor()
+ x.Init(buf, n + offset)
+ return x
+
+ # TosaTensor
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # TosaTensor
+ def Name(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.String(o + self._tab.Pos)
+ return None
+
+ # TosaTensor
+ def Shape(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # TosaTensor
+ def ShapeAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # TosaTensor
+ def ShapeLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TosaTensor
+ def Type(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
+ return 0
+
+ # TosaTensor
+ def Usage(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # TosaTensor
+ def UsageAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
+ return 0
+
+ # TosaTensor
+ def UsageLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TosaTensor
+ def Format(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # TosaTensor
+ def FormatAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
+ return 0
+
+ # TosaTensor
+ def FormatLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TosaTensor
+ def NpyFilename(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+ if o != 0:
+ return self._tab.String(o + self._tab.Pos)
+ return None
+
+def TosaTensorStart(builder): builder.StartObject(6)
+def TosaTensorAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
+def TosaTensorAddShape(builder, shape): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)
+def TosaTensorStartShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaTensorAddType(builder, type): builder.PrependUint32Slot(2, type, 0)
+def TosaTensorAddUsage(builder, usage): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(usage), 0)
+def TosaTensorStartUsageVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaTensorAddFormat(builder, format): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(format), 0)
+def TosaTensorStartFormatVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaTensorAddNpyFilename(builder, npyFilename): builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(npyFilename), 0)
+def TosaTensorEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/TransposeConv2dAttribute.py b/verif/tosa/TransposeConv2dAttribute.py
new file mode 100644
index 0000000..043d8e8
--- /dev/null
+++ b/verif/tosa/TransposeConv2dAttribute.py
@@ -0,0 +1,133 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class TransposeConv2dAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsTransposeConv2dAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = TransposeConv2dAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # TransposeConv2dAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # TransposeConv2dAttribute
+ def Outpad(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # TransposeConv2dAttribute
+ def OutpadAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # TransposeConv2dAttribute
+ def OutpadLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TransposeConv2dAttribute
+ def Stride(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # TransposeConv2dAttribute
+ def StrideAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # TransposeConv2dAttribute
+ def StrideLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TransposeConv2dAttribute
+ def Dilation(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # TransposeConv2dAttribute
+ def DilationAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # TransposeConv2dAttribute
+ def DilationLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # TransposeConv2dAttribute
+ def OutputShape(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # TransposeConv2dAttribute
+ def OutputShapeAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+ return 0
+
+ # TransposeConv2dAttribute
+ def OutputShapeLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+def TransposeConv2dAttributeStart(builder): builder.StartObject(4)
+def TransposeConv2dAttributeAddOutpad(builder, outpad): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(outpad), 0)
+def TransposeConv2dAttributeStartOutpadVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TransposeConv2dAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0)
+def TransposeConv2dAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TransposeConv2dAttributeAddDilation(builder, dilation): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(dilation), 0)
+def TransposeConv2dAttributeStartDilationVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TransposeConv2dAttributeAddOutputShape(builder, outputShape): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(outputShape), 0)
+def TransposeConv2dAttributeStartOutputShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TransposeConv2dAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/UnaryQuantInfo.py b/verif/tosa/UnaryQuantInfo.py
new file mode 100644
index 0000000..9ae0214
--- /dev/null
+++ b/verif/tosa/UnaryQuantInfo.py
@@ -0,0 +1,53 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class UnaryQuantInfo(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsUnaryQuantInfo(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = UnaryQuantInfo()
+ x.Init(buf, n + offset)
+ return x
+
+ # UnaryQuantInfo
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # UnaryQuantInfo
+ def InputZp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+ # UnaryQuantInfo
+ def OutputZp(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+def UnaryQuantInfoStart(builder): builder.StartObject(2)
+def UnaryQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0)
+def UnaryQuantInfoAddOutputZp(builder, outputZp): builder.PrependInt32Slot(1, outputZp, 0)
+def UnaryQuantInfoEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/Usage.py b/verif/tosa/Usage.py
new file mode 100644
index 0000000..4c42daa
--- /dev/null
+++ b/verif/tosa/Usage.py
@@ -0,0 +1,25 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+class Usage(object):
+ UNKNOWN = 0
+ ACTIVATION = 1
+ WEIGHT = 2
+ INDEX = 3
+
diff --git a/verif/tosa/Version.py b/verif/tosa/Version.py
new file mode 100644
index 0000000..ddfdb2d
--- /dev/null
+++ b/verif/tosa/Version.py
@@ -0,0 +1,69 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class Version(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsVersion(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = Version()
+ x.Init(buf, n + offset)
+ return x
+
+ # Version
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # Version
+ def _major(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+ # Version
+ def _minor(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 20
+
+ # Version
+ def _patch(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+ return 0
+
+ # Version
+ def _experimental(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ if o != 0:
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+ return False
+
+def VersionStart(builder): builder.StartObject(4)
+def VersionAdd_major(builder, Major): builder.PrependInt32Slot(0, Major, 0)
+def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 20)
+def VersionAdd_patch(builder, Patch): builder.PrependInt32Slot(2, Patch, 0)
+def VersionAdd_experimental(builder, Experimental): builder.PrependBoolSlot(3, Experimental, 0)
+def VersionEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/WhileLoopAttribute.py b/verif/tosa/WhileLoopAttribute.py
new file mode 100644
index 0000000..c37977f
--- /dev/null
+++ b/verif/tosa/WhileLoopAttribute.py
@@ -0,0 +1,53 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# namespace: tosa
+
+import flatbuffers
+
+class WhileLoopAttribute(object):
+ __slots__ = ['_tab']
+
+ @classmethod
+ def GetRootAsWhileLoopAttribute(cls, buf, offset):
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+ x = WhileLoopAttribute()
+ x.Init(buf, n + offset)
+ return x
+
+ # WhileLoopAttribute
+ def Init(self, buf, pos):
+ self._tab = flatbuffers.table.Table(buf, pos)
+
+ # WhileLoopAttribute
+ def CondBranch(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+ if o != 0:
+ return self._tab.String(o + self._tab.Pos)
+ return None
+
+ # WhileLoopAttribute
+ def BodyBranch(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.String(o + self._tab.Pos)
+ return None
+
+def WhileLoopAttributeStart(builder): builder.StartObject(2)
+def WhileLoopAttributeAddCondBranch(builder, condBranch): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(condBranch), 0)
+def WhileLoopAttributeAddBodyBranch(builder, bodyBranch): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(bodyBranch), 0)
+def WhileLoopAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/__init__.py b/verif/tosa/__init__.py
new file mode 100644
index 0000000..ee1ab30
--- /dev/null
+++ b/verif/tosa/__init__.py
@@ -0,0 +1,15 @@
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/verif/tosa_ref_run.py b/verif/tosa_ref_run.py
new file mode 100644
index 0000000..99f504b
--- /dev/null
+++ b/verif/tosa_ref_run.py
@@ -0,0 +1,66 @@
+import os
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import shlex
+import subprocess
+from tosa_test_runner import TosaTestRunner, run_sh_command
+
+class TosaRefRunner(TosaTestRunner):
+ def __init__(self, args, runnerArgs, testDir):
+ super().__init__(args, runnerArgs, testDir)
+
+ def runModel(self):
+ # Build up the TOSA reference command line
+ # Uses arguments from the argParser args, not the runnerArgs
+ args = self.args
+
+ ref_cmd = [ args.ref_model_path,
+ '-Csubgraph_file={}'.format(self.testDesc['tosa_file']),
+ '-Csubgraph_dir={}'.format(self.testDir),
+ '-Cinput_dir={}'.format(self.testDir),
+ '-Coutput_dir={}'.format(self.testDir),
+ '-Coutput_tensor_prefix=ref-', # Naming agreement with TosaSerializer
+ ]
+
+ # Build up input tensor_name/filename list
+ inputTensors = []
+ for i in range(len(self.testDesc['ifm_placeholder'])):
+ inputTensors.append('{}:{}'.format(self.testDesc['ifm_placeholder'][i], self.testDesc['ifm_file'][i]))
+
+ ref_cmd.append('-Cinput_tensor={}'.format(','.join(inputTensors)))
+
+ if args.ref_debug:
+ ref_cmd.extend(['-dALL', '-l{}'.format(args.ref_debug)])
+
+ if args.ref_intermediates:
+ ref_cmd.extend(['-Ddump_intermediates=1'])
+
+ expectedFailure = self.testDesc['expected_failure']
+
+ try:
+ run_sh_command(self.args, ref_cmd)
+ if expectedFailure:
+ result = TosaTestRunner.Result.UNEXPECTED_PASS
+ else:
+ result = TosaTestRunner.Result.EXPECTED_PASS
+ except Exception as e:
+ if expectedFailure:
+ result = TosaTestRunner.Result.EXPECTED_FAILURE
+ else:
+ result = TosaTestRunner.Result.EXPECTED_PASS
+
+ return result
diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py
new file mode 100644
index 0000000..7ba68c3
--- /dev/null
+++ b/verif/tosa_serializer.py
@@ -0,0 +1,718 @@
+
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+#!/usr/bin/env python3
+
+import flatbuffers
+import numpy as np
+from enum import Enum, IntEnum, unique
+from tosa import TosaGraph, TosaBasicBlock, TosaTensor, TosaOperator, DType, Format, Usage, Op, ResizeMode, Version
+import tosa
+import os
+import json
+
+# With the way flatc generates its python types, there is no programatic way
+# to get string names for the integer types. Manually maintain a string table
+# here.
+DTypeNames = [ 'UNKNOWN',
+ 'BOOL',
+ 'AINT8',
+ 'UINT8',
+ 'INT4',
+ 'INT8',
+ 'INT16',
+ 'INT32',
+ 'INT48',
+ 'FLOAT' ]
+
+def dtype_str_to_val(name):
+
+ for i in range(len(DTypeNames)):
+ if name.casefold() == DTypeNames[i].casefold():
+ return i
+ raise Exception('Unable to parse DType name {}'.format(name))
+
+
+class TosaSerializerUnion:
+ '''This class handles encapsulating and serializing union types into flatbuffers'''
+ def __init__(self):
+
+ # A tuple of the start and end functions. Set by the options constructors below
+ self.optFcns = None
+
+ # The type from the tosa.Options enumeration. Set by the options constructors below.
+ self.utype = None
+
+ # Each of these lists is a tuple of the add function and the
+ # value being added. Set by the options constructors below.
+ self.ints = []
+ self.bools = []
+ self.floats = []
+ self.strings = []
+ self.intvecs = []
+
+ def serialize(self, builder):
+
+ # We have to build strings and vectors first
+ strList = []
+ intVecList = []
+
+ for fcn, val in self.strings:
+ strList.append((fcn, builder.CreateString(val)))
+
+ for fcn, val in self.intvecs:
+ intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
+
+ startFcn, endFcn = self.optFcns
+
+ # Then serialize the options object from the list of primitives and
+ # other serialized values
+ startFcn(builder)
+ for fcn, val in self.ints:
+ fcn(builder, val)
+
+ for fcn, val in self.bools:
+ fcn(builder, val)
+
+ for fcn, val in self.floats:
+ fcn(builder, val)
+
+ for fcn, val in strList:
+ fcn(builder, val)
+
+ for fcn, val in intVecList:
+ fcn(builder, val)
+
+ return endFcn(builder)
+
+class TosaSerializerAttribute(TosaSerializerUnion):
+ '''This class handles encapsulating all of the enumerated types for attributes'''
+
+ def __init__(self):
+ super().__init__()
+
+ def Pool2dAttribute(self, kernel, stride, padding):
+ from tosa import Pool2dAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().Pool2dAttribute
+
+ self.optFcns = (a.Pool2dAttributeStart, a.Pool2dAttributeEnd)
+ self.intvecs.append((a.Pool2dAttributeAddPadding,
+ padding))
+ self.intvecs.append((a.Pool2dAttributeAddKernel,
+ kernel))
+ self.intvecs.append((a.Pool2dAttributeAddStride,
+ stride))
+
+ def Conv2dAttribute(self, padding, stride, dilation):
+ from tosa import Conv2dAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().Conv2dAttribute
+ self.optFcns = (a.Conv2dAttributeStart, a.Conv2dAttributeEnd)
+
+ self.intvecs.append((a.Conv2dAttributeAddPadding,
+ padding))
+ self.intvecs.append((a.Conv2dAttributeAddStride,
+ stride))
+ self.intvecs.append((a.Conv2dAttributeAddDilation,
+ dilation))
+
+ def TransposeConv2DAttribute(self, outpad, stride, dilation, output_shape):
+ from tosa import TransposeConv2dAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().TransposeConv2dAttribute
+ self.optFcns = (a.TransposeConv2dAttributeStart, a.TransposeConv2dAttributeEnd)
+
+ self.intvecs.append((a.TransposeConv2dAttributeAddOutpad,
+ outpad))
+ self.intvecs.append((a.TransposeConv2dAttributeAddStride,
+ stride))
+ self.intvecs.append((a.TransposeConv2dAttributeAddDilation,
+ dilation))
+ self.intvecs.append((a.TransposeConv2dAttributeAddOutputShape,
+ output_shape))
+
+ def ReluNAttribute(self, maxint, maxfp):
+ from tosa import ReluNAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().ReluNAttribute
+ self.optFcns = (a.ReluNAttributeStart, a.ReluNAttributeEnd)
+
+ self.ints.append((a.ReluNAttributeAddMaxInt, maxint))
+ self.ints.append((a.ReluNAttributeAddMaxFp, maxfp))
+
+
+ def AxisAttribute(self, axis):
+ from tosa import AxisAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().AxisAttribute
+ self.optFcns = (a.AxisAttributeStart, a.AxisAttributeEnd)
+
+ self.ints.append((a.AxisAttributeAddAxis,
+ axis))
+
+ def ReshapeAttribute(self, shape):
+ from tosa import ReshapeAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().ReshapeAttribute
+ self.optFcns = (a.ReshapeAttributeStart, a.ReshapeAttributeEnd)
+
+ self.intvecs.append((a.ReshapeAttributeAddShape,
+ shape))
+
+ def SliceAttribute(self, begin, size):
+ from tosa import SliceAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().SliceAttribute
+ self.optFcns = (a.SliceAttributeStart, a.SliceAttributeEnd)
+
+ self.intvecs.append((a.SliceAttributeAddBegin,
+ begin))
+ self.intvecs.append((a.SliceAttributeAddSize,
+ size))
+
+ def TileAttribute(self, multiples):
+ from tosa import TileAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().TileAttribute
+ self.optFcns = (a.TileAttributeStart, a.TileAttributeEnd)
+
+ self.intvecs.append((a.TileAttributeAddMultiples,
+ multiples))
+
+ def ResizeAttribute(self, output_size, stride, offset, shift, mode):
+ from tosa import ResizeAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().ResizeAttribute
+ self.optFcns = (a.ResizeAttributeStart, a.ResizeAttributeEnd)
+
+ self.intvecs.append((a.ResizeAttributeAddOutputSize,
+ output_size))
+ self.intvecs.append((a.ResizeAttributeAddStride,
+ stride))
+ self.intvecs.append((a.ResizeAttributeAddOffset,
+ offset))
+ self.ints.append((a.ResizeAttributeAddShift,
+ shift))
+ self.ints.append((a.ResizeAttributeAddMode,
+ mode))
+
+ def ClampAttribute(self, minint, maxint, minfp, maxfp):
+ from tosa import ClampAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().ClampAttribute
+ self.optFcns = (a.ClampAttributeStart, a.ClampAttributeEnd)
+
+ self.ints.append((a.ClampAttributeAddMinInt,
+ minint))
+ self.ints.append((a.ClampAttributeAddMaxInt,
+ maxint))
+
+ self.ints.append((a.ClampAttributeAddMinFp,
+ minfp))
+ self.ints.append((a.ClampAttributeAddMaxFp,
+ maxfp))
+
+ def RescaleAttribute(self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel):
+ from tosa import RescaleAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().RescaleAttribute
+ self.optFcns = (a.RescaleAttributeStart, a.RescaleAttributeEnd)
+
+ self.ints.append((a.RescaleAttributeAddInputZp,
+ input_zp))
+ self.ints.append((a.RescaleAttributeAddOutputZp,
+ output_zp))
+ self.intvecs.append((a.RescaleAttributeAddMultiplier,
+ multiplier))
+ self.intvecs.append((a.RescaleAttributeAddShift,
+ shift))
+ self.bools.append((a.RescaleAttributeAddScale32,
+ scale32))
+ self.bools.append((a.RescaleAttributeAddDoubleRound,
+ double_round))
+ self.bools.append((a.RescaleAttributeAddPerChannel,
+ per_channel))
+
+ def CustomAttribute(self, identifier):
+ from tosa import CustomAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().CustomAttribute
+ self.optFcns = (a.CustomAttributeStart, a.CustomAttributeEnd)
+
+ self.strings.append((a.CustomAttributeAddIdentifier,
+ identifier))
+
+ def CondIfAttribute(self, then_branch, else_branch):
+ from tosa import CondIfAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().CondIfAttribute
+ self.optFcns = (a.CondIfAttributeStart, a.CondIfAttributeEnd)
+
+ self.strings.append((a.CondIfAttributeAddThenBranch,
+ then_branch))
+ self.strings.append((a.CondIfAttributeAddElseBranch,
+ else_branch))
+
+ def WhileLoopAttribute(self, cond_branch, body_branch):
+ from tosa import WhileLoopAttribute as a, Attribute
+
+ self.utype = Attribute.Attribute().WhileLoopAttribute
+ self.optFcns = (a.WhileLoopAttributeStart, a.WhileLoopAttributeEnd)
+
+ self.strings.append((a.WhileLoopAttributeAddCondBranch,
+ cond_branch))
+ self.strings.append((a.WhileLoopAttributeAddBodyBranch,
+ body_branch))
+
+class TosaSerializerQuantInfo(TosaSerializerUnion):
+ '''This class handles encapsulating all of the enumerated types for quantinfo types'''
+ def __init__(self):
+ super().__init__()
+
+ def ConvQuantInfo(self, input_zp, weight_zp):
+ from tosa import ConvQuantInfo as q, QuantInfo
+
+ self.utype = QuantInfo.QuantInfo().ConvQuantInfo
+ self.optFcns = (q.ConvQuantInfoStart, q.ConvQuantInfoEnd)
+ self.ints.append((q.ConvQuantInfoAddInputZp, input_zp))
+ self.ints.append((q.ConvQuantInfoAddWeightZp, weight_zp))
+
+ def UnaryQuantInfo(self, input_zp, output_zp):
+ from tosa import UnaryQuantInfo as q, QuantInfo
+
+ self.utype = QuantInfo.QuantInfo().UnaryQuantInfo
+ self.optFcns = (q.UnaryQuantInfoStart, q.UnaryQuantInfoEnd)
+ self.ints.append((q.UnaryQuantInfoAddInputZp, input_zp))
+ self.ints.append((q.UnaryQuantInfoAddOutputZp, output_zp))
+
+ def MatMulQuantInfo(self, a_zp, b_zp):
+ from tosa import MatMulQuantInfo as q, QuantInfo
+
+ self.utype = QuantInfo.QuantInfo().MatMulQuantInfo
+ self.optFcns = (q.MatMulQuantInfoStart, q.MatMulQuantInfoEnd)
+ self.ints.append((q.MatMulQuantInfoAddAZp, a_zp))
+ self.ints.append((q.MatMulQuantInfoAddBZp, b_zp))
+
+ def PadQuantInfo(self, input_zp):
+ from tosa import PadQuantInfo as q, QuantInfo
+
+ self.utype = QuantInfo.QuantInfo().PadQuantInfo
+ self.optFcns = (q.PadQuantInfoStart, q.PadQuantInfoEnd)
+ self.ints.append((q.PadQuantInfoAddInputZp, input_zp))
+
+class TosaSerializerTensor:
+ def __init__(self, name, shape, dtype, usage, dformat, filename = None, placeholderFilename = None):
+ self.name = name
+
+ if isinstance(shape, np.ndarray):
+ shape = shape.astype(int).tolist()
+ shape = list(map(int, shape))
+
+ self.shape = shape
+ self.dtype = dtype
+ self.usage = TosaSerializer.toList(usage)
+ self.dformat = TosaSerializer.toList(dformat)
+
+ # Filename for const tensors. This gets written to the .tosa serialization
+ self.filename = filename
+
+ # Filename for placeholder tensors. These get generated by the test generation
+ # process and are written to disk, but are considered input tensors by the network
+ # so they do not appear in the TOSA serialiazation. However, if we want to form a unit
+ # test around these input tensors, we can get the filename from here.
+ self.placeholderFilename = placeholderFilename
+
+ def __str__(self):
+ str = 'TosaSerializerTensor name: {} shape: {} dtype: {} Usage: {} format {} filename: {}'.format(
+ self.name, self.shape, DTypeNames[self.dtype], self.usage, self.dformat, self.filename)
+ return str
+
+ def addUsage(self, usage):
+ self.usage.append(usage)
+
+ def addFormat(self, format):
+ self.dformat.append(format)
+
+ def setDtype(self, dtype):
+ self.dtype = dtype
+
+ def merge(self, name, shape, dtype, usage, dformat, filename = None):
+ # Merge in additional usage/formats to the list
+ found = 0
+ for i in self.usage:
+ if i == usage:
+ found = 1
+ break
+ if not found:
+ self.usage.append(usage)
+
+ found = 0
+ for i in self.dformat:
+ if i == dformat:
+ found = 1
+ break
+ if not found:
+ self.dformat.append(dformat)
+
+ def serialize(self, builder):
+ fb_name = builder.CreateString(self.name)
+ if self.filename:
+ fb_filename = builder.CreateString(self.filename)
+ fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
+ fb_usage = TosaSerializer.serializeInt32Vec(builder, self.usage)
+ fb_dformat = TosaSerializer.serializeInt32Vec(builder, self.dformat)
+
+ TosaTensor.TosaTensorStart(builder)
+ TosaTensor.TosaTensorAddName(builder, fb_name)
+ TosaTensor.TosaTensorAddShape(builder, fb_shapes)
+ TosaTensor.TosaTensorAddType(builder, self.dtype)
+ TosaTensor.TosaTensorAddUsage(builder, fb_usage)
+ TosaTensor.TosaTensorAddFormat(builder, fb_dformat)
+ if self.filename:
+ TosaTensor.TosaTensorAddNpyFilename(builder, fb_filename)
+
+ return TosaTensor.TosaTensorEnd(builder)
+
+class TosaSerializerOperator:
+ def __init__(self, op, inputs, outputs, attributes = None, quantInfo = None):
+ self.op = op
+ self.attributes = attributes
+ self.inputs = TosaSerializer.toList(inputs)
+ self.outputs = TosaSerializer.toList(outputs)
+ self.quantInfo = quantInfo
+
+ def __str__(self):
+ str = 'Op {}\n----\n'.format(self.op)
+
+ for i in self.inputs:
+ str = str + ' Input: {}\n'.format(i)
+ for o in self.outputs:
+ str = str + ' Output: {}\n'.format(o)
+
+ return str
+
+ def serialize(self, builder):
+ fb_inputs = TosaSerializer.serializeStrVec(builder, self.inputs, TosaOperator.TosaOperatorStartInputsVector)
+ fb_outputs = TosaSerializer.serializeStrVec(builder, self.outputs, TosaOperator.TosaOperatorStartOutputsVector)
+ # Need to serialize quant_info and attributes enums still
+ if self.attributes is not None:
+ fb_attributes = self.attributes.serialize(builder)
+
+ if self.quantInfo is not None:
+ fb_qinfo = self.quantInfo.serialize(builder)
+
+ TosaOperator.TosaOperatorStart(builder)
+ TosaOperator.TosaOperatorAddOp(builder, self.op)
+ TosaOperator.TosaOperatorAddInputs(builder, fb_inputs)
+ TosaOperator.TosaOperatorAddOutputs(builder, fb_outputs)
+ if self.attributes is not None:
+ TosaOperator.TosaOperatorAddAttributeType(builder, self.attributes.utype)
+ TosaOperator.TosaOperatorAddAttribute(builder, fb_attributes)
+ if self.quantInfo is not None:
+ TosaOperator.TosaOperatorAddQuantInfoType(builder, self.quantInfo.utype)
+ TosaOperator.TosaOperatorAddQuantInfo(builder, fb_qinfo)
+
+ return TosaOperator.TosaOperatorEnd(builder)
+
+class TosaSerializerBasicBlock:
+ def __init__(self, name):
+ self.name = name
+ self.operators = []
+
+ # Dict assures uniqueness, but allows us to look up by name
+ self.tensors = dict()
+
+ self.inputs = []
+ self.outputs = []
+
+ def addTensor(self, name, shape, dtype, usage, dformat, filename = None, placeholderFilename = None):
+ try:
+ # Someone already added this tensor.
+ # We may have to add more usages and formats
+ tens = self.tensors[name]
+ filename = tens.merge(name, shape, dtype, usage, dformat, filename)
+ except KeyError:
+ self.tensors[name] = TosaSerializerTensor(name, shape, dtype, usage, dformat, filename, placeholderFilename)
+
+ return self.tensors[name]
+
+ def addInput(self, name):
+ self.inputs.append(name)
+
+ def addOutput(self, name):
+ self.outputs.append(name)
+
+ def addOperator(self, op, inputs, outputs, attributes = None, quant_info = None):
+ self.operators.append(TosaSerializerOperator(op, inputs, outputs, attributes, quant_info))
+
+ def serialize(self, builder):
+ fb_name = builder.CreateString(self.name)
+ fbv_inputs = TosaSerializer.serializeStrVec(builder, list(self.inputs), TosaBasicBlock.TosaBasicBlockStartInputsVector)
+ fbv_outputs = TosaSerializer.serializeStrVec(builder, list(self.outputs), TosaBasicBlock.TosaBasicBlockStartOutputsVector)
+ fbv_tensors = TosaSerializer.serializeObjVec(builder, list(self.tensors.values()), TosaBasicBlock.TosaBasicBlockStartTensorsVector)
+ fbv_operators = TosaSerializer.serializeObjVec(builder, self.operators, TosaBasicBlock.TosaBasicBlockStartOperatorsVector)
+
+ TosaBasicBlock.TosaBasicBlockStart(builder)
+ TosaBasicBlock.TosaBasicBlockAddName(builder, fb_name)
+ TosaBasicBlock.TosaBasicBlockAddInputs(builder, fbv_inputs)
+ TosaBasicBlock.TosaBasicBlockAddOutputs(builder, fbv_outputs)
+ TosaBasicBlock.TosaBasicBlockAddTensors(builder, fbv_tensors)
+ TosaBasicBlock.TosaBasicBlockAddOperators(builder, fbv_operators)
+ return TosaBasicBlock.TosaBasicBlockEnd(builder)
+
+@unique
+class TensorDir(IntEnum):
+ PLACEHOLDER = 0
+ CONST = 1
+ INTERMEDIATE = 2
+ RESULT = 3
+
+class TosaSerializer:
+ def __init__(self, pathPrefix):
+
+ # Get the global TOSA version if not already defined
+ try:
+ TOSA_VERSION
+ except NameError:
+ TosaSerializer.setTosaVersion()
+
+ self.builder = flatbuffers.Builder(0)
+
+ self.basicBlocks = []
+ self.startBasicBlock('main')
+ self.pathPrefix = pathPrefix
+
+ # Indicies used for adding/naming tensors
+ self.currInputIdx = 0
+ self.currConstIdx = 0
+ self.currLayerIdx = 1
+ self.currResultIdx = 0
+
+ # Is this an illegal test that is expected to fail?
+ self.expectedFailure = False
+ self.expectedFailureDesc = ''
+
+ def __str__(self):
+ str = ''
+ for bb in self.basicBlocks:
+ str = str + bb.__str__()
+ return str
+
+ def addPlaceholder(self, shape, dtype, usage, dformat, vals):
+ if not self.currBasicBlock:
+ raise Exception('addTensor called without valid basic block')
+
+ name = 'input-{}'.format(self.currInputIdx)
+ filename = '{}.npy'.format(name)
+ self.currInputIdx = self.currInputIdx + 1
+
+ tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, None, filename)
+ # This is always an input to the block
+ self.currBasicBlock.addInput(name)
+ # Add the operator now
+ self.currBasicBlock.addOperator(tosa.Op.Op().PLACEHOLDER, [], name)
+
+ if vals is not None:
+ np.save(os.path.join(self.pathPrefix, filename), vals, False)
+
+ return tens
+
+ def addConst(self, shape, dtype, usage, dformat, vals):
+ if not self.currBasicBlock:
+ raise Exception('addTensor called without valid basic block')
+
+ name = 'const-{}'.format(self.currInputIdx)
+ filename = '{}.npy'.format(name)
+ self.currInputIdx = self.currInputIdx + 1
+
+ tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, filename)
+ # Add the operator now
+ self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name)
+
+ if vals is not None:
+ np.save(os.path.join(self.pathPrefix, filename), vals, False)
+ return tens
+
+ def addIntermediate(self, shape, dtype, usage, dformat):
+
+ if not self.currBasicBlock:
+ raise Exception('addTensor called without valid basic block')
+
+ name = 'layer-{}'.format(self.currLayerIdx)
+ filename = None # No file, so no filename
+ self.currLayerIdx = self.currLayerIdx + 1
+
+ tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, filename)
+
+ return tens
+
+ def addInputTensor(self, tensor):
+ self.currBasicBlock.addOperator(tosa.Op.Op().PLACEHOLDER, [], tensor.name)
+ self.currBasicBlock.addTensor(tensor.name, tensor.shape, tensor.dtype, tensor.usage, tensor.dformat)
+ self.currBasicBlock.addInput(tensor.name)
+
+ def addOutputTensor(self, tensor):
+ self.currBasicBlock.addOutput(tensor.name)
+
+ def addOutput(self, shape, dtype, usage, dformat):
+ if not self.currBasicBlock:
+ raise Exception('addTensor called without valid basic block')
+
+ name = 'result-{}'.format(self.currResultIdx)
+ self.currResultIdx = self.currResultIdx + 1
+
+ tens = self.currBasicBlock.addTensor(name, shape, dtype, usage, dformat, None)
+ self.currBasicBlock.addOutput(name)
+ return tens
+
+ def addOperator(self, op, inputs, outputs, attributes = None, quant_info = None):
+
+ if op == tosa.Op.Op().PLACEHOLDER or \
+ op == tosa.Op.Op().CONST:
+ raise Exception('Use addPlaceholderTensor() or addConstTensor() to add PLACEHOLDER and CONST ops')
+
+ return self.currBasicBlock.addOperator(op, inputs, outputs, attributes, quant_info)
+
+ def setExpectedFailure(self, desc='', val=True):
+ self.expectedFailure = val
+ self.expectedFailureDesc = desc
+
+ def setExpectedFailure(self, desc='', val=True):
+ self.expectedFailure = val
+ self.expectedFailureDesc = desc
+
+ def serialize(self):
+
+ builder = self.builder
+
+ Version.VersionStart(builder)
+ Version.VersionAdd_major(builder, TOSA_VERSION[0])
+ Version.VersionAdd_minor(builder, TOSA_VERSION[1])
+ Version.VersionAdd_patch(builder, TOSA_VERSION[2])
+ Version.VersionAdd_experimental(builder, TOSA_VERSION[3])
+ version = Version.VersionEnd(builder)
+
+ fbv_bb = TosaSerializer.serializeObjVec(builder, self.basicBlocks, TosaGraph.TosaGraphStartBlocksVector)
+
+ TosaGraph.TosaGraphStart(builder)
+ TosaGraph.TosaGraphAddVersion(builder, version)
+ TosaGraph.TosaGraphAddBlocks(builder, fbv_bb)
+ graph = TosaGraph.TosaGraphEnd(builder)
+
+ self.builder.Finish(graph)
+ return self.builder.Output()
+
+ def writeJson(self, tosa_filename):
+ '''Write a json test file so that it is fairly easy to pick up the test
+ and generate commands for third party tool'''
+ test_desc = dict()
+
+ test_desc['tosa_file'] = tosa_filename
+ ifm_name = []
+ ifm_shape = []
+ ifm_file = []
+ ofm_name = []
+ ofm_file = []
+ ofm_shape = []
+
+ for b in self.basicBlocks:
+ if b.name == 'main':
+ for i in b.inputs:
+ ifm_name.append(i)
+ ifm_shape.append(b.tensors[i].shape)
+ ifm_file.append(b.tensors[i].placeholderFilename)
+ for o in b.outputs:
+ ofm_name.append(o)
+ ofm_shape.append(b.tensors[o].shape)
+ # Make up an OFM filename here. One isn't generated until the reference tool is
+ # run, so any name is a good name
+ ofm_file.append('ref-{}.npy'.format(o))
+
+ test_desc['ifm_placeholder'] = ifm_name
+ test_desc['ifm_file'] = ifm_file
+ test_desc['ifm_shape'] = ifm_shape
+ test_desc['ofm_name'] = ofm_name
+ test_desc['ofm_shape'] = ofm_shape
+ test_desc['ofm_file'] = ofm_file
+ test_desc['expected_failure'] = self.expectedFailure
+ if self.expectedFailureDesc:
+ test_desc['expected_failure_desc'] = self.expectedFailureDesc
+
+ return json.dumps(test_desc, indent=' ')
+
+ def startBasicBlock(self, name):
+ self.currBasicBlock = TosaSerializerBasicBlock(name)
+ self.basicBlocks.append(self.currBasicBlock)
+
+ @staticmethod
+ def serializeStrVec(builder, vec, start_fcn):
+ fb_strs = [builder.CreateString(i) for i in vec]
+ start_fcn(builder, len(fb_strs))
+ for s in fb_strs[::-1]:
+ builder.PrependUOffsetTRelative(s)
+ return builder.EndVector(len(fb_strs))
+
+ @staticmethod
+ def serializeInt32Vec(builder, vec):
+ builder.StartVector(4, len(vec), 4)
+ for v in vec[::-1]:
+ builder.PrependInt32(v)
+ return builder.EndVector(len(vec))
+
+ @staticmethod
+ def serializeObjVec(builder, vec, start_fcn):
+ serialized_vec = []
+ for v in vec[::-1]:
+ serialized_vec.append(v.serialize(builder))
+
+ start_fcn(builder, len(vec))
+ for v in serialized_vec:
+ builder.PrependUOffsetTRelative(v)
+ return builder.EndVector(len(vec))
+
+ @staticmethod
+ def toList(val):
+ if isinstance(val, list):
+ return val
+ else:
+ return [val]
+
+ @staticmethod
+ def setTosaVersion():
+ # Create a dummy flatbuffers file with the default version information
+ # There does not appear to be a better way to get a constant from a
+ # flatbuffer schema file
+ builder = flatbuffers.Builder(0)
+ Version.VersionStart(builder)
+ ver = Version.VersionEnd(builder)
+ TosaGraph.TosaGraphStart(builder)
+ TosaGraph.TosaGraphAddVersion(builder, ver)
+ gr = TosaGraph.TosaGraphEnd(builder)
+ builder.Finish(gr)
+
+ out = builder.Output()
+
+ gr = TosaGraph.TosaGraph()
+ root = gr.GetRootAsTosaGraph(out, 0)
+
+ # Store the version as a global variable so that it only needs to be
+ # generated once per process.
+ global TOSA_VERSION
+ TOSA_VERSION = [root.Version()._major(),
+ root.Version()._minor(),
+ root.Version()._patch(),
+ root.Version()._experimental() ]
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
new file mode 100644
index 0000000..dc2d803
--- /dev/null
+++ b/verif/tosa_test_gen.py
@@ -0,0 +1,2301 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+import argparse
+import sys
+import re
+import os
+import subprocess
+import shlex
+import json
+import glob
+import math
+import queue
+import threading
+import traceback
+import math
+
+from enum import IntEnum, Enum, unique
+
+import tosa_serializer as ts
+from tosa_serializer import *
+import tosa
+
+# Convenience variables to the flatc-generated types that should be enums, but aren't
+DType = tosa.DType.DType()
+Usage = tosa.Usage.Usage()
+Format = tosa.Format.Format()
+Op = tosa.Op.Op()
+ResizeMode = tosa.ResizeMode.ResizeMode()
+
+class TosaQuantGen:
+ '''QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion'''
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def needsQinfo(op, dtype):
+ if dtype == DType.AINT8 or dtype == DType.INT8:
+ return True
+ return False
+
+ @staticmethod
+ def qgUnary(testGen, op, dtype):
+ qinfo = ts.TosaSerializerQuantInfo()
+ if TosaQuantGen.needsQinfo(op, dtype):
+ qinfo.UnaryQuantInfo(testGen.randInt(), testGen.randInt())
+ else:
+ qinfo.UnaryQuantInfo(0, 0)
+ return qinfo
+
+ @staticmethod
+ def qgConv(testGen, op, dtype):
+ qinfo = ts.TosaSerializerQuantInfo()
+ if TosaQuantGen.needsQinfo(op, dtype):
+ qinfo.ConvQuantInfo(testGen.randInt(), testGen.randInt())
+ else:
+ qinfo.ConvQuantInfo(0, 0)
+ return qinfo
+
+ @staticmethod
+ def qgMatmul(testGen, op, dtype):
+ qinfo = ts.TosaSerializerQuantInfo()
+ if TosaQuantGen.needsQinfo(op, dtype):
+ qinfo.MatMulQuantInfo(testGen.randInt(), testGen.randInt())
+ else:
+ qinfo.MatMulQuantInfo(0, 0)
+ return qinfo
+
+ @staticmethod
+ def qgPad(testGen, op, dtype):
+ qinfo = ts.TosaSerializerQuantInfo()
+ if TosaQuantGen.needsQinfo(op, dtype):
+ qinfo.PadQuantInfo(testGen.randInt())
+ else:
+ qinfo.PadQuantInfo(0)
+ return qinfo
+
+ @staticmethod
+ def computeMultiplierAndShift(scaleFp, scale32):
+ # Derived from computeMultiplierAndShiftTosaScale32
+ # Provide a floating-point scaling factor and the scale32 parameter
+ # to compute the multiplier and shift
+
+ if scale32:
+ scaleBits = 31
+ else:
+ scaleBits = 15
+
+ m, shift = math.frexp(scaleFp)
+
+ if scaleFp < 0.0:
+ m = -m
+
+ multiplier = round(m * (1 << scaleBits))
+ assert(multiplier <= (1 << scaleBits))
+
+ if multiplier == (1 << scaleBits):
+ multiplier = multiplier // 2
+ shift = shift + 1
+
+ shift = (-shift) + scaleBits
+ #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
+
+ assert(multiplier <= (1 << scaleBits))
+ assert(shift >= 0 and shift <= 63)
+
+ return multiplier, shift
+
+
+class TosaTensorGen():
+ ''' Tensor generators create a shape list for the placeholder and const tensor
+ data operands for the operator. The actual random data is generated separately for each test.'''
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def tgBasic(testGen, opName, rank):
+ pl, const = opName['operands']
+ shape = testGen.makeShape(rank)
+
+ shape_list = []
+ for i in range(pl + const):
+ shape_list.append(shape.copy())
+
+ return shape_list
+
+ @staticmethod
+ def tgNHWC(testGen, opName, rank):
+ pl, const = opName['operands']
+
+ assert(rank == 4)
+
+ shape = testGen.makeShape(rank)
+
+ # Constrict the batch size?
+ if testGen.args.max_batch_size:
+ shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
+
+ shape_list = []
+ for i in range(pl + const):
+ shape_list.append(shape.copy())
+
+ return shape_list
+
+ @staticmethod
+ def tgBroadcastFuzz(testGen, op, rank):
+ shape = testGen.makeShape(rank)
+
+ pl, const = op['operands']
+
+ shape_list = []
+
+ # Choose one of the inputs to broadcast
+ bcast_idx = testGen.randInt(0, pl + const)
+ for i in range(pl + const):
+ shape_bcast = shape.copy()
+
+ # If the chosen input, pick a random index to broadcast
+ if i == bcast_idx:
+ fuzz_idx = testGen.randInt(0, rank)
+ shape_bcast[fuzz_idx] = 1
+
+ shape_list.append(shape_bcast)
+
+ return shape_list
+
+ @staticmethod
+ def tgConv2D(testGen, op, rank):
+ pl, const = op['operands']
+
+ assert(rank == 4)
+
+ # IFM dimensions are NHWC
+ ifm_shape = testGen.makeShape(rank)
+
+ # Constrict the batch size?
+ if testGen.args.max_batch_size:
+ ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+
+ # Get the filter height/width from the operator parameters
+ filter_hw = op['filter']
+
+ # Generate a random OFM depth
+ ofm_depth = testGen.makeShape(1)[0]
+
+ # The filter dimensions are OHWI
+ filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
+
+ # The bias is OC
+ bias_shape = np.asarray([ofm_depth])
+
+ return [ifm_shape, filter_shape, bias_shape]
+
+ @staticmethod
+ def tgTransposeConv2D(testGen, op, rank):
+ pl, const = op['operands']
+
+ assert(rank == 4)
+
+ # IFM dimensions are NHWC
+ ifm_shape = testGen.makeShape(rank)
+
+ # Constrict the batch size?
+ if testGen.args.max_batch_size:
+ ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+
+ # Get the filter height/width from the operator parameters
+ filter_hw = op['filter']
+
+ # Generate a random OFM depth
+ ofm_depth = testGen.makeShape(1)[0]
+
+ # The filter dimensions are OHWI
+ filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
+
+ return [ifm_shape, filter_shape]
+
+ @staticmethod
+ def tgDepthwiseConv2D(testGen, op, rank):
+ pl, const = op['operands']
+
+ assert(rank == 4)
+ assert(pl == 1 and const == 2)
+
+ # IFM dimensions are NHWC
+ ifm_shape = testGen.makeShape(rank)
+
+ # Constrict the batch size?
+ if testGen.args.max_batch_size:
+ ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+
+ # Get the filter height/width from the operator parameters
+ # Filter is KH, HW, C, M
+ filter_hw = op['filter']
+
+ # Generate a random OFM depth, but don't let it get too big because
+ # the output depth is M * C
+ filter_m = (testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)) + 1
+
+ # The filter dimensions are HWCM
+ filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
+
+ # The bias is M * C
+ bias_shape = np.asarray([ifm_shape[3] * filter_m])
+
+ return [ifm_shape, filter_shape, bias_shape]
+
+ @staticmethod
+ def tgFullyConnected(testGen, op, rank):
+ pl, const = op['operands']
+
+ assert(rank == 2)
+ assert(pl == 2 and const == 0)
+
+ input_shape = testGen.makeShape(rank)
+ filter_oc = testGen.makeShape(1)[0]
+ filter_shape = np.asarray([filter_oc, input_shape[1]])
+
+ bias_shape = np.asarray([filter_oc])
+
+ return [input_shape, filter_shape, bias_shape]
+
+ @staticmethod
+ def tgMatmul(testGen, op, rank):
+ pl, const = op['operands']
+
+ assert(rank == 2)
+ assert(pl == 2 and const == 0)
+
+ a_shape = testGen.makeShape(rank)
+ b_oc = testGen.makeShape(1)[0]
+ b_shape = np.asarray([a_shape[1], b_oc])
+
+ return [a_shape, b_shape]
+
+class TosaArgGen:
+ '''Argument generators create exhaustive or random lists of attributes for operators that take
+ attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
+ tuples where the descriptive_name is appended to the test name and the arglist is expanded
+ as arguments to the operator build function.'''
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def agNone(testGen, opName, shapeList, dtype):
+ '''A trivial argument generator for operators that don't take any
+ non-tensor arguments'''
+ return [('', [])]
+
+ @staticmethod
+ def agAxis(testGen, opName, shapeList, dtype):
+ '''Build the axis argument for operators that take a single axis'''
+ axes = []
+
+ shape = shapeList[0]
+
+ for a in range(0, len(shape)):
+ axes.append(('axis_{}'.format(a), [a]))
+ return axes
+
+ @staticmethod
+ def agConv2D(testGen, opName, shapeList, dtype):
+ arg_list = []
+
+ ifm_shape = shapeList[0]
+ filter_shape = shapeList[1]
+
+ # Must be rank 4
+ assert(len(ifm_shape) == 4)
+ assert(len(filter_shape) == 4)
+
+ maxStride = testGen.args.max_conv_stride
+ maxPadding = testGen.args.max_conv_padding + 1
+ maxDilation = testGen.args.max_conv_dilation
+
+ # Strides, padding, dilations
+ for stride in range(0, maxStride ** 2):
+ for padding in range(0, (maxPadding) ** 4):
+ for dilation in range(0, maxDilation ** 2):
+
+ s = [stride // maxStride + 1,
+ stride % maxStride + 1]
+ p = [(padding // (maxPadding * 4)) % maxPadding,
+ (padding // (maxPadding * 2)) % maxPadding,
+ (padding // (maxPadding * 1)) % maxPadding,
+ padding % maxPadding]
+ d = [ dilation // maxDilation + 1,
+ dilation % maxDilation + 1]
+
+ # 4 padding parameters for regular conv2d
+ arg_list.append(('st{}{}_pad{}{}{}{}_dilat{}{}'.format(s[0], s[1],
+ p[0], p[1], p[2], p[3],
+ d[0], d[1]),
+ [ s, p, d ]))
+ return arg_list
+
+ @staticmethod
+ def agTransposeConv2D(testGen, opName, shapeList, dtype):
+ arg_list = []
+
+ ifm_shape = shapeList[0]
+ filter_shape = shapeList[1]
+
+ # Must be rank 4
+ assert(len(ifm_shape) == 4)
+ assert(len(filter_shape) == 4)
+
+ maxStride = testGen.args.max_conv_stride
+ maxPadding = testGen.args.max_conv_padding + 1
+ maxDilation = testGen.args.max_conv_dilation
+
+ # Strides, padding, dilations
+ for stride in range(0, maxStride ** 2):
+ for out_padding in range(0, (maxPadding) ** 2):
+ for dilation in range(0, maxDilation ** 2):
+
+ s = [stride // maxStride + 1,
+ stride % maxStride + 1]
+ p = [(out_padding // (maxPadding * 1)) % maxPadding,
+ out_padding % maxPadding]
+ d = [ dilation // maxDilation + 1,
+ dilation % maxDilation + 1]
+
+ oh = (ifm_shape[1] - filter_shape[1] - (filter_shape[1] - 1) * (d[0] - 1) + \
+ 2 * p[0]) // s[0] + 1
+
+ ow = (ifm_shape[2] - filter_shape[2] - (filter_shape[2] - 1) * (d[1] - 1) + \
+ 2 * p[1]) // s[1] + 1
+
+ # Output shape
+ os = [ ifm_shape[0], oh, ow, filter_shape[0] ]
+
+ arg_list.append(('st{}{}_outpad{}{}_dilat{}{}_os{}x{}x{}x{}'.format(s[0], s[1],
+ p[0], p[1],
+ d[0], d[1],
+ os[0], os[1], os[2], os[3]),
+ [ s, p, d, os ]))
+
+ return arg_list
+
+ @staticmethod
+ def agPad(testGen, opName, shapeList, dtype):
+ arg_list = []
+ rank = len(shapeList[0])
+
+ # Exhaustively test combinations of 0/1 padding on each side of each dimension
+ # This process might need some revision for >1 padding, but use rank**2 as a bitmask
+ # for now
+ for v in range(rank ** 2):
+
+ # Create a flat arraypadding4D
+ paddings = np.zeros((rank * 2), dtype=np.int32)
+
+ # Fill in the 1's
+ for r in (range(rank * 2)):
+ if (v >> r) & 1:
+ paddings[r] = 1
+
+ # Reshape back to a 2D array
+ paddings = paddings.reshape((rank, 2))
+
+ arg_list.append(('pad{0:b}'.format(v), [ paddings ]))
+
+ return arg_list
+
+ @staticmethod
+ def agPooling(testGen, opName, shapeList, dtype):
+ arg_list = []
+
+ shape = shapeList[0]
+ assert(len(shape) == 4)
+
+ maxStride = testGen.args.max_pooling_stride
+ maxKernel = testGen.args.max_pooling_kernel
+ maxPadding = testGen.args.max_pooling_padding + 1
+
+ for kernel in range(0, maxKernel ** 2):
+ for stride in range(0, maxStride ** 2):
+ for padding in range(0, maxPadding ** 4):
+ s = [stride // maxStride + 1,
+ stride % maxStride + 1]
+ k = [(kernel // maxKernel) + 2,
+ (kernel % maxKernel) + 2]
+ p = [(padding // (maxPadding * 4)) % maxPadding,
+ (padding // (maxPadding * 2)) % maxPadding,
+ (padding // (maxPadding * 1)) % maxPadding,
+ padding % maxPadding]
+
+ arg_list.append(('st{}{}_kern{}{}_pad{}{}{}{}'.format(s[0], s[1],
+ k[0], k[1],
+ p[0], p[1], p[2], p[3]),
+ [k, s, p]))
+ return arg_list
+
+ @staticmethod
+ def agCast(testGen, opName, shapeList, inDtype):
+ arg_list = []
+
+ # Enumerate the output types here
+ if inDtype == DType.INT8:
+ dtypeList = [ DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT ]
+ elif inDtype == DType.INT16:
+ dtypeList = [ DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT ]
+ elif inDtype == DType.INT32:
+ dtypeList = [ DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT ]
+ elif inDtype == DType.BOOL:
+ dtypeList = [ DType.INT8, DType.INT16, DType.INT32 ]
+ elif inDtype == DType.FLOAT:
+ dtypeList = [ DType.INT8, DType.INT16, DType.INT32 ]
+ else:
+ raise Exception('Unexpected input dtype: {}'.format(inDtype))
+
+ for dtype in dtypeList:
+ arg_list.append(('out{}'.format(DTypeNames[dtype]), [dtype]))
+
+ return arg_list
+
+ @staticmethod
+ def agRescale(testGen, opName, shapeList, inDtype):
+ arg_list = []
+
+ # Enumerate the output types here
+ for dtype in [ DType.AINT8, DType.INT16, DType.INT32 ]:
+ for scale32 in [ False, True ]:
+ for double_round in [ False, True ]:
+ for per_channel in [ False, True ]:
+
+ if inDtype == DType.INT48 and scale32:
+ # Illegal condition. Must be scale32=False
+ continue
+
+ arg_list.append(('out{}_sc{}_dr{}_pc{}'.format(DTypeNames[dtype], int(scale32), int(double_round), int(per_channel)),
+ [dtype, scale32, double_round, per_channel]))
+
+ return arg_list
+
+ # Helper function for reshape. Gets some factors of a larger number.
+ @staticmethod
+ def getFactors(val, start=1):
+ factors = []
+
+ for i in range(start, int(np.sqrt(val))):
+ if (val % i) == 0:
+ factors.append(i)
+
+ return factors
+
+ @staticmethod
+ def agReshape(testGen, opName, shapeList, dtype):
+ arg_list = []
+
+ origShape = shapeList[0]
+
+ totalElements = 1
+ for s in origShape:
+ totalElements *= s
+
+ # This code is NOT fast. Fortunately, the numbers are fairly small.
+ factors = TosaArgGen.getFactors(totalElements)
+
+ for p in range(testGen.args.num_rand_permutations):
+ newRank = testGen.randInt(1, 6)
+ newShape = []
+ if (len(factors) < newRank):
+ continue
+
+ remainingElements = totalElements
+ shuffledFactors = testGen.rng.permutation(factors)
+ for i in range(newRank):
+ # pick rank-1 factors
+ newShape.append(shuffledFactors[0])
+ remainingElements = remainingElements // shuffledFactors[0]
+ shuffledFactors = testGen.rng.permutation(TosaArgGen.getFactors(remainingElements))
+ newShape.append(remainingElements)
+
+ # Toss in a -1 sometimes
+ minusOne = testGen.randInt(0, newRank * 4)
+ if minusOne < newRank:
+ newShape[minusOne] = -1
+
+ arg_list.append(('perm{}_rank{}'.format(p, newRank), [newShape]))
+
+ return arg_list
+
+
+ @staticmethod
+ def agTranspose(testGen, opName, shapeList, dtype):
+ arg_list = []
+
+ ifm_shape = shapeList[0]
+
+ perms = range(len(ifm_shape))
+ for p in range(testGen.args.num_rand_permutations):
+ perms = np.int32(testGen.rng.permutation(perms)).tolist()
+
+ # Avoid duplicates
+ found = False
+ for name, other_perm in arg_list:
+ if other_perm[0] == perms:
+ found = True
+ break
+
+ if not found:
+ arg_list.append(('perm{}'.format(p), [perms]))
+
+ return arg_list
+
+ @staticmethod
+ def agSlice(testGen, opName, shapeList, dtype):
+ arg_list = []
+
+ ifm_shape = shapeList[0]
+ rank = len(ifm_shape)
+
+ for p in range(testGen.args.num_rand_permutations):
+ begin = []
+ size = []
+
+ valid=True
+
+ for i in range(rank):
+ if ifm_shape[i] > 1:
+ begin.append(testGen.randInt(0, ifm_shape[i]))
+ size.append(testGen.randInt(0, ifm_shape[i] - begin[i]))
+
+ # Invalid slice size?
+ if size[i] == 0:
+ valid = False
+ else:
+ begin.append(0)
+ size.append(1)
+
+ if valid:
+ arg_list.append(('perm{}'.format(p), [begin, size]))
+ return arg_list
+
+ @staticmethod
+ def agTile(testGen, opName, shapeList, dtype):
+ arg_list = []
+
+ ifm_shape = shapeList[0]
+ rank = len(ifm_shape)
+
+ for p in range(testGen.args.num_rand_permutations):
+
+ # Pick a few random, but small multiple values
+ # because otherwise this has a tendency to generate
+ # enormous tensors
+ multiples = []
+ for i in range(rank):
+ multiples.append(testGen.randInt(1, 4))
+
+ arg_list.append(('perm{}'.format(p), [multiples]))
+
+ return arg_list
+
+ @staticmethod
+ def agResize(testGen, opName, shapeList, dtype):
+ arg_list = []
+
+ ifm_shape = shapeList[0]
+
+ for m in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
+
+ # Exclude illegal {mode, type} configurations. Pick legal output types
+ if m == ResizeMode.NEAREST and dtype == DType.INT8:
+ outputDTypeList = [ DType.INT32 ]
+ elif m == ResizeMode.NEAREST and dtype == DType.INT16:
+ outputDTypeList = [ DType.INT16 ]
+ elif m == ResizeMode.BILINEAR and dtype == DType.INT8:
+ outputDTypeList = [ DType.INT8 ]
+ elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
+ outputDTypeList = [ DType.INT48 ]
+ else:
+ continue
+
+ for outputDType in outputDTypeList:
+ for perm in range(testGen.args.num_rand_permutations):
+
+ # Randomly generate legal output dimensions and shift
+ # and then compute the stride and offset based on them
+ output_dims = [ testGen.randInt(), testGen.randInt() ]
+
+ shift = testGen.randInt(1, 11)
+
+ stride = [ (ifm_shape[1] << shift) // output_dims[0],
+ (ifm_shape[2] << shift) // output_dims[1] ]
+
+ offset = [ testGen.randInt(-stride[0], (ifm_shape[1] << shift) - (output_dims[0] - 1) * stride[0]),
+ testGen.randInt(-stride[1], (ifm_shape[2] << shift) - (output_dims[1] - 1) * stride[1]) ]
+
+ arg_list.append(('mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}'.format(m, shift, output_dims[0], output_dims[1],
+ testGen.typeStr(outputDType), stride[0], stride[1],
+ offset[0], offset[1]),
+ [m, stride, offset, shift, output_dims, outputDType]))
+
+ return arg_list
+
+ def agCondIf(testGen, opName, shapeList, dtype):
+ # CondIf generates the condition values here.
+ # Convert to tensors in the build function, along with the
+ # then and else blocks
+ arg_list = []
+
+ for c in [False, True]:
+ arg_list.append(('cond{}'.format(int(c)), [ c ]))
+
+ return arg_list
+
+ def agWhileLoop(testGen, opName, shapeList, dtype):
+ # While loop: 0 iterations, 1, more than 1
+ arg_list = []
+
+ for iter in [0, 1, 4]:
+ arg_list.append(('iter{}'.format(iter), [ iter ]))
+
+ return arg_list
+
+class TosaTestGen:
+ def __init__(self, args):
+ self.args = args
+ self.basePath = args.output_dir
+ self.random_seed = args.random_seed
+ self.ser = None
+ self.rng = np.random.default_rng(self.random_seed)
+ self.createDynamicOpLists()
+ self.initOpListDefaults()
+ self.quantGen = TosaQuantGen()
+ # Force makeShape to do a specific starting shape
+ self.targetted_shape = None
+
+ def createSerializer(self, opName, testPath):
+ self.testPath = os.path.join(opName, testPath)
+
+ fullPath = os.path.join(self.basePath, self.testPath)
+ os.makedirs(fullPath, exist_ok=True)
+ self.ser = ts.TosaSerializer(fullPath)
+
+ def getSerializer(self):
+ return self.ser
+
+ def serialize(self, testName):
+ with open(os.path.join(self.basePath, self.testPath, '{}.tosa'.format(testName)), 'wb') as fd:
+ fd.write(self.ser.serialize())
+
+ with open(os.path.join(self.basePath, self.testPath, 'desc.json'), 'w') as fd:
+ fd.write(self.ser.writeJson('{}.tosa'.format(testName)))
+
+ def getRandTensor(self, shape, dtype):
+ RAND_SHIFT_FACTOR = 0.5
+ RAND_SCALE_FACTOR = 4.0
+
+ if dtype == DType.BOOL:
+ np_dt = np.bool
+ return np.bool_(self.rng.choice(a=[False, True], size=shape))
+ elif dtype == DType.AINT8:
+ return np.int32(self.rng.integers(low=0, high=256, size=shape))
+ elif dtype == DType.INT4:
+ return np.int32(self.rng.integers(low=-7, high=8, size=shape))
+ elif dtype == DType.INT8:
+ return np.int32(self.rng.integers(low=-127, high=128, size=shape))
+ elif dtype == DType.INT16:
+ return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
+ elif dtype == DType.INT32:
+ return np.int32(self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape))
+ elif dtype == DType.INT48:
+ return np.int64(self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape))
+ elif dtype == DType.FLOAT:
+ return np.float32(self.rng.random(size=shape) - RAND_SHIFT_FACTOR * RAND_SCALE_FACTOR)
+ else:
+ raise Exception('Unrecognized Dtype: {}'.format(dtype))
+
+ def buildPlaceholderTensors(self, shape_list, dtype):
+ placeholders = []
+
+ for shape in shape_list:
+ arr = self.getRandTensor(shape, dtype)
+ placeholders.append(self.ser.addPlaceholder(shape, dtype, Usage.ACTIVATION, [], arr))
+
+ return placeholders
+
+ def buildConstTensors(self, shape_list, dtype):
+ consts = []
+
+ for shape in shape_list:
+ arr = self.getRandTensor(shape, dtype)
+ consts.append(self.ser.addConst(shape, dtype, Usage.ACTIVATION, [], arr))
+
+ return consts
+
+ def makeShape(self, rank):
+ if self.targetted_shape:
+ return np.int32(self.targetted_shape)
+ return np.int32(self.rng.integers(low=self.args.tensor_shape_range[0],
+ high=self.args.tensor_shape_range[1],
+ size=rank))
+
+ def setTargetShape(self, shape):
+ self.targetted_shape = shape
+
+ def randInt(self, low=0, high=256):
+ return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
+
+ def getRandNumberDType(self, dtype):
+ if dtype == DType.FLOAT:
+ return self.rng.random()
+ elif dtype == DType.BOOL:
+ return self.rng.choice([False, True])
+ elif dtype == DType.INT4:
+ low, high = (-7, 8)
+ elif dtype == DType.AINT8:
+ low, high = (0, 256)
+ elif dtype == DType.INT8:
+ low, high = (-127, 128)
+ elif dtype == DType.INT16:
+ low, high = (-32768, 32768)
+ elif dtype == DType.INT32:
+ low, high = (-(1<<31), (1<<31))
+ elif dtype == DType.INT48:
+ low, high = (-(1<<47), (1<<47))
+ # Special size
+ return np.int64(self.rng.integers(low, high, size=1))[0]
+ else:
+ raise Exception('Unknown dtype: {}'.format(dtype))
+
+ return np.int32(self.rng.integers(low, high, size=1))[0]
+
+ def shapeStr(self, shape):
+
+ sStr = []
+ # Convert to strings
+ for i in shape:
+ sStr.append(str(i))
+
+ return 'x'.join(sStr)
+
+ def typeStr(self, t):
+ if t == DType.BOOL:
+ return 'b'
+ elif t == DType.AINT8:
+ return 'a8'
+ elif t == DType.INT4:
+ return 'i4'
+ elif t == DType.INT8:
+ return 'i8'
+ elif t == DType.INT16:
+ return 'i16'
+ elif t == DType.INT32:
+ return 'i32'
+ elif t == DType.INT48:
+ return 'i48'
+ elif t == DType.FLOAT:
+ return 'float'
+ else:
+ raise Exception('Unknown dtype, cannot convert to string: {}'.format(t))
+
+ def typeWidth(self, t):
+ ''' Get the datatype width for integer types'''
+ if t == DType.AINT8:
+ return 8
+ elif t == DType.UINT8:
+ return 8
+ elif t == DType.INT4:
+ return 4
+ elif t == DType.INT8:
+ return 8
+ elif t == DType.INT16:
+ return 16
+ elif t == DType.INT32:
+ return 32
+ elif t == DType.INT48:
+ return 48
+ else:
+ raise Exception('Unknown dtype, cannot convert to string: {}'.format(t))
+
+ # Argument generators
+ # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
+ # Where the string descriptor is used to generate the test name and
+ # The build_fcn_arg_list is expanded and passed to the operator test
+ # build function
+
+
+ def build_unary(self, op, a, qinfo = None):
+ result_tens = OutputShaper.unaryOp(self.ser, a)
+ self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
+ return result_tens
+
+ def build_binary_broadcast(self, op, a, b):
+ result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
+ self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
+ return result_tens
+
+ def build_binary_nonbroadcast(self, op, a, b):
+ result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
+ self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
+ return result_tens
+
+ def build_mul(self, op, a, b):
+ result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
+
+ # Special for multiply:
+ # Force the result to INT32 for INT types
+ if a.dtype != DType.FLOAT:
+ result_tens.setDtype(DType.INT32)
+
+ self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
+ return result_tens
+
+ def build_table(self, op, a):
+ # Constant size, random values
+ table_arr = self.getRandTensor([513], DType.INT16)
+ table_tens = self.ser.addConst(table_arr.shape, DType.INT16, Usage.INDEX, [], table_arr)
+
+ result_tens = OutputShaper.tableOp(self.ser, a, table_tens)
+ self.ser.addOperator(op, [a.name, table_tens.name], [result_tens.name], None)
+
+ return result_tens
+
+ def build_select(self, op, cond, a, b):
+
+ # Replace the cond tensor with a boolean tensor since it probably
+ # has the wrong dtype
+ t = self.buildPlaceholderTensors([cond.shape], DType.BOOL)
+ cond = t[0]
+
+ result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
+ self.ser.addOperator(op, [cond.name, a.name, b.name], [result_tens.name])
+
+ return result_tens
+
+ def build_comparison(self, op, a, b):
+ result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
+ self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
+ return result_tens
+
+ def build_argmax(self, op, a, axis):
+ result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.AxisAttribute(axis)
+
+ self.ser.addOperator(op, [a.name], [result_tens.name], attr)
+ return result_tens
+
+ def build_pool2d(self, op, input, kernel, stride, pad, qinfo = None):
+ result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.Pool2dAttribute(kernel, stride, pad)
+ input.addFormat(Format.NHWC)
+
+ self.ser.addOperator(op, [input.name], [result_tens.name], attr, qinfo)
+ return result_tens
+
+ def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
+ assert(len(padding) == 4)
+ result_tens = OutputShaper.conv2dOp(self.ser, ifm, filter, strides, padding, dilations)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.Conv2dAttribute(padding, strides, dilations)
+
+ ifm.addFormat(Format.NHWC)
+ # Update the filter ordering
+ filter.addUsage(Usage.WEIGHT)
+ filter.addFormat(Format.OHWI)
+
+ self.ser.addOperator(op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo)
+ return result_tens
+
+ def build_transpose_conv2d(self, op, ifm, filter, stride, outpad, dilation, output_shape, qinfo):
+ assert(len(outpad) == 2)
+ result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
+
+ ifm.addFormat(Format.NHWC)
+ # Update the filter ordering
+ filter.addUsage(Usage.WEIGHT)
+ filter.addFormat(Format.OHWI)
+
+ # Create bias here since the acc_t depends on (but isn't the same as) the input dtype
+ # The bias is OC
+ if ifm.dtype == DType.AINT8 or ifm.dtype == DType.INT8:
+ bias_type = DType.INT32
+ elif ifm.dtype == DType.INT16:
+ bias_type = DType.INT48
+ elif ifm.dtype == DType.FLOAT:
+ bias_type = DType.FLOAT
+ else:
+ raise Exception('Unsupported dtype for transpose_conv2d: {}'.format(ifm.dtype))
+
+ bias_arr = self.getRandTensor([filter.shape[0]], bias_type)
+ bias_tens = self.ser.addConst([filter.shape[0]], bias_type, [], [], bias_arr)
+
+ self.ser.addOperator(op, [ifm.name, filter.name, bias_tens.name], [result_tens.name], attr, qinfo)
+ return result_tens
+
+ def build_depthwise_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
+ result_tens = OutputShaper.depthwiseConv2dOp(self.ser, ifm, filter, strides, padding, dilations)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.Conv2dAttribute(padding, strides, dilations)
+
+ ifm.addFormat(Format.NHWC)
+ filter.addUsage(Usage.WEIGHT)
+ filter.addFormat(Format.HWIM)
+
+ self.ser.addOperator(op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo)
+ return result_tens
+
+ def build_fully_connected(self, op, ifm, filter, bias, qinfo):
+ result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
+
+ filter.addUsage(Usage.WEIGHT)
+ self.ser.addOperator(op, [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo)
+ return result_tens
+
+ def build_matmul(self, op, a, b, qinfo):
+ result_tens = OutputShaper.matmulOp(self.ser, a, b)
+ self.ser.addOperator(op, [a.name, b.name], [result_tens.name], None, qinfo)
+ return result_tens
+
+ def build_reduce(self, op, a, axis):
+ result_tens = OutputShaper.reduceOp(self.ser, a, axis)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.AxisAttribute(axis)
+
+ self.ser.addOperator(op, [a.name], result_tens.name, attr)
+ return result_tens
+
+ def build_clamp(self, op, a):
+ result_tens = OutputShaper.unaryOp(self.ser, a)
+
+ attr = ts.TosaSerializerAttribute()
+
+ # Get two random ints
+ v = [self.randInt(), self.randInt()]
+
+ if a.dtype == DType.FLOAT:
+ attr.ClampAttribute(0, 0, min(v), max(v))
+ else:
+ attr.ClampAttribute(min(v), max(v), 0, 0)
+
+ self.ser.addOperator(op, [a.name], [result_tens.name], attr)
+ return result_tens
+
+ def build_leaky_relu(self, op, a):
+ result_tens = OutputShaper.unaryOp(self.ser, a)
+ attr = ts.TosaSerializerAttribute()
+
+ attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
+
+ self.ser.addOperator(op, [a.name], [result_tens.name], attr)
+ return result_tens
+
+ # Needs an additional type/input
+ def build_prelu(self, op, a):
+ result_tens = OutputShaper.unaryOp(self.ser, a)
+
+ self.ser.addOperator(op, [a.name], [result_tens.name])
+ return result_tens
+
+ def build_relun(self, op, a):
+ result_tens = OutputShaper.unaryOp(self.ser, a)
+
+ attr = ts.TosaSerializerAttribute()
+
+ if a.dtype == DType.FLOAT:
+ attr.ReluNAttribute(0, self.getRandNumberDType(a.dtype))
+ else:
+ attr.ReluNAttribute(self.getRandNumberDType(a.dtype), 0)
+
+ self.ser.addOperator(op, [a.name], [result_tens.name], attr)
+ return result_tens
+
+ def build_sigmoid(self, op, a):
+ result_tens = OutputShaper.unaryOp(self.ser, a)
+ self.ser.addOperator(op, [a.name], [result_tens.name])
+ return result_tens
+
+ def build_tanh(self, op, a):
+ result_tens = OutputShaper.unaryOp(self.ser, a)
+ self.ser.addOperator(op, [a.name], [result_tens.name])
+ return result_tens
+
+ def build_concat(self, op, a, b, axis):
+ result_tens = OutputShaper.concatOp(self.ser, a, b, axis)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.AxisAttribute(axis)
+
+ self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
+
+ def build_pad(self, op, a, padding, qinfo):
+ result_tens = OutputShaper.padOp(self.ser, a, padding)
+
+ # Need to turn the padding array into a TOSA tensor here.
+ # This is one of the few tensor operands that does not get
+ # randomly generated
+ padding_tens = self.ser.addConst(padding.shape, DType.INT32, [], [], padding)
+
+ self.ser.addOperator(op, [a.name, padding_tens.name], [result_tens.name], None, qinfo)
+
+ def build_reshape(self, op, a, newShape):
+ result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.ReshapeAttribute(newShape)
+
+ self.ser.addOperator(op, [a.name], [result_tens.name], attr)
+ return result_tens
+
+ def build_reverse(self, op, a, axis):
+ result_tens = OutputShaper.unaryOp(self.ser, a)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.AxisAttribute(axis)
+
+ self.ser.addOperator(op, [a.name], [result_tens.name], attr)
+ return result_tens
+
+ def build_transpose(self, op, a, perms):
+ result_tens = OutputShaper.transposeOp(self.ser, a, perms)
+
+ perms_tens = self.ser.addConst([len(perms)], DType.INT32, Usage.ACTIVATION, [], np.int32(perms))
+
+ self.ser.addOperator(op, [a.name, perms_tens.name], [result_tens.name])
+ return result_tens
+
+ def build_slice(self, op, a, begin, size):
+ result_tens = OutputShaper.sliceOp(self.ser, a, begin, size)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.SliceAttribute(begin, size)
+
+ self.ser.addOperator(op, [a.name], [result_tens.name], attr)
+ return result_tens
+
+ def build_tile(self, op, a, multiples):
+ result_tens = OutputShaper.tileOp(self.ser, a, multiples)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.TileAttribute(multiples)
+
+ self.ser.addOperator(op, [a.name], [result_tens.name], attr)
+ return result_tens
+
+
+ def build_gather(self, op, values, axis):
+
+ # Create a new indicies tensor
+ # here with data that doesn't exceed the dimensions of the values tensor
+
+ max_val = values.shape[axis]
+ indicies_arr = np.int32(self.rng.integers(low=0, high=max_val, size=[self.randInt(1, max_val + 1)]))
+ indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, Usage.INDEX, [], indicies_arr)
+
+ result_tens = OutputShaper.gatherOp(self.ser, values, indicies, axis)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.AxisAttribute(axis)
+
+ self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name], attr)
+
+ return result_tens
+
+ def build_resize(self, op, input, mode, stride, offset, shift, output_dims, output_dtype):
+ result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, output_dims, output_dtype)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.ResizeAttribute(output_dims, stride, offset, shift, mode)
+
+ self.ser.addOperator(op, [input.name], [result_tens.name], attr)
+ return result_tens
+
+ def build_identityn(self, op, val, val2):
+
+ result_tens = OutputShaper.unaryOp(self.ser, val)
+ result_tens2 = OutputShaper.unaryOp(self.ser, val2)
+ self.ser.addOperator(op, [val.name, val2.name], [result_tens.name, result_tens2.name])
+ return result_tens
+
+ def build_placeholder(self, op, val):
+ # Add an identity op to avoid warning in the reference model
+ return self.build_unary(Op.IDENTITY, val)
+
+ # Type Conversion
+ def build_cast(self, op, val, out_dtype):
+ result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
+ self.ser.addOperator(op, [val.name], [result_tens.name])
+ return result_tens
+
+ def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
+ result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
+
+ if per_channel:
+ nc = val.shape[-1]
+ else:
+ nc = 1
+
+ in_type_width = self.typeWidth(val.dtype)
+ out_type_width = self.typeWidth(out_dtype)
+
+ if val.dtype == DType.AINT8:
+ input_zp = self.randInt()
+ in_type_width = in_type_width + 1
+ else:
+ input_zp = 0
+
+ if out_dtype == DType.AINT8:
+ output_zp = self.randInt()
+ out_type_width = out_type_width + 1
+ else:
+ output_zp = 0
+
+ # Calculate scale based on:
+ # scale = a *(2^output_width)/(2^input_width))
+
+ a = np.float32(self.rng.random(size=[nc]))
+ scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width))
+
+ if scale32:
+ pass
+ # Cap the scaling at 2^15 - 1 for scale16
+ scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
+ else:
+ # Cap the scaling at 2^15 - 1 for scale16
+ scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
+
+ #print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
+
+ multiplier_arr = np.int32(np.zeros(shape=[nc]))
+ shift_arr = np.int32(np.zeros(shape=[nc]))
+
+ for i in range(nc):
+ multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(scale_arr[i], scale32)
+
+ #print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
+
+ attr = ts.TosaSerializerAttribute()
+ attr.RescaleAttribute(input_zp,
+ output_zp,
+ multiplier_arr,
+ shift_arr,
+ scale32,
+ double_round,
+
+ per_channel)
+
+ self.ser.addOperator(op, [val.name], [result_tens.name], attr)
+ return result_tens
+
+ def build_cond_if_const(self, op, then_tens, else_tens, cond):
+ # For cond_if with constants, we're supplied with then/else tensors that we ignore
+ # (except for the generated shap) and the condition. Build Then/Else blocks
+ # and fill them with const nodes for the body.
+
+ # Condition tensor
+ cond_tens = self.ser.addConst([], DType.BOOL, Usage.ACTIVATION, [], [cond])
+
+ # Make then/else tensors
+ out_shape = then_tens.shape
+ then_arr = np.int32(self.rng.integers(0, 255, size=out_shape))
+ else_arr = np.int32(self.rng.integers(0, 255, size=out_shape))
+
+ # And the result tensor based on any of the outputs
+ result_tens = self.ser.addOutput(out_shape, DType.INT32, Usage.ACTIVATION, [])
+
+ # Create the attribute with the names of the then/else blocks
+ then_block = 'THEN_BLOCK'
+ else_block = 'ELSE_BLOCK'
+ attr = ts.TosaSerializerAttribute()
+ attr.CondIfAttribute(then_block, else_block)
+
+ # Finally, build the op and the two blocks
+ self.ser.addOperator(op, [cond_tens.name], [result_tens.name], attr)
+
+ self.ser.startBasicBlock(then_block)
+ # Build the actual then/else tensors inside their blocks
+ then_tens = self.ser.addConst(out_shape, DType.INT32, Usage.ACTIVATION, [], then_arr)
+ self.ser.addOutputTensor(then_tens)
+
+ self.ser.startBasicBlock(else_block)
+ else_tens = self.ser.addConst(out_shape, DType.INT32, Usage.ACTIVATION, [], else_arr)
+ self.ser.addOutputTensor(else_tens)
+
+ return result_tens
+
+ def build_cond_if_binary(self, op, a, b, cond):
+ # For cond_if with a binary op in the then/else blocks, take a and b and
+ # alternately add or subtract them based on the condition
+
+ # Condition tensor
+ cond_tens = self.ser.addConst([], DType.BOOL, Usage.ACTIVATION, [], [cond])
+
+ result_tens = self.ser.addOutput(a.shape, a.dtype, Usage.ACTIVATION, [])
+ self.ser.currBasicBlock.addOutput(result_tens.name)
+
+ # Create the attribute with the names of the then/else blocks
+ then_block = 'THEN_BLOCK'
+ else_block = 'ELSE_BLOCK'
+ attr = ts.TosaSerializerAttribute()
+ attr.CondIfAttribute(then_block, else_block)
+
+ # Finally, build the op and the two blocks
+ self.ser.addOperator(op, [cond_tens.name, a.name, b.name], [result_tens.name], attr)
+
+ self.ser.startBasicBlock(then_block)
+ self.ser.addInputTensor(a)
+ self.ser.addInputTensor(b)
+ then_tens = self.ser.addOutput(a.shape, a.dtype, a.usage, a.dformat)
+ self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
+
+ self.ser.startBasicBlock(else_block)
+ self.ser.addInputTensor(a)
+ self.ser.addInputTensor(b)
+ else_tens = self.ser.addOutput(a.shape, a.dtype, a.usage, a.dformat)
+ self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
+
+ return result_tens
+
+ def build_while_loop(self, op, a, iter_val):
+ iter = self.ser.addPlaceholder([], DType.INT32, Usage.ACTIVATION, [], [np.int32(iter_val)])
+
+ cond_block = 'COND_BLOCK'
+ body_block = 'BODY_BLOCK'
+
+ attr = ts.TosaSerializerAttribute()
+ attr.WhileLoopAttribute(cond_block, body_block)
+
+ # Accumulator tensor
+ #acc = self.ser.addOutput(a.shape, a.dtype, a.usage, a.dformat)
+ acc_init_val = np.int32(np.zeros(a.shape))
+ acc = self.ser.addPlaceholder(a.shape, a.dtype, a.usage, a.dformat, acc_init_val)
+
+ # Intermediate/output tensors for everything going through the loop
+ iter_out = self.ser.addIntermediate(iter.shape, iter.dtype, iter.usage, iter.dformat)
+ a_out = self.ser.addIntermediate(a.shape, a.dtype, a.usage, a.dformat)
+ acc_out = self.ser.addIntermediate(acc.shape, acc.dtype, acc.usage, acc.dformat)
+
+ # While_loop operator
+ self.ser.addOperator(op,
+ [iter.name, a.name, acc.name],
+ [iter_out.name, a_out.name, acc_out.name], attr)
+
+ # COND block (input: iter, output: cond_tens )
+ self.ser.startBasicBlock(cond_block)
+ self.ser.addInputTensor(iter)
+ self.ser.addInputTensor(a)
+ self.ser.addInputTensor(acc)
+ zero_tens = self.ser.addConst([], DType.INT32, [], [], [np.int32(0)])
+ cond_tens = self.ser.addOutput([], DType.BOOL, [], [])
+ self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name],
+ [cond_tens.name])
+
+ # BODY block (input: a, acc, iter, output: a, acc, iter)
+ # Note that local intermediate tensors need to be declared here for the outputs
+ self.ser.startBasicBlock(body_block)
+ self.ser.addInputTensor(iter)
+ self.ser.addInputTensor(a)
+ self.ser.addInputTensor(acc)
+ one_tens = self.ser.addConst([], DType.INT32, [], [], [np.int32(1)])
+ iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype, iter.usage, iter.dformat)
+ acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype, acc.usage, acc.dformat)
+ self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
+ self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
+ self.ser.addOutputTensor(iter_body_out)
+ self.ser.addOutputTensor(a)
+ self.ser.addOutputTensor(acc_body_out)
+
+ return acc_out
+
+
+ def genOpTestList(self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None):
+
+ try:
+ op = self.TOSA_OP_LIST[opName]
+ except KeyError as e:
+ raise Exception('Cannot find op with name {}'.format(opName))
+
+ # Initialize a new random number generator
+ self.rng = np.random.default_rng(self.random_seed)
+
+ build_fcn, tgen_fcn, agen_fcn = op['build_fcn']
+
+ # Generate the lists of arguments
+ rmin, rmax = op['rank']
+
+ # Test list consists of a tuple of:
+ # (opName, testNameStr, dtype, shapeList, argumentsList)
+ testList = []
+
+ if not shapeFilter:
+ shapeFilter = [None]
+
+ for r in range(rmin, rmax + 1):
+
+ # Filter out the rank?
+ if rankFilter is not None and r not in rankFilter:
+ continue
+
+ for t in op['types']:
+
+ # Filter tests based on dtype?
+ if dtypeFilter is not None:
+ if t not in dtypeFilter:
+ continue
+
+ # Create the placeholder and const tensors
+ for shape in shapeFilter:
+ # A None shape chooses a random shape of a given rank
+
+ # Filter out by rank
+ if shape is not None and len(shape) != r:
+ continue
+
+ self.setTargetShape(shape)
+ shapeList = tgen_fcn(self, op, r)
+
+ shapeStr = self.shapeStr(shapeList[0])
+ typeStr = self.typeStr(t)
+
+ # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
+ argList = []
+ if agen_fcn:
+ argList = agen_fcn(self, opName, shapeList, t)
+ else:
+ argList = [('', [])]
+
+ for argStr, args in argList:
+ if argStr:
+ testStr = '{}_{}_{}_{}'.format(opName, shapeStr, typeStr, argStr)
+ else:
+ testStr = '{}_{}_{}'.format(opName, shapeStr, typeStr)
+
+ testList.append((opName, testStr, t, shapeList, args))
+
+ return testList
+
+ def serializeTest(self, opName, testStr, dtype, shapeList, testArgs):
+ try:
+ op = self.TOSA_OP_LIST[opName]
+ except KeyError as e:
+ raise Exception('Cannot find op with name {}'.format(opName))
+
+ # Create a serializer
+ self.createSerializer(opName, testStr)
+
+ build_fcn, tgen_fcn, agen_fcn = op['build_fcn']
+ pCount, cCount = op['operands']
+
+ try:
+ qgen = op['qgen']
+ except KeyError:
+ qgen = None
+
+ # Build the random tensor operands and the test
+ tens = []
+ tens.extend(self.buildPlaceholderTensors(shapeList[0:pCount], dtype))
+ tens.extend(self.buildConstTensors(shapeList[pCount:], dtype))
+
+ if qgen is not None:
+ qinfo = qgen(self, op, dtype)
+ else:
+ qinfo = None
+
+ try:
+ if qinfo is not None:
+ resultName = build_fcn(self, op['op'], *tens, *testArgs, qinfo)
+ else:
+ resultName = build_fcn(self, op['op'], *tens, *testArgs)
+ except TypeError as e:
+ print('build_fcn: {}\nTensors: {}\nArgs: {}\n'.format(build_fcn, tens, testArgs))
+ raise e
+
+ # Save the serialized test
+ self.serialize('test')
+
+ def createDynamicOpLists(self):
+
+ # Dynamically create op lists for convolutions with a list of kernel sizes
+ KERNELS = [ [1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3] ]
+
+ for k in KERNELS:
+ testName = 'conv2d_{}x{}'.format(k[0], k[1])
+ self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST['conv2d_TEMPLATE'].copy()
+ self.TOSA_OP_LIST[testName]['filter'] = k
+ self.TOSA_OP_LIST[testName]['template'] = False
+
+ testName = 'depthwise_conv2d_{}x{}'.format(k[0], k[1])
+ self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST['depthwise_conv2d_TEMPLATE'].copy()
+ self.TOSA_OP_LIST[testName]['filter'] = k
+ self.TOSA_OP_LIST[testName]['template'] = False
+
+ testName = 'transpose_conv2d_{}x{}'.format(k[0], k[1])
+ self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST['transpose_conv2d_TEMPLATE'].copy()
+ self.TOSA_OP_LIST[testName]['filter'] = k
+ self.TOSA_OP_LIST[testName]['template'] = False
+
+ # Delete any templates after having created any dynamic ops
+ # This is a two-pass operation because it's bad practice to delete
+ # keys from dictionaries while iterating
+ keyList = []
+ for k in self.TOSA_OP_LIST:
+ try:
+ if self.TOSA_OP_LIST[k]['template'] == True:
+ keyList.append(k)
+ continue
+ except KeyError:
+ pass
+
+ for k in keyList:
+ del self.TOSA_OP_LIST[k]
+
+ def initOpListDefaults(self):
+ '''Fill in default fields for ops if they aren't already specified.
+ Look for missing required fields (datastructure linting).'''
+ for op in self.TOSA_OP_LIST:
+
+ # Required fields
+ try:
+ pl, c = self.TOSA_OP_LIST[op]['operands']
+ except (KeyError, ValueError, TypeError):
+ raise Exception('Op {} is missing a valid operand tuple in TOSA_OP_LIST'.format(op))
+
+ try:
+ fcn, tgen, arggen = self.TOSA_OP_LIST[op]['build_fcn']
+ except (KeyError, ValueError, TypeError):
+ raise Exception('Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST'.format(op))
+
+ try:
+ types = self.TOSA_OP_LIST[op]['types']
+ except KeyError as e:
+ raise Exception('Op {} is missing a valid type list in TOSA_OP_LIST'.format(op))
+
+ try:
+ opcode = self.TOSA_OP_LIST[op]['op']
+ except KeyError as e:
+ raise Exception('Op {} is missing the Op field in TOSA_OP_LIST'.format(op))
+
+ # Put in default rank range, if missing
+ try:
+ rank = self.TOSA_OP_LIST[op]['rank']
+ except KeyError:
+ self.TOSA_OP_LIST[op]['rank'] = self.DEFAULT_RANK_RANGE
+
+ # Tensor operator list
+ # 'op': op name
+ # 'operands': tuple of (placeholder, const) operands
+ # 'rank': optional, restricts rank to tuple inclusive of (min, max), if not specified, defaults to (1, 4)
+ # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
+ # 'types': array of datatypes to be tested
+ TYPE_FP = [ DType.FLOAT ]
+
+ # Type with an aint8
+ TYPE_INT = [ DType.AINT8, DType.INT16, DType.INT32 ] # Most operators support AINT8 instead of INT8, excludes INT4
+ TYPE_INT_FP = [ DType.AINT8, DType.INT16, DType.INT32, DType.FLOAT ] # Most operators support AINT8 instead of INT8, excludes INT4
+
+ # Types with an int8
+ TYPE_PURE_INT = [ DType.INT8, DType.INT16, DType.INT32 ] # Note: excludes INT4
+ TYPE_PURE_INT_FP = [ DType.INT8, DType.INT16, DType.INT32, DType.FLOAT ] # Note: excludes INT4
+ TYPE_BOOL = [ DType.BOOL ]
+ TYPE_FI32 = [ DType.FLOAT, DType.INT32 ]
+ TYPE_FIB = [ DType.FLOAT, DType.AINT8, DType.INT8, DType.INT16, DType.INT32, DType.BOOL ]
+ TYPE_FI16 = [ DType.FLOAT, DType.INT16 ]
+
+ TYPE_NARROW_INT_FP = [ DType.AINT8, DType.INT16, DType.FLOAT ]
+
+ DEFAULT_RANK_RANGE = (1, 4)
+
+ TOSA_OP_LIST = {
+ # Binary ops
+ 'add':
+ { 'op': Op.ADD,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_FI32 },
+
+ 'arithmetic_right_shift':
+ { 'op': Op.ARITHMETIC_RIGHT_SHIFT,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_PURE_INT },
+
+ 'bitwise_and':
+ { 'op': Op.BITWISE_AND,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_INT },
+
+ 'bitwise_or':
+ { 'op': Op.BITWISE_OR,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_INT },
+
+ 'bitwise_xor':
+ { 'op': Op.BITWISE_XOR,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_INT },
+
+ 'logical_and':
+ { 'op': Op.LOGICAL_AND,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_BOOL },
+
+ 'logical_left_shift':
+ { 'op': Op.LOGICAL_LEFT_SHIFT,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_PURE_INT },
+
+ 'logical_right_shift':
+ { 'op': Op.LOGICAL_RIGHT_SHIFT,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_PURE_INT },
+
+ 'logical_or':
+ { 'op': Op.LOGICAL_OR,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_BOOL },
+
+ 'logical_xor':
+ { 'op': Op.LOGICAL_XOR,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_BOOL },
+
+ 'max':
+ { 'op': Op.MAXIMUM,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_FI32 },
+
+ 'min':
+ { 'op': Op.MINIMUM,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_FI32 },
+
+ 'mul':
+ { 'op': Op.MUL,
+ 'operands': (2, 0),
+ 'build_fcn': (build_mul, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_PURE_INT_FP },
+
+ 'pow':
+ { 'op': Op.POW,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FP },
+
+ 'sub':
+ { 'op': Op.SUB,
+ 'operands': (2, 0),
+ 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_FI32 },
+
+ 'table':
+ { 'op': Op.TABLE,
+ # Use the automatic generation functions to create the input array
+ # but create the table tensor in the build function, as it may be
+ # a different type from the input
+ 'operands': (1, 0),
+ 'build_fcn': (build_table, TosaTensorGen.tgBasic, None),
+ 'types': [ DType.INT16 ] },
+
+ 'argmax':
+ { 'op': Op.ARGMAX,
+ 'operands': (1, 0),
+ 'build_fcn': (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ 'types': TYPE_FP },
+
+ # Templated operator. Filled in by createDynamicOpLists
+ 'conv2d_TEMPLATE':
+ { 'op': Op.CONV2D,
+ 'operands': (1, 2),
+ 'rank': (4, 4),
+ 'build_fcn': (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
+ 'qgen': TosaQuantGen.qgConv,
+ 'types': TYPE_FP,
+ 'template': True },
+
+ # Templated operator. Filled in by createDynamicOpLists
+ 'depthwise_conv2d_TEMPLATE':
+ { 'op': Op.DEPTHWISE_CONV2D,
+ 'operands': (1, 2),
+ 'filter': [1, 1],
+ 'rank': (4, 4),
+ 'build_fcn': (build_depthwise_conv2d, TosaTensorGen.tgDepthwiseConv2D, TosaArgGen.agConv2D),
+ 'qgen': TosaQuantGen.qgConv,
+ 'types': TYPE_FP,
+ 'template': True },
+
+ # Templated operator. Filled in by createDynamicOpLists
+ 'transpose_conv2d_TEMPLATE':
+ { 'op': Op.TRANSPOSE_CONV2D,
+ 'operands': (1, 1),
+ 'rank': (4, 4),
+ 'build_fcn': (build_transpose_conv2d, TosaTensorGen.tgTransposeConv2D, TosaArgGen.agTransposeConv2D),
+ 'qgen': TosaQuantGen.qgConv,
+ 'types': TYPE_FP,
+ 'template': True },
+
+ 'fully_connected':
+ { 'op': Op.FULLY_CONNECTED,
+ 'operands': (2, 0),
+ 'rank': (2, 2),
+ 'build_fcn': (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
+ 'qgen': TosaQuantGen.qgConv,
+ 'types': TYPE_FP },
+
+ 'matmul':
+ { 'op': Op.MATMUL,
+ 'operands': (2, 0),
+ 'rank': (2, 2),
+ 'build_fcn': (build_matmul, TosaTensorGen.tgMatmul, None),
+ 'qgen': TosaQuantGen.qgMatmul,
+ 'types': TYPE_NARROW_INT_FP },
+
+ # Unary operators
+ 'abs':
+ { 'op': Op.ABS,
+ 'operands': (1, 0),
+ 'build_fcn': (build_unary, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FI32 },
+
+ 'bitwise_not':
+ { 'op': Op.BITWISE_NOT,
+ 'operands': (1, 0),
+ 'build_fcn': (build_unary, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_INT },
+
+ 'ceil':
+ { 'op': Op.CEIL,
+ 'operands': (1, 0),
+ 'build_fcn': (build_unary, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FP },
+
+ 'clz':
+ { 'op': Op.CLZ,
+ 'operands': (1, 0),
+ 'build_fcn': (build_unary, TosaTensorGen.tgBasic, None),
+ 'types': [ DType.INT32 ] },
+
+ 'exp':
+ { 'op': Op.EXP,
+ 'operands': (1, 0),
+ 'build_fcn': (build_unary, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FP },
+
+ 'floor':
+ { 'op': Op.FLOOR,
+ 'operands': (1, 0),
+ 'build_fcn': (build_unary, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FP },
+
+ 'log':
+ { 'op': Op.LOG,
+ 'operands': (1, 0),
+ 'build_fcn': (build_unary, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FP },
+
+ 'floor':
+ { 'op': Op.FLOOR,
+ 'operands': (1, 0),
+ 'build_fcn': (build_unary, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FP },
+
+ 'logical_not':
+ { 'op': Op.LOGICAL_NOT,
+ 'operands': (1, 0),
+ 'build_fcn': (build_unary, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_BOOL },
+
+ 'negate':
+ { 'op': Op.NEGATE,
+ 'operands': (1, 0),
+ 'build_fcn': (build_unary, TosaTensorGen.tgBasic, None),
+ 'qgen': TosaQuantGen.qgUnary,
+ 'types': TYPE_INT_FP },
+
+ 'reciprocal':
+ { 'op': Op.RECIPROCAL,
+ 'operands': (1, 0),
+ 'build_fcn': (build_unary, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FP },
+
+ 'rsqrt':
+ { 'op': Op.RSQRT,
+ 'operands': (1, 0),
+ 'build_fcn': (build_unary, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FP },
+
+ # Ternary operators
+ 'select':
+ { 'op': Op.SELECT,
+ 'operands': (3, 0),
+ 'build_fcn': (build_select, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_FIB },
+
+ # Comparison operators
+ 'equal':
+ { 'op': Op.EQUAL,
+ 'operands': (2, 0),
+ 'build_fcn': (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_FI32 },
+
+ 'greater_equal':
+ { 'op': Op.GREATER_EQUAL,
+ 'operands': (2, 0),
+ 'build_fcn': (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_FI32 },
+
+ 'greater':
+ { 'op': Op.GREATER,
+ 'operands': (2, 0),
+ 'build_fcn': (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
+ 'types': TYPE_FI32 },
+
+ # Pooling operators
+ 'avg_pool2d':
+ { 'op': Op.AVG_POOL2D,
+ 'operands': (1, 0),
+ 'rank': (4, 4),
+ 'build_fcn': (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
+ 'qgen': TosaQuantGen.qgUnary,
+ 'types': TYPE_NARROW_INT_FP },
+
+
+ 'max_pool2d':
+ { 'op': Op.MAX_POOL2D,
+ 'operands': (1, 0),
+ 'rank': (4, 4),
+ 'build_fcn': (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
+ 'types': TYPE_NARROW_INT_FP },
+
+ # Reduce operators
+ 'reduce_any':
+ { 'op': Op.REDUCE_ANY,
+ 'operands': (1, 0),
+ 'build_fcn': (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ 'types': TYPE_BOOL },
+
+ 'reduce_all':
+ { 'op': Op.REDUCE_ALL,
+ 'operands': (1, 0),
+ 'build_fcn': (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ 'types': TYPE_BOOL },
+
+ 'reduce_max':
+ { 'op': Op.REDUCE_MAX,
+ 'operands': (1, 0),
+ 'build_fcn': (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ 'types': TYPE_INT_FP },
+
+ 'reduce_min':
+ { 'op': Op.REDUCE_MAX,
+ 'operands': (1, 0),
+ 'build_fcn': (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ 'types': TYPE_INT_FP },
+
+ 'reduce_product':
+ { 'op': Op.REDUCE_PRODUCT,
+ 'operands': (1, 0),
+ 'build_fcn': (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ 'types': TYPE_FP },
+
+ 'reduce_sum':
+ { 'op': Op.REDUCE_SUM,
+ 'operands': (1, 0),
+ 'build_fcn': (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ 'types': TYPE_FI32 },
+
+ # Activation functions
+ 'clamp':
+ { 'op': Op.CLAMP,
+ 'operands': (1, 0),
+ 'build_fcn': (build_clamp, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_NARROW_INT_FP },
+
+ 'relun':
+ { 'op': Op.RELUN,
+ 'operands': (1, 0),
+ 'build_fcn': (build_relun, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FI32 },
+
+ 'sigmoid':
+ { 'op': Op.SIGMOID,
+ 'operands': (1, 0),
+ 'build_fcn': (build_sigmoid, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FP },
+
+ 'tanh':
+ { 'op': Op.TANH,
+ 'operands': (1, 0),
+ 'build_fcn': (build_tanh, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FP },
+
+ # Data layout operators
+ 'concat':
+ { 'op': Op.CONCAT,
+ 'operands': (2, 0),
+ 'build_fcn': (build_concat, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ 'types': TYPE_FIB },
+
+ 'pad':
+ { 'op': Op.PAD,
+ 'operands': (1, 0),
+ 'build_fcn': (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
+ 'qgen': TosaQuantGen.qgPad,
+ 'types': TYPE_FIB },
+
+ 'reshape':
+ { 'op': Op.RESHAPE,
+ 'operands': (1, 0),
+ 'build_fcn': (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
+ 'types': TYPE_FIB },
+
+ 'reverse':
+ { 'op': Op.REVERSE,
+ 'operands': (1, 0),
+ 'build_fcn': (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ 'types': TYPE_FIB },
+
+ 'slice':
+ { 'op': Op.SLICE,
+ 'operands': (1, 0),
+ 'build_fcn': (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
+ 'types': TYPE_FIB },
+
+ 'tile':
+ { 'op': Op.TILE,
+ 'operands': (1, 0),
+ 'build_fcn': (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
+ 'types': TYPE_FIB },
+
+ 'transpose':
+ { 'op': Op.TRANSPOSE,
+ 'operands': (1, 0),
+ 'rank': (2, 4), # Do not allow tranpose on rank=1
+ 'build_fcn': (build_transpose, TosaTensorGen.tgBasic, TosaArgGen.agTranspose),
+ 'types': TYPE_FIB },
+
+ # Scatter/Gather
+ 'gather':
+ { 'op': Op.GATHER,
+ 'operands': (1, 0),
+ 'build_fcn': (build_gather, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ 'types': TYPE_INT },
+
+
+ # Image operations
+ 'resize':
+ { 'op': Op.RESIZE,
+ 'operands': (1, 0),
+ 'rank': (4, 4),
+ 'build_fcn': ( build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
+ 'types': [ DType.INT8, DType.INT16 ] },
+
+
+ # Data nodes
+ 'placeholder':
+ { 'op': Op.PLACEHOLDER,
+ 'operands': (1, 0),
+ 'build_fcn': ( build_placeholder, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FIB },
+
+ 'const':
+ { 'op': Op.CONST,
+ 'operands': (1, 0),
+ 'build_fcn': ( build_placeholder, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FIB },
+
+
+ 'identity':
+ { 'op': Op.IDENTITY,
+ 'operands': (1, 0),
+ 'build_fcn': ( build_unary, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FIB },
+
+
+ 'identityn':
+ { 'op': Op.IDENTITYN,
+ 'operands': (2, 0),
+ 'build_fcn': ( build_identityn, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_FIB },
+
+ # Type conversion
+ 'cast':
+ { 'op': Op.CAST,
+ 'operands': (1, 0),
+ 'build_fcn': ( build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast ),
+ 'types': [ DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL ] },
+
+ 'rescale':
+ { 'op': Op.RESCALE,
+ 'operands': (1, 0),
+ 'build_fcn': ( build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale ),
+ 'types': [ DType.AINT8, DType.INT16, DType.INT32, DType.INT48 ] },
+
+ # Custom
+ # Not implemented.
+
+ # Control flow
+
+ # Two varients of cond_if, one that generates one of two constant tensors (no
+ # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors
+ # (two inputs to the basic blocks, one output)
+ 'cond_if_const':
+ { 'op': Op.COND_IF,
+ 'operands': (0, 2),
+ 'build_fcn': ( build_cond_if_const, TosaTensorGen.tgBasic, TosaArgGen.agCondIf ),
+ 'types': [ DType.BOOL ] },
+
+ 'cond_if_binary':
+ { 'op': Op.COND_IF,
+ 'operands': (2, 0),
+ 'build_fcn': ( build_cond_if_binary, TosaTensorGen.tgBasic, TosaArgGen.agCondIf ),
+ 'types': TYPE_FI32 },
+
+ # while_loop
+ 'while_loop':
+ { 'op': Op.WHILE_LOOP,
+ 'operands': (0, 1),
+ 'build_fcn': ( build_while_loop, TosaTensorGen.tgBasic, TosaArgGen.agWhileLoop ),
+ 'types': [DType.INT32] },
+
+
+ }
+
+class OutputShaper:
+ # Methods in this class compute the expected output shape and datatype
+ # for common classes of operations
+ def __init__(self):
+ pass
+
+ # These methods return arguments that can be used for
+ # creating a new output tensor
+ @staticmethod
+ def binaryBroadcastOp(ser, a, b):
+ assert(len(a.shape) == len(b.shape))
+ assert(a.dtype == b.dtype)
+
+ shape = []
+ for i in range(len(a.shape)):
+ if a.shape[i] == 1:
+ shape.append(b.shape[i])
+ else:
+ shape.append(a.shape[i])
+
+ return ser.addOutput(shape, a.dtype, a.usage, a.dformat)
+
+ @staticmethod
+ def binaryNonBroadcastOp(ser, a, b):
+ assert(len(a.shape) == len(b.shape))
+ assert(a.dtype == b.dtype)
+
+ shape = []
+ for i in range(len(a.shape)):
+ assert(a.shape[i] == b.shape[i])
+ shape.append(a.shape[i])
+
+ return ser.addOutput(shape, a.dtype, a.usage, a.dformat)
+
+ @staticmethod
+ def unaryOp(ser, a):
+ return ser.addOutput(a.shape, a.dtype, a.usage, a.dformat)
+
+ @staticmethod
+ def selectOp(ser, cond, a, b):
+ assert(len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape))
+ assert(a.dtype == b.dtype)
+
+ shape = []
+ for i in range(len(a.shape)):
+ shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
+
+ return ser.addOutput(shape, a.dtype, a.usage, a.dformat)
+
+ @staticmethod
+ def binaryComparisonOp(ser, a, b):
+ assert(len(a.shape) == len(b.shape))
+ assert(a.dtype == b.dtype)
+
+ # Do broadcast
+ shape = []
+ for i in range(len(a.shape)):
+ if a.shape[i] == 1:
+ shape.append(b.shape[i])
+ else:
+ shape.append(a.shape[i])
+
+ # Force the output type to bool
+ return ser.addOutput(shape, DType.BOOL, a.usage, a.dformat)
+
+ @staticmethod
+ def reduceOp(ser, a, axis):
+
+ shape = a.shape.copy()
+
+ shape[axis] = 1
+
+ return ser.addOutput(shape, a.dtype, a.usage, a.dformat)
+
+ @staticmethod
+ def argmaxOp(ser, a, axis):
+ shape = a.shape.copy()
+ del shape[axis]
+ return ser.addOutput(shape, DType.INT32, a.usage, a.dformat)
+
+ @staticmethod
+ def conv2dOp(ser, ifm, filter, strides, padding, dilations):
+
+ # IFM: NHWC
+ # Filter: OHWI
+ # OFM: NHWC
+
+ if len(padding) == 2:
+ # Expand padding to 4 parameters in the case of transpose_conv2d
+ # From H,W to T,B,L,R
+ padding = [padding[0], padding[0], padding[1], padding[1]]
+
+ h = (ifm.shape[1] - filter.shape[1] - (filter.shape[1] - 1) * (dilations[0] - 1) + \
+ padding[0] + padding[1]) // strides[0] + 1
+
+ w = (ifm.shape[2] - filter.shape[2] - (filter.shape[2] - 1) * (dilations[1] - 1) + \
+ padding[2] + padding[3]) // strides[1] + 1
+
+ if h <= 0 or w <= 0:
+ # Invalid test parameters?
+ h = 0
+ w = 0
+ ser.setExpectedFailure(True, 'Invalid combination of conv2d parameters')
+
+ ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
+
+ if ifm.dtype == DType.AINT8 or ifm.dtype == DType.INT8:
+ out_dtype = DType.INT32
+ elif ifm.dtype == DType.INT16:
+ out_dtype = DType.INT48
+ elif ifm.dtype == DType.FLOAT:
+ out_dtype = DType.FLOAT
+ else:
+ raise Exception('Unsupported input dtype: {}'.format(ifm.dtype))
+
+ return ser.addOutput(ofm_shape, out_dtype, ifm.usage, ifm.dformat)
+
+ @staticmethod
+ def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
+ # IFM: NHWC
+ # Filter: HWCM
+ # OFM: NHW C*M
+ h = (ifm.shape[1] - filter.shape[0] - (filter.shape[0] - 1) * (dilations[0] - 1) + \
+ padding[0] + padding[1]) // strides[0] + 1
+
+ w = (ifm.shape[2] - filter.shape[1] - (filter.shape[1] - 1) * (dilations[1] - 1) + \
+ padding[2] + padding[3]) // strides[1] + 1
+
+ if h <= 0 or w <= 0:
+ # Invalid test parameters?
+ h = 0
+ w = 0
+ ser.setExpectedFailure(True, 'Invalid combination of conv2d parameters')
+
+ ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
+
+ if ifm.dtype == DType.AINT8 or ifm.dtype == DType.INT8:
+ out_dtype = DType.INT32
+ elif ifm.dtype == DType.INT16:
+ out_dtype = DType.INT48
+ elif ifm.dtype == DType.FLOAT:
+ out_dtype = DType.FLOAT
+ else:
+ raise Exception('Unsupported input dtype: {}'.format(ifm.dtype))
+
+ return ser.addOutput(ofm_shape, out_dtype, ifm.usage, ifm.dformat)
+
+
+ @staticmethod
+ def pool2dOp(ser, ifm, kernel, stride, pad):
+ # input: NHWC
+ h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
+ w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
+
+ if h <= 0 or w <= 0:
+ # Invalid test parameters?
+ h = 0
+ w = 0
+ ser.setExpectedFailure(True, 'Invalid combination of pooling parameters')
+
+ ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
+ return ser.addOutput(ofm_shape, ifm.dtype, ifm.usage, ifm.dformat)
+
+ @staticmethod
+ def fullyConnectedOp(ser, input, filter):
+ # input: N, IC
+ # filter: OC, IC
+ # output: N, OC
+
+ output_shape = [input.shape[0], filter.shape[0]]
+
+ if input.dtype == DType.AINT8 or input.dtype == DType.INT8:
+ out_dtype = DType.INT32
+ elif input.dtype == DType.INT16:
+ out_dtype = DType.INT48
+ elif input.dtype == DType.FLOAT:
+ out_dtype = DType.FLOAT
+ else:
+ raise Exception('Unsupported input dtype: {}'.format(input.dtype))
+
+ return ser.addOutput(output_shape, out_dtype, input.usage, input.dformat)
+
+ @staticmethod
+ def matmulOp(ser, a, b):
+ # a: M, K
+ # b: K, N
+ # out: M, N
+
+ output_shape = [a.shape[0], b.shape[1]]
+
+
+ if a.dtype == DType.AINT8:
+ out_dtype = DType.INT32
+ elif a.dtype == DType.INT16:
+ out_dtype = DType.INT48
+ elif a.dtype == DType.FLOAT:
+ out_dtype = DType.FLOAT
+ else:
+ raise Exception('UNsupported input dtype for matmul: {}'.format(a.dtype))
+
+ return ser.addOutput(output_shape, out_dtype, a.usage, a.dformat)
+
+ @staticmethod
+ def concatOp(ser, a, b, axis):
+
+ output_shape = a.shape.copy()
+ output_shape[axis] = a.shape[axis] + b.shape[axis]
+
+ return ser.addOutput(output_shape, a.dtype, a.usage, a.dformat)
+
+ @staticmethod
+ def padOp(ser, a, padding):
+
+ output_shape = a.shape.copy()
+
+ for i in range(len(output_shape)):
+ output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
+
+ return ser.addOutput(output_shape, a.dtype, a.usage, a.dformat)
+
+ @staticmethod
+ def reshapeOp(ser, a, shape):
+ output_shape = shape.copy()
+
+ totalElements = 1
+ for i in a.shape:
+ totalElements *= i
+
+ # If there are any -1 elements, figure out what that dimension must be
+ totalOutputElements = 1
+ for i in output_shape:
+ if i != -1:
+ totalOutputElements *= i
+
+ # And fill it in
+ for i in range(len(output_shape)):
+ if output_shape[i] == -1:
+ output_shape[i] = totalElements // totalOutputElements
+
+ return ser.addOutput(output_shape, a.dtype, a.usage, a.dformat)
+
+ @staticmethod
+ def sliceOp(ser, a, begin, size):
+
+ output_shape = size.copy()
+ return ser.addOutput(output_shape, a.dtype, a.usage, a.dformat)
+
+ @staticmethod
+ def tileOp(ser, a, multiples):
+
+ output_shape = a.shape.copy()
+ assert(len(multiples) == len(output_shape))
+
+ for i in range(len(output_shape)):
+ output_shape[i] = a.shape[i] * multiples[i]
+
+ return ser.addOutput(output_shape, a.dtype, a.usage, a.dformat)
+
+ @staticmethod
+ def transposeOp(ser, a, perms):
+ output_shape = a.shape.copy()
+ assert(len(perms) == len(output_shape))
+
+ for i in range(len(output_shape)):
+ output_shape[i] = a.shape[perms[i]]
+
+ return ser.addOutput(output_shape, a.dtype, a.usage, a.dformat)
+
+ @staticmethod
+ def gatherOp(ser, values, indicies, axis):
+ # indicies minus the axis + values - the indexes used to look up values.
+ output_shape = [*values.shape[0:axis], indicies.shape[0], *values.shape[axis+1:]]
+
+ return ser.addOutput(output_shape, values.dtype, indicies.usage, indicies.dformat)
+
+ @staticmethod
+ def tableOp(ser, input, table):
+ # Same shape as the input, but with the type of the table.
+ return ser.addOutput(input.shape, DType.INT32, input.usage, input.dformat)
+
+ @staticmethod
+ def resizeOp(ser, input, mode, stride, offset, shift, output_dims, output_dtype):
+
+ output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
+
+ if stride[0] <= 0 or stride[1] <= 0:
+ ser.setExpectedFailure(True, 'Negative or zero stride')
+
+ return ser.addOutput(output_dims, output_dtype, input.usage, input.dformat)
+
+ @staticmethod
+ def typeConversionOp(ser, val, out_dtype):
+ return ser.addOutput(val.shape, out_dtype, val.usage, val.dformat)
+
+ @staticmethod
+ def transposeConv2DOp(ser, ifm, output_shape):
+ if ifm.dtype == DType.AINT8 or ifm.dtype == DType.INT8:
+ out_dtype = DType.INT32
+ elif ifm.dtype == DType.INT16:
+ out_dtype = DType.INT48
+ elif ifm.dtype == DType.FLOAT:
+ out_dtype = DType.FLOAT
+ else:
+ raise Exception('Unsupported input dtype: {}'.format(ifm.dtype))
+
+ if output_shape[1] <= 0 or output_shape[2] <= 0:
+ ser.setExpectedFailure(True, 'Negative output shape')
+
+ return ser.addOutput(output_shape, out_dtype, ifm.usage, ifm.dformat)
diff --git a/verif/tosa_test_runner.py b/verif/tosa_test_runner.py
new file mode 100644
index 0000000..6549192
--- /dev/null
+++ b/verif/tosa_test_runner.py
@@ -0,0 +1,63 @@
+import os
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import shlex
+import subprocess
+from enum import IntEnum, unique
+
+def run_sh_command(args, full_cmd, capture_output=False):
+ '''Utility function to run an external command. Optionally return captured stdout/stderr'''
+
+ # Quote the command line for printing
+ full_cmd_esc = [ shlex.quote(x) for x in full_cmd ]
+
+ if args.verbose:
+ print('### Running {}'.format(' '.join(full_cmd_esc)))
+
+ if capture_output:
+ rc = subprocess.run(full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ if rc.returncode != 0:
+ print(rc.stdout.decode('utf-8'))
+ print(rc.stderr.decode('utf-8'))
+ raise Exception('Error running command: {}.\n{}'.format(' '.join(full_cmd_esc), rc.stderr.decode('utf-8')))
+ return (rc.stdout, rc.stderr)
+ else:
+ rc = subprocess.run(full_cmd)
+ if rc.returncode != 0:
+ raise Exception('Error running command: {}'.format(' '.join(full_cmd_esc)))
+
+class TosaTestRunner:
+
+ def __init__(self, args, runnerArgs, testDir):
+
+ self.args = args
+ self.runnerArgs = runnerArgs
+ self.testDir = testDir
+
+ # Load the json test file
+ with open(os.path.join(testDir, 'desc.json'), 'r') as fd:
+ self.testDesc = json.load(fd)
+
+ def runModel(self):
+ pass
+
+ class Result(IntEnum):
+ EXPECTED_PASS = 0
+ EXPECTED_FAILURE = 1
+ UNEXPECTED_PASS = 2
+ UNEXPECTED_FAILURE = 3
+ INTERNAL_ERROR = 4
diff --git a/verif/tosa_verif_build_tests.py b/verif/tosa_verif_build_tests.py
new file mode 100755
index 0000000..19eb2f4
--- /dev/null
+++ b/verif/tosa_verif_build_tests.py
@@ -0,0 +1,136 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import sys
+import re
+import os
+import subprocess
+import shlex
+import json
+import glob
+import math
+import queue
+import threading
+import traceback
+
+
+from enum import IntEnum, Enum, unique
+from datetime import datetime
+
+# Include the ../shared directory in PYTHONPATH
+parent_dir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(os.path.join(parent_dir, '..', 'scripts'))
+sys.path.append(os.path.join(parent_dir, '..', 'scripts', 'xunit'))
+import xunit
+from tosa_serializer import *
+from tosa_test_gen import TosaTestGen
+import tosa
+
+# Used for parsing a comma-separated list of integers in a string
+# to an actual list of integers
+def str_to_list(in_s):
+ '''Converts a comma-separated list of string integers to a python list of ints'''
+ lst = in_s.split(',')
+ out_list = []
+ for i in lst:
+ out_list.append(int(i))
+ return out_list
+
+def auto_int(x):
+ '''Converts hex/dec argument values to an int'''
+ return int(x, 0)
+
+def parseArgs():
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-o', dest='output_dir', type=str, default='vtest',
+ help='Test output directory')
+
+ parser.add_argument('--seed', dest='random_seed', default=42, type=int,
+ help='Random seed for test generation')
+
+ parser.add_argument('--filter', dest='filter', default='', type=str,
+ help='Filter operator test names by this expression')
+
+ parser.add_argument('-v', '--verbose', dest='verbose', action='count',
+ help='Verbose operation')
+
+ # Constraints on tests
+ parser.add_argument('--tensor-dim-range', dest='tensor_shape_range', default='1,64',
+ type=lambda x: str_to_list(x),
+ help='Min,Max range of tensor shapes')
+
+ parser.add_argument('--max-batch-size', dest='max_batch_size', default=1, type=int,
+ help='Maximum batch size for NHWC tests')
+
+ parser.add_argument('--max-conv-padding', dest='max_conv_padding', default=1, type=int,
+ help='Maximum padding for Conv tests')
+
+ parser.add_argument('--max-conv-dilation', dest='max_conv_dilation', default=2, type=int,
+ help='Maximum dilation for Conv tests')
+
+ parser.add_argument('--max-conv-stride', dest='max_conv_stride', default=2, type=int,
+ help='Maximum stride for Conv tests')
+
+ parser.add_argument('--max-pooling-padding', dest='max_pooling_padding', default=1, type=int,
+ help='Maximum padding for pooling tests')
+
+ parser.add_argument('--max-pooling-stride', dest='max_pooling_stride', default=2, type=int,
+ help='Maximum stride for pooling tests')
+
+ parser.add_argument('--max-pooling-kernel', dest='max_pooling_kernel', default=2, type=int,
+ help='Maximum padding for pooling tests')
+
+ parser.add_argument('--num-rand-permutations', dest='num_rand_permutations', default=6, type=int,
+ help='Number of random permutations for a given shape/rank for randomly-sampled parameter spaces')
+
+ # Targetting a specific shape/rank/dtype
+ parser.add_argument('--target-shape', dest='target_shapes', action='append', default=[], type=lambda x: str_to_list(x),
+ help='Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)')
+
+ parser.add_argument('--target-rank', dest='target_ranks', action='append', default=None, type=lambda x: auto_int(x),
+ help='Create tests with a particular input tensor rank')
+
+ parser.add_argument('--target-dtype', dest='target_dtypes', action='append', default=None, type=lambda x: dtype_str_to_val(x),
+ help='Create test with a particular DType (may be repeated)')
+
+ args = parser.parse_args()
+
+ return args
+
+def main():
+
+
+ args = parseArgs()
+
+ ttg = TosaTestGen(args)
+
+ testList = []
+ for op in ttg.TOSA_OP_LIST:
+ if re.match(args.filter + '.*', op):
+ testList.extend(ttg.genOpTestList(op, shapeFilter=args.target_shapes, rankFilter=args.target_ranks, dtypeFilter=args.target_dtypes))
+
+ print('{} matching tests'.format(len(testList)))
+ for opName, testStr, dtype, shapeList, testArgs in testList:
+ print(testStr)
+ ttg.serializeTest(opName, testStr, dtype, shapeList, testArgs)
+ print('Done creating {} tests'.format(len(testList)))
+
+
+if __name__ == '__main__':
+ exit(main())
diff --git a/verif/tosa_verif_run_ref.py b/verif/tosa_verif_run_ref.py
new file mode 100755
index 0000000..2284e35
--- /dev/null
+++ b/verif/tosa_verif_run_ref.py
@@ -0,0 +1,198 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import sys
+import re
+import os
+import subprocess
+import shlex
+import json
+import glob
+import math
+import queue
+import threading
+import traceback
+import importlib
+
+
+from enum import IntEnum, Enum, unique
+from datetime import datetime
+
+# Include the ../shared directory in PYTHONPATH
+parent_dir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(os.path.join(parent_dir, '..', 'scripts'))
+sys.path.append(os.path.join(parent_dir, '..', 'scripts', 'xunit'))
+import xunit
+import tosa
+from tosa_test_gen import TosaTestGen
+from tosa_test_runner import TosaTestRunner
+
+no_color_printing = False
+#from run_tf_unit_test import LogColors, print_color, run_sh_command
+
+def parseArgs():
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-t', '--test', dest='test', type=str, nargs='+',
+ help='Test(s) to run')
+ parser.add_argument('--seed', dest='random_seed', default=42, type=int,
+ help='Random seed for test generation')
+ parser.add_argument('--ref-model-path', dest='ref_model_path',
+ default='build/reference_model/tosa_reference_model', type=str,
+ help='Path to reference model executable')
+ parser.add_argument('--ref-debug', dest='ref_debug', default='', type=str,
+ help='Reference debug flag (low, med, high)')
+ parser.add_argument('--ref-intermediates', dest='ref_intermediates', default=0, type=int,
+ help='Reference model dumps intermediate tensors')
+ parser.add_argument('-v', '--verbose', dest='verbose', action='count',
+ help='Verbose operation')
+ parser.add_argument('-j', '--jobs', dest='jobs', type=int, default=1,
+ help='Number of parallel jobs')
+ parser.add_argument('--sut-module', '-s', dest='sut_module', type=str, nargs='+', default=['tosa_ref_run'],
+ help='System under test module to load (derives from TosaTestRunner). May be repeated')
+ parser.add_argument('--sut-module-args', dest='sut_module_args', type=str, nargs='+', default=[],
+ help='System under test module arguments. Use sutmodulename:argvalue to pass an argument. May be repeated.')
+ parser.add_argument('--xunit-file', dest='xunit_file', type=str, default='result.xml',
+ help='XUnit output file')
+
+ args = parser.parse_args()
+
+ # Autodetect CPU count
+ if args.jobs <= 0:
+ args.jobs = os.cpu_count()
+
+ return args
+
+def workerThread(task_queue, runnerList, args, result_queue):
+ while True:
+ try:
+ test = task_queue.get(block=False)
+ except queue.Empty:
+ break
+
+ if test is None:
+ break
+
+ msg = ''
+ start_time = datetime.now()
+ try:
+
+ for runnerModule, runnerArgs in runnerList:
+ if args.verbose:
+ print('Running runner {} with test {}'.format(runnerModule.__name__, test))
+ runner = runnerModule.TosaRefRunner(args, runnerArgs, test)
+ try:
+ rc = runner.runModel()
+ except Exception as e:
+ rc = TosaTestRunner.Result.INTERNAL_ERROR
+ except Exception as e:
+ print('Internal regression error: {}'.format(e))
+ print(''.join(traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)))
+ rc = TosaTestRunner.Result.INTERNAL_ERROR
+
+ end_time = datetime.now()
+
+ result_queue.put((test, rc, msg, end_time - start_time))
+ task_queue.task_done()
+
+ return True
+
+def loadRefModules(args):
+ # Returns a tuple of (runner_module, [argument list])
+ runnerList = []
+ for r in args.sut_module:
+ if args.verbose:
+ print('Loading module {}'.format(r))
+
+ runner = importlib.import_module(r)
+
+ # Look for arguments associated with this runner
+ runnerArgPrefix = '{}:'.format(r)
+ runnerArgList = []
+ for a in args.sut_module_args:
+ if a.startswith(runnerArgPrefix):
+ runnerArgList.append(a[len(runnerArgPrefix):])
+ runnerList.append((runner, runnerArgList))
+
+ return runnerList
+
+def main():
+ args = parseArgs()
+
+ runnerList = loadRefModules(args)
+
+ threads = []
+ taskQueue = queue.Queue()
+ resultQueue = queue.Queue()
+
+ for t in args.test:
+ taskQueue.put((t))
+
+ print('Running {} tests '.format(taskQueue.qsize()))
+
+ for i in range(args.jobs):
+ t = threading.Thread(target=workerThread, args=(taskQueue, runnerList, args, resultQueue))
+ t.setDaemon(True)
+ t.start()
+ threads.append(t)
+
+ taskQueue.join()
+
+ resultList = []
+ results = [0] * len(TosaTestRunner.Result)
+
+ while True:
+ try:
+ test, rc, msg, time_delta = resultQueue.get(block=False)
+ except queue.Empty:
+ break
+
+ resultList.append((test, rc, msg, time_delta))
+ results[rc] = results[rc] + 1
+
+ xunit_result = xunit.xunit_results('Regressions')
+ xunit_suite = xunit_result.create_suite('Unit tests')
+
+ # Sort by test name
+ for test, rc, msg, time_delta in sorted(resultList, key=lambda tup: tup[0]):
+ test_name = test
+ xt = xunit.xunit_test(test_name, 'reference')
+
+ xt.time = str(float(time_delta.seconds) + (float(time_delta.microseconds) * 1e-6))
+
+ if rc == TosaTestRunner.Result.EXPECTED_PASS or rc == TosaTestRunner.Result.EXPECTED_FAILURE:
+ if args.verbose:
+ print('{} {}'.format(rc.name, test_name))
+ else:
+ xt.failed(msg)
+ print('{} {}'.format(rc.name, test_name))
+
+ xunit_suite.tests.append(xt)
+ resultQueue.task_done()
+
+ xunit_result.write_results(args.xunit_file)
+
+ print('Totals: ', end='')
+ for result in TosaTestRunner.Result:
+ print('{} {}, '.format(results[result], result.name.lower()), end ='')
+ print()
+
+ return 0
+
+if __name__ == '__main__':
+ exit(main())