aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-11-15 17:07:37 -0800
committerKevin Cheng <kevin.cheng@arm.com>2021-11-15 17:07:55 -0800
commit80a022fd103b26a03a04e0565c4d263f73d950b8 (patch)
tree6fd26a5210cf2fa6650610077ac530680e8c4717
parente351a65ce85511dea24056554722d661dc7fee42 (diff)
downloadtosa_mlir_translator-80a022fd103b26a03a04e0565c4d263f73d950b8.tar.gz
First commit of tosa serialize passes
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I1551017706f6e8af604792f48cdeb49b4da7ef0d
-rw-r--r--.clang-format1
-rw-r--r--.gitmodules3
-rw-r--r--CMakeLists.txt33
-rw-r--r--LICENSE.txt219
-rw-r--r--README.md49
-rw-r--r--include/SerializationPasses.h34
-rw-r--r--include/SerializationPasses.td21
-rw-r--r--include/operator.def117
-rw-r--r--src/TosaSerialize.cpp1764
m---------third_party/serialization_lib0
10 files changed, 2241 insertions, 0 deletions
diff --git a/.clang-format b/.clang-format
new file mode 100644
index 0000000..9b3aa8b
--- /dev/null
+++ b/.clang-format
@@ -0,0 +1 @@
+BasedOnStyle: LLVM
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..5d7eb1e
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "third_party/serialization_lib"]
+ path = third_party/serialization_lib
+ url = https://review.mlplatform.org/tosa/serialization_lib
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000..968d73b
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,33 @@
+# TOSA serialization MLIR passes
+
+cmake_minimum_required(VERSION 3.13.4)
+project(MlirTosaPasses)
+
+set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ standard to conform to")
+set(CMAKE_CXX_STANDARD_REQUIRED YES)
+
+set(CMAKE_VERBOSE_MAKEFILE ON)
+
+# TOSA MLIR->Flatbuffers serialization pass
+
+include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
+include_directories(${PROJECT_SOURCE_DIR}/third_party/serialization_lib/include)
+include_directories(${PROJECT_SOURCE_DIR}/third_party/serialization_lib/third_party/flatbuffers/include)
+
+set(LLVM_TARGET_DEFINITIONS include/SerializationPasses.td)
+mlir_tablegen(include/SerializationPasses.h.inc -gen-pass-decls -name TosaSerialization)
+add_public_tablegen_target(tosa_serialization_passes_inc_gen)
+
+# Compile the TOSA serialization_lib
+add_subdirectory(third_party/serialization_lib)
+
+add_mlir_library(tosa_serialize
+ src/TosaSerialize.cpp
+
+ DEPENDS
+ mlir-headers
+ tosa_serialization_passes_inc_gen
+
+ LINK_LIBS PRIVATE
+ tosa_serialization_lib
+)
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 0000000..45d76b4
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,219 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
+---- LLVM Exceptions to the Apache 2.0 License ----
+
+As an exception, if, as a result of your compiling your source code, portions
+of this Software are embedded into an Object form of such source code, you
+may redistribute such embedded portions in such Object form without complying
+with the conditions of Sections 4(a), 4(b) and 4(d) of the License.
+
+In addition, if you combine or link compiled forms of this Software with
+software that is licensed under the GPLv2 ("Combined Software") and if a
+court of competent jurisdiction determines that the patent provision (Section
+3), the indemnity provision (Section 9) or other Section of the License
+conflicts with the conditions of the GPLv2, you may retroactively and
+prospectively choose to deem waived or otherwise exclude such Section(s) of
+the License, but only in their entirety and only with respect to the Combined
+Software.
+
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..46039b7
--- /dev/null
+++ b/README.md
@@ -0,0 +1,49 @@
+TOSA MLIR Translator
+==========================
+
+# Introduction
+
+The *TOSA MLIR Translator* repository implements translators between the TOSA MLIR
+dialect and serialized representations.
+
+The current implementation supports serialization from MLIR form to flatbuffers.
+A deserializer from flatbuffers to MLIR form is in development.
+
+# Dependencies
+
+##TOSA serialization library
+<https://review.mlplatform.org/plugins/gitiles/tosa/serialization_lib>
+The library includes a FlatBuffers schema and a C++ API for reading and writing a TOSA
+graph as a flatbuffer.
+
+# Compiling
+
+This repository does not currently build standalone. It must be included within another
+MLIR repository with a pass manager registering the passes implemented within this
+repository.
+
+The included CMake rules can be used to add this repository as a submodule.
+The include/SerializationPasses.h enables MLIR pass registration inclusion.
+
+If target "tosa_serialize" is linked correctly, you should able to see "--tosa-serialize"
+and "--tosa-serialize-json" options available in your MLIR pass manager/MLIR optimizer.
+
+# Usage
+
+To serialize a TOSA MLIR graph to TOSA flatbuffer binary file:
+
+ \<YOUR_MLIR_OPTIMIZER\> --tosa-serialize \<TOSA_MLIR_GRAPH\> \
+ --tosa-flatbuffer-filename \<TOSA_FLATBUFFER_FILENAME\>
+
+To serialize a TOSA MLIR graph to TOSA flatbuffer JSON file:
+
+ \<YOUR_MLIR_OPTIMIZER\> --tosa-serialize \<TOSA_MLIR_GRAPH\> \
+ --tosa-flatbuffer-schema \<PATH_TO_TOSA_FLATBUFFER_SCHEMA\> \
+ --tosa-flatbuffer-filename \<TOSA_FLATBUFFER_FILENAME\>
+
+ where \<PATH_TO_TOSA_FLATBUFFER_SCHEMA\> is provided within the serialization library
+ submodule in third_party/serialization_lib/schema/tosa.fbs
+
+# License
+
+The *TOSA MLIR Translator* is licensed under Apache-2.0 with LLVM Exceptions.
diff --git a/include/SerializationPasses.h b/include/SerializationPasses.h
new file mode 100644
index 0000000..0991f87
--- /dev/null
+++ b/include/SerializationPasses.h
@@ -0,0 +1,34 @@
+
+// Copyright (c) 2020-2021, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef INCLUDE_SERIALIZATION_PASSES_H
+#define INCLUDE_SERIALIZATION_PASSES_H
+
+#include <memory>
+
+#include "mlir/Pass/Pass.h" // from @llvm-project
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<OperationPass<FuncOp>> createTosaSerializePass();
+
+#define GEN_PASS_REGISTRATION
+#include "SerializationPasses.h.inc"
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // INCLUDE_SERIALIZATION_PASSES_H
diff --git a/include/SerializationPasses.td b/include/SerializationPasses.td
new file mode 100644
index 0000000..6df996e
--- /dev/null
+++ b/include/SerializationPasses.td
@@ -0,0 +1,21 @@
+// Copyright (c) 2020-2021, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+include "mlir/Pass/PassBase.td"
+
+def TosaSerializationPass : Pass<"tosa-serialize", "FuncOp"> {
+ let summary = "Generate TOSA flatbuffer serialized form";
+ let constructor = "createTosaSerializePass()";
+}
+
diff --git a/include/operator.def b/include/operator.def
new file mode 100644
index 0000000..85bb5c9
--- /dev/null
+++ b/include/operator.def
@@ -0,0 +1,117 @@
+
+// Copyright (c) 2020-2021, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ Syntax:
+ DEF_OPERATOR(MLIR_OP)
+
+ Description:
+ MLIR_OP: the operator class type, must match mlir/include/mlir/Dialect/Tosa/IR/tosa_ops.td in llvm-project repo
+*/
+
+/* tensor operators */
+DEF_OPERATOR(ArgMax)
+DEF_OPERATOR(AvgPool2d)
+DEF_OPERATOR(Conv2D)
+DEF_OPERATOR(Conv3D)
+DEF_OPERATOR(DepthwiseConv2D)
+DEF_OPERATOR(FullyConnected)
+DEF_OPERATOR(MatMul)
+DEF_OPERATOR(MaxPool2d)
+DEF_OPERATOR(TransposeConv2D)
+
+/* activation */
+DEF_OPERATOR(Clamp)
+DEF_OPERATOR(Sigmoid)
+DEF_OPERATOR(Tanh)
+
+/* elementwise - binary */
+DEF_OPERATOR(Add)
+DEF_OPERATOR(ArithmeticRightShift)
+DEF_OPERATOR(BitwiseAnd)
+DEF_OPERATOR(BitwiseOr)
+DEF_OPERATOR(BitwiseXor)
+DEF_OPERATOR(Div)
+DEF_OPERATOR(LogicalAnd)
+DEF_OPERATOR(LogicalLeftShift)
+DEF_OPERATOR(LogicalRightShift)
+DEF_OPERATOR(LogicalOr)
+DEF_OPERATOR(LogicalXor)
+DEF_OPERATOR(Maximum)
+DEF_OPERATOR(Minimum)
+DEF_OPERATOR(Mul)
+DEF_OPERATOR(Pow)
+DEF_OPERATOR(Sub)
+DEF_OPERATOR(Table)
+
+/* elementwise - unary */
+DEF_OPERATOR(Abs)
+DEF_OPERATOR(BitwiseNot)
+DEF_OPERATOR(Ceil)
+DEF_OPERATOR(Clz)
+DEF_OPERATOR(Exp)
+DEF_OPERATOR(Floor)
+DEF_OPERATOR(Log)
+DEF_OPERATOR(LogicalNot)
+DEF_OPERATOR(Negate)
+DEF_OPERATOR(Reciprocal)
+DEF_OPERATOR(Rsqrt)
+
+/* elementwise - ternary */
+DEF_OPERATOR(Select)
+
+/* logical */
+DEF_OPERATOR(Equal)
+DEF_OPERATOR(Greater)
+DEF_OPERATOR(GreaterEqual)
+
+/* reduction */
+DEF_OPERATOR(ReduceAny)
+DEF_OPERATOR(ReduceAll)
+DEF_OPERATOR(ReduceMax)
+DEF_OPERATOR(ReduceMin)
+DEF_OPERATOR(ReduceProd)
+DEF_OPERATOR(ReduceSum)
+
+/* memory operation */
+DEF_OPERATOR(Concat)
+DEF_OPERATOR(Pad)
+DEF_OPERATOR(Reshape)
+DEF_OPERATOR(Reverse)
+DEF_OPERATOR(Slice)
+DEF_OPERATOR(Tile)
+DEF_OPERATOR(Transpose)
+
+/* gather/scatter */
+DEF_OPERATOR(Gather)
+DEF_OPERATOR(Scatter)
+
+/* image */
+DEF_OPERATOR(Resize)
+
+/* quantization */
+DEF_OPERATOR(Cast)
+DEF_OPERATOR(Rescale)
+
+/* data nodes */
+DEF_OPERATOR(Const)
+DEF_OPERATOR(Identity)
+
+/* custom operations */
+DEF_OPERATOR(Custom)
+
+/* control flow operators */
+DEF_OPERATOR(If)
+DEF_OPERATOR(While)
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
new file mode 100644
index 0000000..5699ffe
--- /dev/null
+++ b/src/TosaSerialize.cpp
@@ -0,0 +1,1764 @@
+
+// Copyright (c) 2020-2021, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// TOSA flatbuffer generation
+
+#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
+#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
+#include "mlir/IR/Matchers.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "tosa_serialization_handler.h"
+#include <functional>
+#include <unordered_map>
+
+// The namespace might be confusing here. We have mlir::tosa:: defined in MLIR
+// and tosa:: defined in serialization library
+// TODO: align the namespace
+using namespace tosa;
+
+namespace cl = llvm::cl;
+
+llvm::cl::opt<std::string> tosa_flatbuffer_filename(
+ "tosa-flatbuffer-filename", llvm::cl::desc("<tosa flatbuffer filename>"),
+ llvm::cl::init("tosa_dump.tosa"), llvm::cl::value_desc("filename"));
+
+llvm::cl::opt<std::string> tosa_flatbuffer_schema(
+ "tosa-flatbuffer-schema", llvm::cl::desc("<tosa flatbuffer schema file>"),
+ llvm::cl::init(""), llvm::cl::value_desc("filename"));
+
+// Specialize mlir::Value for std::hash<T> and std::equal_to<T> to be able to
+// build std::unordered_map<mlir::Value, ...>
+namespace std {
+
+template <> struct hash<mlir::Value> {
+ std::size_t operator()(const mlir::Value &val) const {
+ return static_cast<std::size_t>(mlir::hash_value(val));
+ }
+};
+
+template <> struct equal_to<mlir::Value> {
+ bool operator()(const mlir::Value &lhs, const mlir::Value &rhs) const {
+ return (lhs == rhs);
+ }
+};
+
+} // namespace std
+
+ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
+ if (mode_str == "NEAREST_NEIGHBOR")
+ return ResizeMode_NEAREST;
+ else if (mode_str == "BILINEAR")
+ return ResizeMode_BILINEAR;
+ else
+ return ResizeMode_UNKNOWN;
+}
+
+DType Type2DType(mlir::Type element_type) {
+ if (element_type.isF64() || element_type.isF32() || element_type.isF16() ||
+ element_type.isBF16()) {
+ return DType_FLOAT;
+ } else if (element_type.isUnsignedInteger(8)) {
+ return DType_UINT8;
+ } else if (element_type.isInteger(4)) {
+ return DType_INT8;
+ } else if (element_type.isInteger(8)) {
+ return DType_INT8;
+ } else if (element_type.isInteger(16)) {
+ return DType_INT16;
+ } else if (element_type.isInteger(32)) {
+ return DType_INT32;
+ } else if (element_type.isInteger(48)) {
+ return DType_INT48;
+ }
+ // boolean in MLIR treated as integer with bitwidth 1
+ else if (element_type.isInteger(1)) {
+ return DType_BOOL;
+ }
+ return DType_UNKNOWN;
+}
+
+int GetQuantizedParameter(mlir::Type type, std::vector<float> &scale,
+ std::vector<int32_t> &zeropoint,
+ int32_t &quantized_dimension, int64_t &quant_min,
+ int64_t &quant_max) {
+ if (auto qtype = type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
+ scale.push_back(qtype.getScale());
+ zeropoint.push_back(qtype.getZeroPoint());
+ quantized_dimension = 0;
+
+ quant_min = qtype.getStorageTypeMin();
+ quant_max = qtype.getStorageTypeMax();
+ } else if (auto qtype =
+ type.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
+ scale.assign(qtype.getScales().begin(), qtype.getScales().end());
+ zeropoint.assign(qtype.getZeroPoints().begin(),
+ qtype.getZeroPoints().end());
+ quantized_dimension = qtype.getQuantizedDimension();
+
+ quant_min = qtype.getStorageTypeMin();
+ quant_max = qtype.getStorageTypeMax();
+ } else {
+ return 1;
+ }
+
+ return 0;
+}
+
+TosaQuantInfoBase *
+GetUnaryQuantInfo(mlir::tosa::UnaryOpQuantizationAttr quant_info) {
+ int32_t input_zp = quant_info.input_zp().getInt();
+ int32_t output_zp = quant_info.output_zp().getInt();
+
+ TosaQuantInfoBase *qinfo = new TosaUnaryQuantInfo(input_zp, output_zp);
+
+ return qinfo;
+}
+
+TosaQuantInfoBase *
+GetConvQuantInfo(mlir::tosa::ConvOpQuantizationAttr quant_info) {
+ int32_t input_zp = quant_info.input_zp().getInt();
+ int32_t weight_zp = quant_info.weight_zp().getInt();
+
+ TosaQuantInfoBase *qinfo = new TosaConvQuantInfo(input_zp, weight_zp);
+
+ return qinfo;
+}
+
+TosaQuantInfoBase *
+GetPadQuantInfo(mlir::tosa::PadOpQuantizationAttr quant_info) {
+ int32_t input_zp = quant_info.input_zp().getInt();
+
+ TosaQuantInfoBase *qinfo = new TosaPadQuantInfo(input_zp);
+
+ return qinfo;
+}
+
+TosaQuantInfoBase *
+GetMatMulQuantInfo(mlir::tosa::MatMulOpQuantizationAttr quant_info) {
+ int32_t a_zp = quant_info.a_zp().getInt();
+ int32_t b_zp = quant_info.b_zp().getInt();
+
+ TosaQuantInfoBase *qinfo = new TosaMatMulQuantInfo(a_zp, b_zp);
+
+ return qinfo;
+}
+
+class TosaSerializationBlockBuilder;
+
+class TosaSerializationOperatorBuilder {
+public:
+ TosaSerializationOperatorBuilder(
+ TosaSerializationBlockBuilder *_block_builder)
+ : block_builder(_block_builder) {}
+ template <typename T>
+ TosaSerializationOperator *build(mlir::Operation &op) const;
+
+private:
+ std::string GetTensorName(mlir::Value val) const;
+ TosaSerializationOperator *BuildPoolOpFromMlirOp(mlir::Operation &op,
+ Op opcode) const;
+ TosaSerializationOperator *BuildEwiseBinaryOpFromMlirOp(mlir::Operation &op,
+ Op opcode) const;
+ TosaSerializationOperator *BuildEwiseUnaryOpFromMlirOp(mlir::Operation &op,
+ Op opcode) const;
+ TosaSerializationOperator *BuildReductionOpFromMlirOp(mlir::Operation &op,
+ Op opcode) const;
+ TosaSerializationBlockBuilder *block_builder;
+};
+
+// This builder assumes each region only has only one block
+class TosaSerializationBlockBuilder {
+public:
+ friend class TosaSerializationOperatorBuilder;
+ TosaSerializationBlockBuilder(TosaSerializationBasicBlock *_block,
+ TosaSerializationHandler *_tsh,
+ mlir::Region *_region)
+ : block(_block), tsh(_tsh), region(_region) {}
+
+ mlir::LogicalResult
+ BuildAllOpsInRegion(std::vector<mlir::Value> &return_values);
+ TosaSerializationBasicBlock *GetBlock() { return block; }
+ TosaSerializationHandler *GetTsh() { return tsh; }
+
+private:
+ TosaSerializationOperator *BuildTosaSerializationOperator(
+ const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op);
+ TosaSerializationTensor *
+ BuildTosaSerializationTensor(mlir::Value val, const std::string &name);
+
+ TosaSerializationBasicBlock *block;
+ TosaSerializationHandler *tsh;
+ mlir::Region *region;
+ std::unordered_map<mlir::Value, std::string> tensor_map;
+};
+
+std::string
+TosaSerializationOperatorBuilder::GetTensorName(mlir::Value val) const {
+ if (block_builder->tensor_map.find(val) == block_builder->tensor_map.end()) {
+ llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map";
+ assert(0);
+ }
+ return block_builder->tensor_map[val];
+}
+
+// Main template to catch unimplemented translation.
+template <typename T>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build(mlir::Operation &op) const {
+ llvm::errs() << "Translation of operator " << op.getName().getStringRef()
+ << " is not implemented yet\n";
+ return nullptr;
+}
+
+/* Start translating TOSA operator */
+
+#define ASSERT_VECTOR_LENGTH(VECTOR, LENGTH) \
+ if (VECTOR.size() != LENGTH) { \
+ std::string msg; \
+ msg = std::string(#VECTOR) + " is [" + std::to_string(VECTOR.size()) + \
+ "], but expected to be [" + std::to_string(LENGTH) + "]\n"; \
+ op.emitOpError(msg.c_str()); \
+ return nullptr; \
+ }
+
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op,
+ Op opcode) const {
+ std::vector<int> pad, stride, kernel;
+
+ auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : pad_attr) {
+ pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(pad, 4);
+
+ auto stride_attr =
+ op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : stride_attr) {
+ stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(stride, 2);
+
+ auto kernel_attr =
+ op.getAttr("kernel").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : kernel_attr) {
+ kernel.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(kernel, 2);
+
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaPoolAttribute attribute(pad, kernel, stride);
+ auto quant_info = op.getAttrOfType<mlir::tosa::UnaryOpQuantizationAttr>(
+ "quantization_info");
+ QuantInfo qinfo_type;
+ TosaQuantInfoBase *qinfo;
+ if (quant_info) {
+ qinfo_type = QuantInfo_UnaryQuantInfo;
+ qinfo = GetUnaryQuantInfo(quant_info);
+ } else {
+ qinfo_type = QuantInfo_NONE;
+ qinfo = new TosaNoneQuantInfo();
+ }
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ opcode, Attribute_PoolAttribute, &attribute, qinfo_type, qinfo,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ delete qinfo;
+
+ return tyop;
+}
+
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::BuildEwiseBinaryOpFromMlirOp(
+ mlir::Operation &op, Op opcode) const {
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ opcode, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input0_name, input1_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::BuildEwiseUnaryOpFromMlirOp(
+ mlir::Operation &op, Op opcode) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ opcode, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::BuildReductionOpFromMlirOp(
+ mlir::Operation &op, Op opcode) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt();
+ TosaAxisAttribute attribute(axis);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ opcode, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+#define BUILD_OP_POOL2D(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ TosaSerializationOperator * \
+ TosaSerializationOperatorBuilder::build<mlir::tosa::MLIR_OP_NAME##Op>( \
+ mlir::Operation & op) const { \
+ return BuildPoolOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \
+ }
+
+#define BUILD_OP_ELEMENTWISE_BINARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ TosaSerializationOperator * \
+ TosaSerializationOperatorBuilder::build<mlir::tosa::MLIR_OP_NAME##Op>( \
+ mlir::Operation & op) const { \
+ return BuildEwiseBinaryOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \
+ }
+
+#define BUILD_OP_ELEMENTWISE_UNARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ TosaSerializationOperator * \
+ TosaSerializationOperatorBuilder::build<mlir::tosa::MLIR_OP_NAME##Op>( \
+ mlir::Operation & op) const { \
+ return BuildEwiseUnaryOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \
+ }
+
+#define BUILD_OP_REDUCTION(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ TosaSerializationOperator * \
+ TosaSerializationOperatorBuilder::build<mlir::tosa::MLIR_OP_NAME##Op>( \
+ mlir::Operation & op) const { \
+ return BuildReductionOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \
+ }
+
+BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D)
+BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D)
+
+BUILD_OP_ELEMENTWISE_BINARY(Add, ADD)
+BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND)
+BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR)
+BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR)
+BUILD_OP_ELEMENTWISE_BINARY(Div, INTDIV)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalOr, LOGICAL_OR)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalXor, LOGICAL_XOR)
+BUILD_OP_ELEMENTWISE_BINARY(Maximum, MAXIMUM)
+BUILD_OP_ELEMENTWISE_BINARY(Minimum, MINIMUM)
+BUILD_OP_ELEMENTWISE_BINARY(Pow, POW)
+BUILD_OP_ELEMENTWISE_BINARY(Sub, SUB)
+
+BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS)
+BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT)
+BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL)
+BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ)
+BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP)
+BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR)
+BUILD_OP_ELEMENTWISE_UNARY(Log, LOG)
+BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT)
+BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL)
+BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT)
+
+BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY)
+BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL)
+BUILD_OP_REDUCTION(ReduceMax, REDUCE_MAX)
+BUILD_OP_REDUCTION(ReduceMin, REDUCE_MIN)
+BUILD_OP_REDUCTION(ReduceProd, REDUCE_PRODUCT)
+BUILD_OP_REDUCTION(ReduceSum, REDUCE_SUM)
+
+BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL)
+BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER)
+BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL)
+
+BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID)
+BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH)
+BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY)
+BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST)
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
+ mlir::Operation &op) const {
+ std::string output_name = GetTensorName(op.getResult(0));
+ TosaSerializationTensor *ts =
+ block_builder->GetBlock()->GetTensorByName(output_name);
+ if (!ts) {
+ op.emitOpError(
+ "ERROR: serialization tensor must be built before building operator");
+ return nullptr;
+ }
+
+#if 0
+ // Gracefully handle constants of "constant unit" type which have no value
+ // by creating a numpy value of 0.
+ auto unit_val = op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::UnitAttr>();
+
+ if (unit_val)
+ {
+ std::vector<float> data = { 0.0 };
+ type = DType_FLOAT;
+ TosaSerializationHandler::ConvertF32toU8(data, u8_data);
+ }
+#endif
+
+ // Update tensor.data array with Const value attribute
+ std::vector<uint8_t> u8_data;
+ DType type = ts->GetDtype();
+ if (type == DType_FLOAT) {
+ std::vector<float> data;
+ auto dense_attr = op.getAttr(llvm::StringRef("value"))
+ .dyn_cast<mlir::DenseElementsAttr>();
+ auto val_attr =
+ op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<float>()) {
+ data.push_back(val);
+ }
+ } else if (val_attr) {
+ data.push_back((float)val_attr.getValueAsDouble());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
+ }
+ TosaSerializationHandler::ConvertF32toU8(data, u8_data);
+ } else if (type == DType_INT8) {
+ std::vector<int8_t> data;
+ auto dense_attr = op.getAttr(llvm::StringRef("value"))
+ .dyn_cast<mlir::DenseElementsAttr>();
+ auto val_attr =
+ op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<int8_t>()) {
+ data.push_back(val);
+ }
+ } else if (val_attr) {
+ data.push_back(val_attr.getInt());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
+ }
+ TosaSerializationHandler::ConvertI8toU8(data, u8_data);
+ } else if (type == DType_INT16) {
+ std::vector<int16_t> data;
+ auto dense_attr = op.getAttr(llvm::StringRef("value"))
+ .dyn_cast<mlir::DenseElementsAttr>();
+ auto val_attr =
+ op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<int16_t>()) {
+ data.push_back(val);
+ }
+ } else if (val_attr) {
+ data.push_back(val_attr.getInt());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
+ }
+ TosaSerializationHandler::ConvertI16toU8(data, u8_data);
+ } else if (type == DType_INT32) {
+ std::vector<int32_t> data;
+ auto dense_attr = op.getAttr(llvm::StringRef("value"))
+ .dyn_cast<mlir::DenseElementsAttr>();
+ auto val_attr =
+ op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<int32_t>()) {
+ data.push_back(val);
+ }
+ } else if (val_attr) {
+ data.push_back(val_attr.getInt());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
+ }
+ TosaSerializationHandler::ConvertI32toU8(data, u8_data);
+ } else if (type == DType_INT48) {
+ std::vector<int64_t> data;
+ auto dense_attr = op.getAttr(llvm::StringRef("value"))
+ .dyn_cast<mlir::DenseElementsAttr>();
+ auto val_attr =
+ op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<int64_t>()) {
+ data.push_back(val);
+ }
+ } else if (val_attr) {
+ data.push_back(val_attr.getInt());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
+ }
+ TosaSerializationHandler::ConvertI48toU8(data, u8_data);
+ } else if (type == DType_BOOL) {
+ std::vector<bool> data;
+
+ auto dense_attr = op.getAttr(llvm::StringRef("value"))
+ .dyn_cast<mlir::DenseElementsAttr>();
+ auto val_attr =
+ op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::BoolAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<bool>()) {
+ data.push_back(val);
+ }
+ } else if (val_attr) {
+ data.push_back(val_attr.getValue());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
+ }
+
+ TosaSerializationHandler::ConvertBooltoU8(data, u8_data);
+ } else {
+ op.emitOpError("Unknown element type of const attribute");
+ return nullptr;
+ }
+ ts->SetData(u8_data);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_CONST, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{}, std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>(
+ mlir::Operation &op) const {
+ std::vector<int> pad, stride, dilation;
+
+ auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : pad_attr) {
+ pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(pad, 4);
+
+ auto stride_attr =
+ op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : stride_attr) {
+ stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(stride, 2);
+
+ auto dilation_attr =
+ op.getAttr("dilation").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : dilation_attr) {
+ dilation.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(dilation, 2);
+
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string input2_name = GetTensorName(op.getOperand(2));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaConvAttribute attribute(pad, stride, dilation);
+
+ auto quant_info =
+ op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+ QuantInfo qinfo_type;
+ TosaQuantInfoBase *qinfo;
+ if (quant_info) {
+ qinfo_type = QuantInfo_ConvQuantInfo;
+ qinfo = GetConvQuantInfo(quant_info);
+ } else {
+ qinfo_type = QuantInfo_NONE;
+ qinfo = new TosaNoneQuantInfo();
+ }
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_CONV2D, Attribute_ConvAttribute, &attribute, qinfo_type, qinfo,
+ std::vector<std::string>{input0_name, input1_name, input2_name},
+ std::vector<std::string>{output_name});
+
+ delete qinfo;
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>(
+ mlir::Operation &op) const {
+ std::vector<int> pad, stride, dilation;
+
+ auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : pad_attr) {
+ pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(pad, 4);
+
+ auto stride_attr =
+ op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : stride_attr) {
+ stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(stride, 2);
+
+ auto dilation_attr =
+ op.getAttr("dilation").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : dilation_attr) {
+ dilation.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(dilation, 2);
+
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string input2_name = GetTensorName(op.getOperand(2));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaConvAttribute attribute(pad, stride, dilation);
+
+ auto quant_info =
+ op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+ QuantInfo qinfo_type;
+ TosaQuantInfoBase *qinfo;
+ if (quant_info) {
+ qinfo_type = QuantInfo_ConvQuantInfo;
+ qinfo = GetConvQuantInfo(quant_info);
+ } else {
+ qinfo_type = QuantInfo_NONE;
+ qinfo = new TosaNoneQuantInfo();
+ }
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, qinfo_type,
+ qinfo, std::vector<std::string>{input0_name, input1_name, input2_name},
+ std::vector<std::string>{output_name});
+
+ delete qinfo;
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>(
+ mlir::Operation &op) const {
+ std::vector<int> outpad, stride, dilation, output_shape;
+
+ auto outpad_attr =
+ op.getAttr("out_pad").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : outpad_attr) {
+ outpad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(outpad, 2);
+
+ auto stride_attr =
+ op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : stride_attr) {
+ stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(stride, 2);
+
+ auto dilation_attr =
+ op.getAttr("dilation").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : dilation_attr) {
+ dilation.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(dilation, 2);
+
+ auto output_shape_attr =
+ op.getAttr("out_shape").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : output_shape_attr) {
+ output_shape.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(output_shape, 4);
+
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string input2_name = GetTensorName(op.getOperand(2));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaTransposeConvAttribute attribute(outpad, stride, dilation, output_shape);
+
+ auto quant_info =
+ op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+ QuantInfo qinfo_type;
+ TosaQuantInfoBase *qinfo;
+ if (quant_info) {
+ qinfo_type = QuantInfo_ConvQuantInfo;
+ qinfo = GetConvQuantInfo(quant_info);
+ } else {
+ qinfo_type = QuantInfo_NONE;
+ qinfo = new TosaNoneQuantInfo();
+ }
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute,
+ qinfo_type, qinfo,
+ std::vector<std::string>{input0_name, input1_name, input2_name},
+ std::vector<std::string>{output_name});
+
+ delete qinfo;
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::FullyConnectedOp>(
+ mlir::Operation &op) const {
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string input2_name = GetTensorName(op.getOperand(2));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ auto quant_info =
+ op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+ QuantInfo qinfo_type;
+ TosaQuantInfoBase *qinfo;
+ if (quant_info) {
+ qinfo_type = QuantInfo_ConvQuantInfo;
+ qinfo = GetConvQuantInfo(quant_info);
+ } else {
+ qinfo_type = QuantInfo_NONE;
+ qinfo = new TosaNoneQuantInfo();
+ }
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_FULLY_CONNECTED, Attribute_NONE, nullptr, qinfo_type, qinfo,
+ std::vector<std::string>{input0_name, input1_name, input2_name},
+ std::vector<std::string>{output_name});
+
+ delete qinfo;
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::MatMulOp>(
+ mlir::Operation &op) const {
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ auto quant_info = op.getAttrOfType<mlir::tosa::MatMulOpQuantizationAttr>(
+ "quantization_info");
+ QuantInfo qinfo_type;
+ TosaQuantInfoBase *qinfo;
+ if (quant_info) {
+ qinfo_type = QuantInfo_MatMulQuantInfo;
+ qinfo = GetMatMulQuantInfo(quant_info);
+ } else {
+ qinfo_type = QuantInfo_NONE;
+ qinfo = new TosaNoneQuantInfo();
+ }
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_MATMUL, Attribute_NONE, nullptr, qinfo_type, qinfo,
+ std::vector<std::string>{input0_name, input1_name},
+ std::vector<std::string>{output_name});
+
+ delete qinfo;
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::SelectOp>(
+ mlir::Operation &op) const {
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string input2_name = GetTensorName(op.getOperand(2));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_SELECT, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input0_name, input1_name, input2_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>(
+ mlir::Operation &op) const {
+ int32_t min_int =
+ op.getAttr("min_int").dyn_cast<mlir::IntegerAttr>().getInt();
+ int32_t max_int =
+ op.getAttr("max_int").dyn_cast<mlir::IntegerAttr>().getInt();
+ float min_fp = op.getAttr("min_fp")
+ .dyn_cast<mlir::FloatAttr>()
+ .getValue()
+ .convertToFloat();
+ float max_fp = op.getAttr("max_fp")
+ .dyn_cast<mlir::FloatAttr>()
+ .getValue()
+ .convertToFloat();
+
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_CLAMP, Attribute_ClampAttribute, &attribute, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ArgMaxOp>(
+ mlir::Operation &op) const {
+ int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt();
+
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaAxisAttribute attribute(axis);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_ARGMAX, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ConcatOp>(
+ mlir::Operation &op) const {
+ int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt();
+
+ std::vector<std::string> inputs;
+ for (uint32_t i = 0; i < op.getNumOperands(); i++) {
+ std::string input_name = GetTensorName(op.getOperand(i));
+ inputs.push_back(input_name);
+ }
+
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaAxisAttribute attribute(axis);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_CONCAT, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr,
+ inputs, std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::NegateOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ auto quant_info = op.getAttrOfType<mlir::tosa::UnaryOpQuantizationAttr>(
+ "quantization_info");
+ QuantInfo qinfo_type;
+ TosaQuantInfoBase *qinfo;
+ if (quant_info) {
+ qinfo_type = QuantInfo_UnaryQuantInfo;
+ qinfo = GetUnaryQuantInfo(quant_info);
+ } else {
+ qinfo_type = QuantInfo_NONE;
+ qinfo = new TosaNoneQuantInfo();
+ }
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_NEGATE, Attribute_NONE, nullptr, qinfo_type, qinfo,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ReshapeOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ std::vector<int> shape;
+ auto shape_attr =
+ op.getAttr("new_shape").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : shape_attr) {
+ shape.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+
+ TosaReshapeAttribute attribute(shape);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_RESHAPE, Attribute_ReshapeAttribute, &attribute, QuantInfo_NONE,
+ nullptr, std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ // Match padding tensor as compile-time constant attribute
+ // TODO: fix when MLIR dialect changes
+ mlir::ElementsAttr paddings_elems;
+ if (!matchPattern(op.getOperand(1), m_Constant(&paddings_elems)))
+ return nullptr;
+
+ std::vector<int> paddings;
+ for (int32_t val : paddings_elems.getValues<int32_t>()) {
+ paddings.push_back(val);
+ }
+
+ TosaPadAttribute attribute(paddings, 0 /* pad_const_int */,
+ 0.0f /* pad_const_fp */);
+
+ auto quant_info =
+ op.getAttrOfType<mlir::tosa::PadOpQuantizationAttr>("quantization_info");
+ QuantInfo qinfo_type;
+ TosaQuantInfoBase *qinfo;
+ if (quant_info) {
+ qinfo_type = QuantInfo_PadQuantInfo;
+ qinfo = GetPadQuantInfo(quant_info);
+ } else {
+ qinfo_type = QuantInfo_NONE;
+ qinfo = new TosaNoneQuantInfo();
+ }
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_PAD, Attribute_PadAttribute, &attribute, qinfo_type, qinfo,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ // Match perm tensor as compile-time constant attribute
+ // TODO: fix when MLIR dialect changes
+ mlir::ElementsAttr perm_elems;
+ if (!matchPattern(op.getOperand(1), m_Constant(&perm_elems)))
+ return nullptr;
+
+ std::vector<int> perm;
+ for (int32_t i = 0; i < perm_elems.getNumElements(); i++) {
+ perm.push_back(perm_elems.getValue<mlir::IntegerAttr>(i).getInt());
+ }
+
+ TosaTransposeAttribute attribute(perm);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_TRANSPOSE, Attribute_TransposeAttribute, &attribute, QuantInfo_NONE,
+ nullptr, std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::SliceOp>(
+ mlir::Operation &op) const {
+ std::vector<int> start, size;
+ auto begin_attr = op.getAttr("start").dyn_cast<mlir::ArrayAttr>().getValue();
+ auto size_attr = op.getAttr("size").dyn_cast<mlir::ArrayAttr>().getValue();
+
+ for (auto &int_attr : begin_attr) {
+ start.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+
+ for (auto &int_attr : size_attr) {
+ size.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+
+ TosaSliceAttribute attribute(start, size);
+
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_SLICE, Attribute_SliceAttribute, &attribute, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::TileOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ std::vector<int> multiples;
+ auto multiples_attr =
+ op.getAttr("multiples").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : multiples_attr) {
+ multiples.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+
+ TosaTileAttribute attribute(multiples);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_TILE, Attribute_TileAttribute, &attribute, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::GatherOp>(
+ mlir::Operation &op) const {
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_GATHER, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input0_name, input1_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ScatterOp>(
+ mlir::Operation &op) const {
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string input2_name = GetTensorName(op.getOperand(2));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_SCATTER, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input0_name, input1_name, input2_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ResizeOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ std::vector<int> output_size;
+ auto output_size_attr =
+ op.getAttr("output_size").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : output_size_attr) {
+ output_size.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(output_size, 2);
+
+ std::vector<int> stride;
+ auto stride_attr =
+ op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : stride_attr) {
+ stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(stride, 2);
+
+ std::vector<int> offset;
+ auto offset_attr =
+ op.getAttr("offset").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &int_attr : offset_attr) {
+ offset.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+ ASSERT_VECTOR_LENGTH(offset, 2);
+
+ int32_t shift = op.getAttr("shift").dyn_cast<mlir::IntegerAttr>().getInt();
+
+ std::vector<float> stride_fp;
+ auto stride_fp_attr =
+ op.getAttr("stride_fp").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &fp_attr : stride_fp_attr) {
+ stride_fp.push_back(fp_attr.dyn_cast<mlir::FloatAttr>().getValueAsDouble());
+ }
+ ASSERT_VECTOR_LENGTH(stride_fp, 2);
+
+ std::vector<float> offset_fp;
+ auto offset_fp_attr =
+ op.getAttr("offset_fp").dyn_cast<mlir::ArrayAttr>().getValue();
+ for (auto &fp_attr : offset_fp_attr) {
+ offset_fp.push_back(fp_attr.dyn_cast<mlir::FloatAttr>().getValueAsDouble());
+ }
+ ASSERT_VECTOR_LENGTH(offset_fp, 2);
+
+ auto mode_str =
+ op.getAttr("mode").dyn_cast<mlir::StringAttr>().getValue().str();
+ ResizeMode mode = ResizeModeStr2Enum(mode_str);
+
+ TosaResizeAttribute attribute(output_size, stride, offset, shift, stride_fp,
+ offset_fp, mode);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_RESIZE, Attribute_ResizeAttribute, &attribute, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ReverseOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt();
+
+ TosaAxisAttribute attribute(axis);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_REVERSE, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::MulOp>(
+ mlir::Operation &op) const {
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ int32_t shift = op.getAttr("shift").dyn_cast<mlir::IntegerAttr>().getInt();
+
+ TosaMulAttribute attribute(shift);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_MUL, Attribute_MulAttribute, &attribute, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input0_name, input1_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ArithmeticRightShiftOp>(
+ mlir::Operation &op) const {
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ bool round = op.getAttr("round").dyn_cast<mlir::BoolAttr>().getValue();
+
+ TosaArithmeticRightShiftAttribute attribute(round);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_ARITHMETIC_RIGHT_SHIFT, Attribute_ArithmeticRightShiftAttribute,
+ &attribute, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input0_name, input1_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::TableOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ // Match table tensor as compile-time constant attribute
+ // TODO: fix when MLIR dialect changes
+ mlir::ElementsAttr table_elems;
+ if (!matchPattern(op.getOperand(1), m_Constant(&table_elems)))
+ return nullptr;
+
+ std::vector<int> table;
+ for (int32_t i = 0; i < table_elems.getNumElements(); i++) {
+ table.push_back(table_elems.getValue<mlir::IntegerAttr>(i).getInt());
+ }
+
+ TosaTableAttribute attribute(table);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_TABLE, Attribute_TableAttribute, &attribute, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>(
+ mlir::Operation &op) const {
+ int32_t input_zp =
+ op.getAttr("input_zp").dyn_cast<mlir::IntegerAttr>().getInt();
+ int32_t output_zp =
+ op.getAttr("output_zp").dyn_cast<mlir::IntegerAttr>().getInt();
+ bool scale32 = op.getAttr("scale32").dyn_cast<mlir::BoolAttr>().getValue();
+ bool double_round =
+ op.getAttr("double_round").dyn_cast<mlir::BoolAttr>().getValue();
+ bool per_channel =
+ op.getAttr("per_channel").dyn_cast<mlir::BoolAttr>().getValue();
+
+ std::vector<int> multiplier, shift;
+ auto multiplier_attr =
+ op.getAttr("multiplier").dyn_cast<mlir::ArrayAttr>().getValue();
+ auto shift_attr = op.getAttr("shift").dyn_cast<mlir::ArrayAttr>().getValue();
+
+ for (auto &int_attr : multiplier_attr) {
+ multiplier.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+
+ for (auto &int_attr : shift_attr) {
+ shift.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
+ }
+
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaRescaleAttribute attribute(input_zp, output_zp, multiplier, shift,
+ scale32, double_round, per_channel);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_RESCALE, Attribute_RescaleAttribute, &attribute, QuantInfo_NONE,
+ nullptr, std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::CustomOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_CUSTOM, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
+ mlir::Operation &op) const {
+ std::vector<std::string> input_names, output_names;
+
+ mlir::Region &then_region = op.getRegion(0);
+ mlir::Region &else_region = op.getRegion(1);
+ std::vector<mlir::Value> then_yields, else_yields;
+ TosaSerializationBasicBlock *then_block = nullptr;
+ TosaSerializationBasicBlock *else_block = nullptr;
+
+ // Building then branch block
+ std::string then_block_name =
+ "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size());
+ then_block = new TosaSerializationBasicBlock(
+ then_block_name, std::vector<TosaSerializationOperator *>(),
+ std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
+ std::vector<std::string>());
+ assert(then_block);
+ block_builder->GetTsh()->GetBlocks().push_back(then_block);
+
+ TosaSerializationBlockBuilder then_block_builder(
+ then_block, block_builder->GetTsh(), &then_region);
+ if (then_block_builder.BuildAllOpsInRegion(then_yields).failed()) {
+ return nullptr;
+ }
+ if (then_yields.size() != op.getNumResults()) {
+ op.emitOpError("BuildOpCondIf: then_region yield.size() doesn't match "
+ "cond_if's output size");
+ return nullptr;
+ }
+
+ // Building else branch block
+ std::string else_block_name =
+ "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size());
+ else_block = new TosaSerializationBasicBlock(
+ else_block_name, std::vector<TosaSerializationOperator *>(),
+ std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
+ std::vector<std::string>());
+ assert(else_block);
+ block_builder->GetTsh()->GetBlocks().push_back(else_block);
+
+ TosaSerializationBlockBuilder else_block_builder(
+ else_block, block_builder->GetTsh(), &else_region);
+ if (else_block_builder.BuildAllOpsInRegion(else_yields).failed()) {
+ return nullptr;
+ }
+ if (else_yields.size() != op.getNumResults()) {
+ op.emitOpError("BuildOpCondIf: else_region yield.size() doesn't match "
+ "cond_if's output size");
+ return nullptr;
+ }
+
+ TosaCondIfAttribute attribute(then_block->GetName(), else_block->GetName());
+
+ for (size_t i = 0; i < op.getNumOperands(); i++) {
+ std::string input_name = GetTensorName(op.getOperand(i));
+ input_names.push_back(input_name);
+ }
+
+ for (size_t i = 0; i < op.getNumResults(); i++) {
+ std::string output_name = GetTensorName(op.getResult(i));
+ output_names.push_back(output_name);
+ }
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_COND_IF, Attribute_CondIfAttribute, &attribute, QuantInfo_NONE,
+ nullptr, input_names, output_names);
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
+ mlir::Operation &op) const {
+ std::vector<std::string> input_names, output_names;
+
+ mlir::Region &cond_region = op.getRegion(0);
+ mlir::Region &body_region = op.getRegion(1);
+ std::vector<mlir::Value> cond_yields, body_yields;
+ TosaSerializationBasicBlock *cond_block = nullptr;
+ TosaSerializationBasicBlock *body_block = nullptr;
+
+ // Building cond branch block
+ std::string cond_block_name =
+ "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size());
+ cond_block = new TosaSerializationBasicBlock(
+ cond_block_name, std::vector<TosaSerializationOperator *>(),
+ std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
+ std::vector<std::string>());
+ assert(cond_block);
+ block_builder->GetTsh()->GetBlocks().push_back(cond_block);
+
+ TosaSerializationBlockBuilder cond_block_builder(
+ cond_block, block_builder->GetTsh(), &cond_region);
+ if (cond_block_builder.BuildAllOpsInRegion(cond_yields).failed()) {
+ return nullptr;
+ }
+ if (cond_yields.size() != 1) {
+ op.emitOpError("BuildOpWhileLoop: cond_region yield.size() is not 1");
+ return nullptr;
+ }
+
+ // Building body branch block
+ std::string body_block_name =
+ "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size());
+ body_block = new TosaSerializationBasicBlock(
+ body_block_name, std::vector<TosaSerializationOperator *>(),
+ std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
+ std::vector<std::string>());
+ assert(body_block);
+ block_builder->GetTsh()->GetBlocks().push_back(body_block);
+
+ TosaSerializationBlockBuilder body_block_builder(
+ body_block, block_builder->GetTsh(), &body_region);
+ if (body_block_builder.BuildAllOpsInRegion(body_yields).failed()) {
+ return nullptr;
+ }
+ if (body_yields.size() != op.getNumResults()) {
+ op.emitOpError("BuildOpWhileLoop: body_region yield.size() doesn't "
+ "match while_loop's output size");
+ return nullptr;
+ }
+
+ TosaWhileLoopAttribute attribute(cond_block->GetName(),
+ body_block->GetName());
+
+ for (size_t i = 0; i < op.getNumOperands(); i++) {
+ std::string input_name = GetTensorName(op.getOperand(i));
+ input_names.push_back(input_name);
+ }
+
+ for (size_t i = 0; i < op.getNumResults(); i++) {
+ std::string output_name = GetTensorName(op.getResult(i));
+ output_names.push_back(output_name);
+ }
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute, QuantInfo_NONE,
+ nullptr, input_names, output_names);
+
+ return tyop;
+}
+
+/* End translating TOSA operator */
+
+mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion(
+ std::vector<mlir::Value> &return_values) {
+ TosaSerializationOperator *ser_operator = nullptr;
+ TosaSerializationTensor *ser_tensor = nullptr;
+ size_t num_blocks_in_region = 0;
+ static int input_tensor_index = 0;
+ static int intermediate_tensor_index = 0;
+ static int output_tensor_index = 0;
+ TosaSerializationOperatorBuilder op_builder(this);
+
+ for (auto &bb : region->getBlocks()) {
+ num_blocks_in_region++;
+
+ if (num_blocks_in_region > 1) {
+ llvm::errs() << "Invalid MLIR: multiple blocks in a region\n";
+ return mlir::failure();
+ }
+
+ // We always have one block for each region right now
+ assert(bb.isEntryBlock());
+
+ // Specify block input tensor name
+ for (auto args : bb.getArguments()) {
+ std::string block_input_name =
+ "TosaInput_" + std::to_string(input_tensor_index++);
+ block->GetInputs().push_back(block_input_name);
+ tensor_map[args] = block_input_name;
+ }
+
+ // Build tensor_map
+ for (auto &op : bb) {
+ if (!(llvm::isa<mlir::tosa::YieldOp>(op) ||
+ llvm::isa<mlir::ReturnOp>(op) ||
+ llvm::isa<mlir::tensor::CastOp>(op))) {
+ for (uint32_t i = 0; i < op.getNumResults(); i++) {
+ std::string intermediate_tensor_name =
+ "layer_" + std::to_string(intermediate_tensor_index++);
+ tensor_map[op.getResult(i)] = intermediate_tensor_name;
+ }
+ } else {
+ if (llvm::isa<mlir::tensor::CastOp>(op))
+ continue;
+ // Override return tensor name
+ for (auto val : op.getOperands()) {
+ // Workaround to skip mlir::tensor::CastOp before return
+ mlir::Operation *val_defining_op = val.getDefiningOp();
+ if (llvm::isa<mlir::tensor::CastOp>(*val_defining_op))
+ val = val_defining_op->getOperand(0);
+
+ // Sanity check. This mlir::Value should be built in map since graph
+ // is DAG
+ if (tensor_map.find(val) == tensor_map.end()) {
+ llvm::errs() << "ERROR: Can't find built mlir::Value key.\n";
+ return mlir::failure();
+ }
+ std::string output_name =
+ "TosaOutput_" + std::to_string(output_tensor_index++);
+ tensor_map[val] = output_name;
+ block->GetOutputs().push_back(output_name);
+ return_values.push_back(val);
+ }
+ }
+ }
+
+ // Build tensor
+ for (auto pair : tensor_map) {
+ ser_tensor = BuildTosaSerializationTensor(pair.first /* val */,
+ pair.second /* name */);
+ if (!ser_tensor) {
+ llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n";
+ return mlir::failure();
+ }
+ block->GetTensors().push_back(ser_tensor);
+ }
+
+ // Build operator
+ for (auto &op : bb) {
+ if (llvm::isa<mlir::tosa::YieldOp>(op) || llvm::isa<mlir::ReturnOp>(op) ||
+ llvm::isa<mlir::tensor::CastOp>(op))
+ continue;
+ ser_operator = BuildTosaSerializationOperator(op_builder, op);
+ if (!ser_operator) {
+ llvm::errs() << "ERROR: Failed to build TosaSerializationOperator\n";
+ return mlir::failure();
+ }
+ block->GetOperators().push_back(ser_operator);
+ }
+ }
+
+ return mlir::success();
+}
+
+TosaSerializationOperator *
+TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
+ const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) {
+ std::string full_op_name = op.getName().getStringRef().str();
+ TosaSerializationOperator *target_operator = nullptr;
+
+ if (false) {
+ }
+#define DEF_OPERATOR(MLIR_OP) \
+ else if (llvm::isa<mlir::tosa::MLIR_OP##Op>(op)) { \
+ target_operator = op_builder.build<mlir::tosa::MLIR_OP##Op>(op); \
+ }
+#include "operator.def"
+#undef DEF_OPERATOR
+ else {
+ llvm::errs() << "unsupported tosa operator " << op.getName().getStringRef()
+ << "\n";
+ }
+
+ if (!target_operator) {
+ llvm::errs() << op.getName().getStringRef()
+ << " operator hasn't been translated to flatbuffer, skipped\n";
+ return nullptr;
+ }
+
+ // Sanity check the number of inputs/outputs of TOSA dialect matches the
+ // number of TOSA flatbuffer
+ if (op.getNumOperands() != target_operator->GetInputTensorNames().size()) {
+ llvm::errs() << "WARNING. MLIR operator has " << op.getNumOperands()
+ << " input tensors != Flatbuffer "
+ "operator has "
+ << target_operator->GetInputTensorNames().size()
+ << " input tensors\n";
+ }
+ if (op.getNumResults() != target_operator->GetOutputTensorNames().size()) {
+ llvm::errs() << "WARNING. MLIR operator has " << op.getNumResults()
+ << " output tensors != Flatbuffer "
+ "operator has "
+ << target_operator->GetOutputTensorNames().size()
+ << " output tensors\n";
+ }
+
+ return target_operator;
+}
+
+TosaSerializationTensor *
+TosaSerializationBlockBuilder::BuildTosaSerializationTensor(
+ mlir::Value val, const std::string &name) {
+ // If tensor already created before, use that tensor directly, create a new
+ // one otherwise
+ TosaSerializationTensor *ts = block->GetTensorByName(name);
+ if (ts) {
+ return nullptr;
+ }
+
+ mlir::RankedTensorType tensor =
+ val.getType().dyn_cast<mlir::RankedTensorType>();
+ std::vector<int32_t> shape(tensor.getShape().begin(),
+ tensor.getShape().end());
+ DType type = Type2DType(tensor.getElementType());
+
+ ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>());
+
+ return ts;
+}
+
+mlir::LogicalResult translate2FlatBuffer(mlir::FuncOp &func,
+ TosaSerializationHandler &tsh) {
+ TosaSerializationBasicBlock *main_block;
+
+ mlir::Region *main_region = func.getCallableRegion();
+ std::vector<mlir::Value> main_returns;
+
+ if (!main_region) {
+ llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n";
+ return mlir::failure();
+ }
+
+ if (!tsh.GetBlocks().empty()) {
+ llvm::errs() << "Internal Error: TosaSerializationHandler's block list "
+ "must be empty\n";
+ return mlir::failure();
+ }
+
+ main_block = new TosaSerializationBasicBlock(
+ std::string("main"), std::vector<TosaSerializationOperator *>(),
+ std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
+ std::vector<std::string>());
+ assert(main_block);
+ tsh.GetBlocks().push_back(main_block);
+
+ TosaSerializationBlockBuilder block_builder(main_block, &tsh, main_region);
+ if (block_builder.BuildAllOpsInRegion(main_returns).failed()) {
+ return mlir::failure();
+ }
+
+ if (main_returns.empty()) {
+ llvm::errs() << "Warning: graph doesn't have return values\n";
+ }
+
+ return mlir::success();
+}
+
+mlir::LogicalResult dumpTosaFlatbuffer(mlir::FuncOp &func) {
+ tosa::TosaSerializationHandler tsh;
+
+ std::string tosa_flatbuffer_directory_fullpath;
+ if (translate2FlatBuffer(func, tsh).failed()) {
+ llvm::errs() << "Fail to translate TOSA MLIR to flatbuffer\n";
+ return mlir::failure();
+ }
+
+ if (tsh.SaveFileTosaFlatbuffer(tosa_flatbuffer_filename.c_str())) {
+ llvm::errs() << "Fail to save flatbuffer " << tosa_flatbuffer_filename
+ << "\n";
+ return mlir::failure();
+ }
+ return mlir::success();
+}
+
+mlir::LogicalResult dumpTosaJSON(mlir::FuncOp &func) {
+ tosa::TosaSerializationHandler tsh;
+
+ const char *tosa_schema = tosa_flatbuffer_schema.c_str();
+
+ if (!tosa_schema) {
+ llvm::errs() << "Flatbuffer schema not defined\n";
+ return mlir::failure();
+ }
+
+ if (tsh.LoadFileSchema(tosa_schema)) {
+ llvm::errs() << "Error loading tosa schema file: " << tosa_schema << "\n";
+ return mlir::failure();
+ }
+
+ std::string tosa_flatbuffer_directory_fullpath;
+ if (translate2FlatBuffer(func, tsh).failed()) {
+ llvm::errs() << "Fail to translate TOSA MLIR to flatbuffer\n";
+ return mlir::failure();
+ }
+
+ if (tsh.SaveFileJson(tosa_flatbuffer_filename.c_str())) {
+ llvm::errs() << "Fail to save flatbuffer " << tosa_flatbuffer_filename
+ << "\n";
+ return mlir::failure();
+ }
+
+ return mlir::success();
+}
+
+namespace mlir {
+
+namespace tosa {
+
+namespace {
+
+class TosaSerialize : public PassWrapper<TosaSerialize, FunctionPass> {
+public:
+ TosaSerialize() = default;
+
+ StringRef getArgument() const final {
+ // This is the argument used to refer to the pass in
+ // the textual format (on the commandline for example).
+ return "tosa-serialize";
+ }
+ StringRef getDescription() const final {
+ // This is a brief description of the pass.
+ return "Run the TOSA serialization (flatbuffer generation) pass";
+ }
+
+ void runOnFunction() override {
+ auto function = getFunction();
+
+ if (dumpTosaFlatbuffer(function).failed()) {
+ llvm::errs() << "Failed to generate TOSA flatbuffer...\n";
+ return signalPassFailure();
+ }
+ }
+};
+
+class TosaSerializeJSON : public PassWrapper<TosaSerializeJSON, FunctionPass> {
+public:
+ TosaSerializeJSON() = default;
+
+ StringRef getArgument() const final {
+ // This is the argument used to refer to the pass in
+ // the textual format (on the commandline for example).
+ return "tosa-serialize-json";
+ }
+ StringRef getDescription() const final {
+ // This is a brief description of the pass.
+ return "Run the TOSA serialization (JSON generation) pass";
+ }
+
+ void runOnFunction() override {
+ auto function = getFunction();
+
+ if (dumpTosaJSON(function).failed()) {
+ llvm::errs() << "Failed to generate TOSA JSON...\n";
+ return signalPassFailure();
+ }
+ }
+};
+
+} // anonymous namespace
+
+// Creates an instance of the TOSA flatbuffer generation pass
+std::unique_ptr<OperationPass<FuncOp>> createTosaSerializePass() {
+ return std::make_unique<TosaSerialize>();
+}
+
+std::unique_ptr<OperationPass<FuncOp>> createTosaSerializeJSONPass() {
+ return std::make_unique<TosaSerializeJSON>();
+}
+
+static PassRegistration<TosaSerialize> pass([] {
+ return createTosaSerializePass();
+});
+
+static PassRegistration<TosaSerializeJSON> passJSON([] {
+ return createTosaSerializeJSONPass();
+});
+
+} // namespace tosa
+} // namespace mlir
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
new file mode 160000
+Subproject 545a508429afe1d22760563d252839e13ecd12a