From 80a022fd103b26a03a04e0565c4d263f73d950b8 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Mon, 15 Nov 2021 17:07:37 -0800 Subject: First commit of tosa serialize passes Signed-off-by: Kevin Cheng Change-Id: I1551017706f6e8af604792f48cdeb49b4da7ef0d --- .clang-format | 1 + .gitmodules | 3 + CMakeLists.txt | 33 + LICENSE.txt | 219 +++++ README.md | 49 ++ include/SerializationPasses.h | 34 + include/SerializationPasses.td | 21 + include/operator.def | 117 +++ src/TosaSerialize.cpp | 1764 ++++++++++++++++++++++++++++++++++++++++ third_party/serialization_lib | 1 + 10 files changed, 2242 insertions(+) create mode 100644 .clang-format create mode 100644 .gitmodules create mode 100644 CMakeLists.txt create mode 100644 LICENSE.txt create mode 100644 README.md create mode 100644 include/SerializationPasses.h create mode 100644 include/SerializationPasses.td create mode 100644 include/operator.def create mode 100644 src/TosaSerialize.cpp create mode 160000 third_party/serialization_lib diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..9b3aa8b --- /dev/null +++ b/.clang-format @@ -0,0 +1 @@ +BasedOnStyle: LLVM diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..5d7eb1e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/serialization_lib"] + path = third_party/serialization_lib + url = https://review.mlplatform.org/tosa/serialization_lib diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..968d73b --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,33 @@ +# TOSA serialization MLIR passes + +cmake_minimum_required(VERSION 3.13.4) +project(MlirTosaPasses) + +set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ standard to conform to") +set(CMAKE_CXX_STANDARD_REQUIRED YES) + +set(CMAKE_VERBOSE_MAKEFILE ON) + +# TOSA MLIR->Flatbuffers serialization pass + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${PROJECT_SOURCE_DIR}/third_party/serialization_lib/include) +include_directories(${PROJECT_SOURCE_DIR}/third_party/serialization_lib/third_party/flatbuffers/include) + +set(LLVM_TARGET_DEFINITIONS include/SerializationPasses.td) +mlir_tablegen(include/SerializationPasses.h.inc -gen-pass-decls -name TosaSerialization) +add_public_tablegen_target(tosa_serialization_passes_inc_gen) + +# Compile the TOSA serialization_lib +add_subdirectory(third_party/serialization_lib) + +add_mlir_library(tosa_serialize + src/TosaSerialize.cpp + + DEPENDS + mlir-headers + tosa_serialization_passes_inc_gen + + LINK_LIBS PRIVATE + tosa_serialization_lib +) diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..45d76b4 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,219 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +---- LLVM Exceptions to the Apache 2.0 License ---- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into an Object form of such source code, you +may redistribute such embedded portions in such Object form without complying +with the conditions of Sections 4(a), 4(b) and 4(d) of the License. + +In addition, if you combine or link compiled forms of this Software with +software that is licensed under the GPLv2 ("Combined Software") and if a +court of competent jurisdiction determines that the patent provision (Section +3), the indemnity provision (Section 9) or other Section of the License +conflicts with the conditions of the GPLv2, you may retroactively and +prospectively choose to deem waived or otherwise exclude such Section(s) of +the License, but only in their entirety and only with respect to the Combined +Software. + diff --git a/README.md b/README.md new file mode 100644 index 0000000..46039b7 --- /dev/null +++ b/README.md @@ -0,0 +1,49 @@ +TOSA MLIR Translator +========================== + +# Introduction + +The *TOSA MLIR Translator* repository implements translators between the TOSA MLIR +dialect and serialized representations. + +The current implementation supports serialization from MLIR form to flatbuffers. +A deserializer from flatbuffers to MLIR form is in development. + +# Dependencies + +##TOSA serialization library + +The library includes a FlatBuffers schema and a C++ API for reading and writing a TOSA +graph as a flatbuffer. + +# Compiling + +This repository does not currently build standalone. It must be included within another +MLIR repository with a pass manager registering the passes implemented within this +repository. + +The included CMake rules can be used to add this repository as a submodule. +The include/SerializationPasses.h enables MLIR pass registration inclusion. + +If target "tosa_serialize" is linked correctly, you should able to see "--tosa-serialize" +and "--tosa-serialize-json" options available in your MLIR pass manager/MLIR optimizer. + +# Usage + +To serialize a TOSA MLIR graph to TOSA flatbuffer binary file: + + \ --tosa-serialize \ \ + --tosa-flatbuffer-filename \ + +To serialize a TOSA MLIR graph to TOSA flatbuffer JSON file: + + \ --tosa-serialize \ \ + --tosa-flatbuffer-schema \ \ + --tosa-flatbuffer-filename \ + + where \ is provided within the serialization library + submodule in third_party/serialization_lib/schema/tosa.fbs + +# License + +The *TOSA MLIR Translator* is licensed under Apache-2.0 with LLVM Exceptions. diff --git a/include/SerializationPasses.h b/include/SerializationPasses.h new file mode 100644 index 0000000..0991f87 --- /dev/null +++ b/include/SerializationPasses.h @@ -0,0 +1,34 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef INCLUDE_SERIALIZATION_PASSES_H +#define INCLUDE_SERIALIZATION_PASSES_H + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tosa { + +std::unique_ptr> createTosaSerializePass(); + +#define GEN_PASS_REGISTRATION +#include "SerializationPasses.h.inc" + +} // namespace tosa +} // namespace mlir + +#endif // INCLUDE_SERIALIZATION_PASSES_H diff --git a/include/SerializationPasses.td b/include/SerializationPasses.td new file mode 100644 index 0000000..6df996e --- /dev/null +++ b/include/SerializationPasses.td @@ -0,0 +1,21 @@ +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +include "mlir/Pass/PassBase.td" + +def TosaSerializationPass : Pass<"tosa-serialize", "FuncOp"> { + let summary = "Generate TOSA flatbuffer serialized form"; + let constructor = "createTosaSerializePass()"; +} + diff --git a/include/operator.def b/include/operator.def new file mode 100644 index 0000000..85bb5c9 --- /dev/null +++ b/include/operator.def @@ -0,0 +1,117 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + Syntax: + DEF_OPERATOR(MLIR_OP) + + Description: + MLIR_OP: the operator class type, must match mlir/include/mlir/Dialect/Tosa/IR/tosa_ops.td in llvm-project repo +*/ + +/* tensor operators */ +DEF_OPERATOR(ArgMax) +DEF_OPERATOR(AvgPool2d) +DEF_OPERATOR(Conv2D) +DEF_OPERATOR(Conv3D) +DEF_OPERATOR(DepthwiseConv2D) +DEF_OPERATOR(FullyConnected) +DEF_OPERATOR(MatMul) +DEF_OPERATOR(MaxPool2d) +DEF_OPERATOR(TransposeConv2D) + +/* activation */ +DEF_OPERATOR(Clamp) +DEF_OPERATOR(Sigmoid) +DEF_OPERATOR(Tanh) + +/* elementwise - binary */ +DEF_OPERATOR(Add) +DEF_OPERATOR(ArithmeticRightShift) +DEF_OPERATOR(BitwiseAnd) +DEF_OPERATOR(BitwiseOr) +DEF_OPERATOR(BitwiseXor) +DEF_OPERATOR(Div) +DEF_OPERATOR(LogicalAnd) +DEF_OPERATOR(LogicalLeftShift) +DEF_OPERATOR(LogicalRightShift) +DEF_OPERATOR(LogicalOr) +DEF_OPERATOR(LogicalXor) +DEF_OPERATOR(Maximum) +DEF_OPERATOR(Minimum) +DEF_OPERATOR(Mul) +DEF_OPERATOR(Pow) +DEF_OPERATOR(Sub) +DEF_OPERATOR(Table) + +/* elementwise - unary */ +DEF_OPERATOR(Abs) +DEF_OPERATOR(BitwiseNot) +DEF_OPERATOR(Ceil) +DEF_OPERATOR(Clz) +DEF_OPERATOR(Exp) +DEF_OPERATOR(Floor) +DEF_OPERATOR(Log) +DEF_OPERATOR(LogicalNot) +DEF_OPERATOR(Negate) +DEF_OPERATOR(Reciprocal) +DEF_OPERATOR(Rsqrt) + +/* elementwise - ternary */ +DEF_OPERATOR(Select) + +/* logical */ +DEF_OPERATOR(Equal) +DEF_OPERATOR(Greater) +DEF_OPERATOR(GreaterEqual) + +/* reduction */ +DEF_OPERATOR(ReduceAny) +DEF_OPERATOR(ReduceAll) +DEF_OPERATOR(ReduceMax) +DEF_OPERATOR(ReduceMin) +DEF_OPERATOR(ReduceProd) +DEF_OPERATOR(ReduceSum) + +/* memory operation */ +DEF_OPERATOR(Concat) +DEF_OPERATOR(Pad) +DEF_OPERATOR(Reshape) +DEF_OPERATOR(Reverse) +DEF_OPERATOR(Slice) +DEF_OPERATOR(Tile) +DEF_OPERATOR(Transpose) + +/* gather/scatter */ +DEF_OPERATOR(Gather) +DEF_OPERATOR(Scatter) + +/* image */ +DEF_OPERATOR(Resize) + +/* quantization */ +DEF_OPERATOR(Cast) +DEF_OPERATOR(Rescale) + +/* data nodes */ +DEF_OPERATOR(Const) +DEF_OPERATOR(Identity) + +/* custom operations */ +DEF_OPERATOR(Custom) + +/* control flow operators */ +DEF_OPERATOR(If) +DEF_OPERATOR(While) diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp new file mode 100644 index 0000000..5699ffe --- /dev/null +++ b/src/TosaSerialize.cpp @@ -0,0 +1,1764 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TOSA flatbuffer generation + +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tosa_serialization_handler.h" +#include +#include + +// The namespace might be confusing here. We have mlir::tosa:: defined in MLIR +// and tosa:: defined in serialization library +// TODO: align the namespace +using namespace tosa; + +namespace cl = llvm::cl; + +llvm::cl::opt tosa_flatbuffer_filename( + "tosa-flatbuffer-filename", llvm::cl::desc(""), + llvm::cl::init("tosa_dump.tosa"), llvm::cl::value_desc("filename")); + +llvm::cl::opt tosa_flatbuffer_schema( + "tosa-flatbuffer-schema", llvm::cl::desc(""), + llvm::cl::init(""), llvm::cl::value_desc("filename")); + +// Specialize mlir::Value for std::hash and std::equal_to to be able to +// build std::unordered_map +namespace std { + +template <> struct hash { + std::size_t operator()(const mlir::Value &val) const { + return static_cast(mlir::hash_value(val)); + } +}; + +template <> struct equal_to { + bool operator()(const mlir::Value &lhs, const mlir::Value &rhs) const { + return (lhs == rhs); + } +}; + +} // namespace std + +ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { + if (mode_str == "NEAREST_NEIGHBOR") + return ResizeMode_NEAREST; + else if (mode_str == "BILINEAR") + return ResizeMode_BILINEAR; + else + return ResizeMode_UNKNOWN; +} + +DType Type2DType(mlir::Type element_type) { + if (element_type.isF64() || element_type.isF32() || element_type.isF16() || + element_type.isBF16()) { + return DType_FLOAT; + } else if (element_type.isUnsignedInteger(8)) { + return DType_UINT8; + } else if (element_type.isInteger(4)) { + return DType_INT8; + } else if (element_type.isInteger(8)) { + return DType_INT8; + } else if (element_type.isInteger(16)) { + return DType_INT16; + } else if (element_type.isInteger(32)) { + return DType_INT32; + } else if (element_type.isInteger(48)) { + return DType_INT48; + } + // boolean in MLIR treated as integer with bitwidth 1 + else if (element_type.isInteger(1)) { + return DType_BOOL; + } + return DType_UNKNOWN; +} + +int GetQuantizedParameter(mlir::Type type, std::vector &scale, + std::vector &zeropoint, + int32_t &quantized_dimension, int64_t &quant_min, + int64_t &quant_max) { + if (auto qtype = type.dyn_cast()) { + scale.push_back(qtype.getScale()); + zeropoint.push_back(qtype.getZeroPoint()); + quantized_dimension = 0; + + quant_min = qtype.getStorageTypeMin(); + quant_max = qtype.getStorageTypeMax(); + } else if (auto qtype = + type.dyn_cast()) { + scale.assign(qtype.getScales().begin(), qtype.getScales().end()); + zeropoint.assign(qtype.getZeroPoints().begin(), + qtype.getZeroPoints().end()); + quantized_dimension = qtype.getQuantizedDimension(); + + quant_min = qtype.getStorageTypeMin(); + quant_max = qtype.getStorageTypeMax(); + } else { + return 1; + } + + return 0; +} + +TosaQuantInfoBase * +GetUnaryQuantInfo(mlir::tosa::UnaryOpQuantizationAttr quant_info) { + int32_t input_zp = quant_info.input_zp().getInt(); + int32_t output_zp = quant_info.output_zp().getInt(); + + TosaQuantInfoBase *qinfo = new TosaUnaryQuantInfo(input_zp, output_zp); + + return qinfo; +} + +TosaQuantInfoBase * +GetConvQuantInfo(mlir::tosa::ConvOpQuantizationAttr quant_info) { + int32_t input_zp = quant_info.input_zp().getInt(); + int32_t weight_zp = quant_info.weight_zp().getInt(); + + TosaQuantInfoBase *qinfo = new TosaConvQuantInfo(input_zp, weight_zp); + + return qinfo; +} + +TosaQuantInfoBase * +GetPadQuantInfo(mlir::tosa::PadOpQuantizationAttr quant_info) { + int32_t input_zp = quant_info.input_zp().getInt(); + + TosaQuantInfoBase *qinfo = new TosaPadQuantInfo(input_zp); + + return qinfo; +} + +TosaQuantInfoBase * +GetMatMulQuantInfo(mlir::tosa::MatMulOpQuantizationAttr quant_info) { + int32_t a_zp = quant_info.a_zp().getInt(); + int32_t b_zp = quant_info.b_zp().getInt(); + + TosaQuantInfoBase *qinfo = new TosaMatMulQuantInfo(a_zp, b_zp); + + return qinfo; +} + +class TosaSerializationBlockBuilder; + +class TosaSerializationOperatorBuilder { +public: + TosaSerializationOperatorBuilder( + TosaSerializationBlockBuilder *_block_builder) + : block_builder(_block_builder) {} + template + TosaSerializationOperator *build(mlir::Operation &op) const; + +private: + std::string GetTensorName(mlir::Value val) const; + TosaSerializationOperator *BuildPoolOpFromMlirOp(mlir::Operation &op, + Op opcode) const; + TosaSerializationOperator *BuildEwiseBinaryOpFromMlirOp(mlir::Operation &op, + Op opcode) const; + TosaSerializationOperator *BuildEwiseUnaryOpFromMlirOp(mlir::Operation &op, + Op opcode) const; + TosaSerializationOperator *BuildReductionOpFromMlirOp(mlir::Operation &op, + Op opcode) const; + TosaSerializationBlockBuilder *block_builder; +}; + +// This builder assumes each region only has only one block +class TosaSerializationBlockBuilder { +public: + friend class TosaSerializationOperatorBuilder; + TosaSerializationBlockBuilder(TosaSerializationBasicBlock *_block, + TosaSerializationHandler *_tsh, + mlir::Region *_region) + : block(_block), tsh(_tsh), region(_region) {} + + mlir::LogicalResult + BuildAllOpsInRegion(std::vector &return_values); + TosaSerializationBasicBlock *GetBlock() { return block; } + TosaSerializationHandler *GetTsh() { return tsh; } + +private: + TosaSerializationOperator *BuildTosaSerializationOperator( + const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op); + TosaSerializationTensor * + BuildTosaSerializationTensor(mlir::Value val, const std::string &name); + + TosaSerializationBasicBlock *block; + TosaSerializationHandler *tsh; + mlir::Region *region; + std::unordered_map tensor_map; +}; + +std::string +TosaSerializationOperatorBuilder::GetTensorName(mlir::Value val) const { + if (block_builder->tensor_map.find(val) == block_builder->tensor_map.end()) { + llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map"; + assert(0); + } + return block_builder->tensor_map[val]; +} + +// Main template to catch unimplemented translation. +template +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build(mlir::Operation &op) const { + llvm::errs() << "Translation of operator " << op.getName().getStringRef() + << " is not implemented yet\n"; + return nullptr; +} + +/* Start translating TOSA operator */ + +#define ASSERT_VECTOR_LENGTH(VECTOR, LENGTH) \ + if (VECTOR.size() != LENGTH) { \ + std::string msg; \ + msg = std::string(#VECTOR) + " is [" + std::to_string(VECTOR.size()) + \ + "], but expected to be [" + std::to_string(LENGTH) + "]\n"; \ + op.emitOpError(msg.c_str()); \ + return nullptr; \ + } + +TosaSerializationOperator * +TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, + Op opcode) const { + std::vector pad, stride, kernel; + + auto pad_attr = op.getAttr("pad").dyn_cast().getValue(); + for (auto &int_attr : pad_attr) { + pad.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(pad, 4); + + auto stride_attr = + op.getAttr("stride").dyn_cast().getValue(); + for (auto &int_attr : stride_attr) { + stride.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(stride, 2); + + auto kernel_attr = + op.getAttr("kernel").dyn_cast().getValue(); + for (auto &int_attr : kernel_attr) { + kernel.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(kernel, 2); + + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaPoolAttribute attribute(pad, kernel, stride); + auto quant_info = op.getAttrOfType( + "quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_UnaryQuantInfo; + qinfo = GetUnaryQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + opcode, Attribute_PoolAttribute, &attribute, qinfo_type, qinfo, + std::vector{input_name}, + std::vector{output_name}); + + delete qinfo; + + return tyop; +} + +TosaSerializationOperator * +TosaSerializationOperatorBuilder::BuildEwiseBinaryOpFromMlirOp( + mlir::Operation &op, Op opcode) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + opcode, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{input0_name, input1_name}, + std::vector{output_name}); + + return tyop; +} + +TosaSerializationOperator * +TosaSerializationOperatorBuilder::BuildEwiseUnaryOpFromMlirOp( + mlir::Operation &op, Op opcode) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + opcode, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +TosaSerializationOperator * +TosaSerializationOperatorBuilder::BuildReductionOpFromMlirOp( + mlir::Operation &op, Op opcode) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + int32_t axis = op.getAttr("axis").dyn_cast().getInt(); + TosaAxisAttribute attribute(axis); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + opcode, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +#define BUILD_OP_POOL2D(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + TosaSerializationOperator * \ + TosaSerializationOperatorBuilder::build( \ + mlir::Operation & op) const { \ + return BuildPoolOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \ + } + +#define BUILD_OP_ELEMENTWISE_BINARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + TosaSerializationOperator * \ + TosaSerializationOperatorBuilder::build( \ + mlir::Operation & op) const { \ + return BuildEwiseBinaryOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \ + } + +#define BUILD_OP_ELEMENTWISE_UNARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + TosaSerializationOperator * \ + TosaSerializationOperatorBuilder::build( \ + mlir::Operation & op) const { \ + return BuildEwiseUnaryOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \ + } + +#define BUILD_OP_REDUCTION(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + TosaSerializationOperator * \ + TosaSerializationOperatorBuilder::build( \ + mlir::Operation & op) const { \ + return BuildReductionOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \ + } + +BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) +BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D) + +BUILD_OP_ELEMENTWISE_BINARY(Add, ADD) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR) +BUILD_OP_ELEMENTWISE_BINARY(Div, INTDIV) +BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND) +BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT) +BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT) +BUILD_OP_ELEMENTWISE_BINARY(LogicalOr, LOGICAL_OR) +BUILD_OP_ELEMENTWISE_BINARY(LogicalXor, LOGICAL_XOR) +BUILD_OP_ELEMENTWISE_BINARY(Maximum, MAXIMUM) +BUILD_OP_ELEMENTWISE_BINARY(Minimum, MINIMUM) +BUILD_OP_ELEMENTWISE_BINARY(Pow, POW) +BUILD_OP_ELEMENTWISE_BINARY(Sub, SUB) + +BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS) +BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT) +BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL) +BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ) +BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP) +BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR) +BUILD_OP_ELEMENTWISE_UNARY(Log, LOG) +BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT) +BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL) +BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT) + +BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY) +BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL) +BUILD_OP_REDUCTION(ReduceMax, REDUCE_MAX) +BUILD_OP_REDUCTION(ReduceMin, REDUCE_MIN) +BUILD_OP_REDUCTION(ReduceProd, REDUCE_PRODUCT) +BUILD_OP_REDUCTION(ReduceSum, REDUCE_SUM) + +BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL) +BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER) +BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL) + +BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID) +BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH) +BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY) +BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST) + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string output_name = GetTensorName(op.getResult(0)); + TosaSerializationTensor *ts = + block_builder->GetBlock()->GetTensorByName(output_name); + if (!ts) { + op.emitOpError( + "ERROR: serialization tensor must be built before building operator"); + return nullptr; + } + +#if 0 + // Gracefully handle constants of "constant unit" type which have no value + // by creating a numpy value of 0. + auto unit_val = op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (unit_val) + { + std::vector data = { 0.0 }; + type = DType_FLOAT; + TosaSerializationHandler::ConvertF32toU8(data, u8_data); + } +#endif + + // Update tensor.data array with Const value attribute + std::vector u8_data; + DType type = ts->GetDtype(); + if (type == DType_FLOAT) { + std::vector data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back((float)val_attr.getValueAsDouble()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + TosaSerializationHandler::ConvertF32toU8(data, u8_data); + } else if (type == DType_INT8) { + std::vector data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getInt()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + TosaSerializationHandler::ConvertI8toU8(data, u8_data); + } else if (type == DType_INT16) { + std::vector data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getInt()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + TosaSerializationHandler::ConvertI16toU8(data, u8_data); + } else if (type == DType_INT32) { + std::vector data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getInt()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + TosaSerializationHandler::ConvertI32toU8(data, u8_data); + } else if (type == DType_INT48) { + std::vector data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getInt()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + TosaSerializationHandler::ConvertI48toU8(data, u8_data); + } else if (type == DType_BOOL) { + std::vector data; + + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getValue()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + + TosaSerializationHandler::ConvertBooltoU8(data, u8_data); + } else { + op.emitOpError("Unknown element type of const attribute"); + return nullptr; + } + ts->SetData(u8_data); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CONST, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{}, std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector pad, stride, dilation; + + auto pad_attr = op.getAttr("pad").dyn_cast().getValue(); + for (auto &int_attr : pad_attr) { + pad.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(pad, 4); + + auto stride_attr = + op.getAttr("stride").dyn_cast().getValue(); + for (auto &int_attr : stride_attr) { + stride.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(stride, 2); + + auto dilation_attr = + op.getAttr("dilation").dyn_cast().getValue(); + for (auto &int_attr : dilation_attr) { + dilation.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(dilation, 2); + + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaConvAttribute attribute(pad, stride, dilation); + + auto quant_info = + op.getAttrOfType("quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_ConvQuantInfo; + qinfo = GetConvQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CONV2D, Attribute_ConvAttribute, &attribute, qinfo_type, qinfo, + std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + delete qinfo; + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector pad, stride, dilation; + + auto pad_attr = op.getAttr("pad").dyn_cast().getValue(); + for (auto &int_attr : pad_attr) { + pad.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(pad, 4); + + auto stride_attr = + op.getAttr("stride").dyn_cast().getValue(); + for (auto &int_attr : stride_attr) { + stride.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(stride, 2); + + auto dilation_attr = + op.getAttr("dilation").dyn_cast().getValue(); + for (auto &int_attr : dilation_attr) { + dilation.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(dilation, 2); + + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaConvAttribute attribute(pad, stride, dilation); + + auto quant_info = + op.getAttrOfType("quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_ConvQuantInfo; + qinfo = GetConvQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, qinfo_type, + qinfo, std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + delete qinfo; + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector outpad, stride, dilation, output_shape; + + auto outpad_attr = + op.getAttr("out_pad").dyn_cast().getValue(); + for (auto &int_attr : outpad_attr) { + outpad.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(outpad, 2); + + auto stride_attr = + op.getAttr("stride").dyn_cast().getValue(); + for (auto &int_attr : stride_attr) { + stride.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(stride, 2); + + auto dilation_attr = + op.getAttr("dilation").dyn_cast().getValue(); + for (auto &int_attr : dilation_attr) { + dilation.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(dilation, 2); + + auto output_shape_attr = + op.getAttr("out_shape").dyn_cast().getValue(); + for (auto &int_attr : output_shape_attr) { + output_shape.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(output_shape, 4); + + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaTransposeConvAttribute attribute(outpad, stride, dilation, output_shape); + + auto quant_info = + op.getAttrOfType("quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_ConvQuantInfo; + qinfo = GetConvQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, + qinfo_type, qinfo, + std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + delete qinfo; + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + auto quant_info = + op.getAttrOfType("quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_ConvQuantInfo; + qinfo = GetConvQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_FULLY_CONNECTED, Attribute_NONE, nullptr, qinfo_type, qinfo, + std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + delete qinfo; + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string output_name = GetTensorName(op.getResult(0)); + + auto quant_info = op.getAttrOfType( + "quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_MatMulQuantInfo; + qinfo = GetMatMulQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_MATMUL, Attribute_NONE, nullptr, qinfo_type, qinfo, + std::vector{input0_name, input1_name}, + std::vector{output_name}); + + delete qinfo; + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_SELECT, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + int32_t min_int = + op.getAttr("min_int").dyn_cast().getInt(); + int32_t max_int = + op.getAttr("max_int").dyn_cast().getInt(); + float min_fp = op.getAttr("min_fp") + .dyn_cast() + .getValue() + .convertToFloat(); + float max_fp = op.getAttr("max_fp") + .dyn_cast() + .getValue() + .convertToFloat(); + + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CLAMP, Attribute_ClampAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + int32_t axis = op.getAttr("axis").dyn_cast().getInt(); + + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaAxisAttribute attribute(axis); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_ARGMAX, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + int32_t axis = op.getAttr("axis").dyn_cast().getInt(); + + std::vector inputs; + for (uint32_t i = 0; i < op.getNumOperands(); i++) { + std::string input_name = GetTensorName(op.getOperand(i)); + inputs.push_back(input_name); + } + + std::string output_name = GetTensorName(op.getResult(0)); + + TosaAxisAttribute attribute(axis); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CONCAT, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr, + inputs, std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + auto quant_info = op.getAttrOfType( + "quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_UnaryQuantInfo; + qinfo = GetUnaryQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_NEGATE, Attribute_NONE, nullptr, qinfo_type, qinfo, + std::vector{input_name}, + std::vector{output_name}); + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + std::vector shape; + auto shape_attr = + op.getAttr("new_shape").dyn_cast().getValue(); + for (auto &int_attr : shape_attr) { + shape.push_back(int_attr.dyn_cast().getInt()); + } + + TosaReshapeAttribute attribute(shape); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_RESHAPE, Attribute_ReshapeAttribute, &attribute, QuantInfo_NONE, + nullptr, std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + // Match padding tensor as compile-time constant attribute + // TODO: fix when MLIR dialect changes + mlir::ElementsAttr paddings_elems; + if (!matchPattern(op.getOperand(1), m_Constant(&paddings_elems))) + return nullptr; + + std::vector paddings; + for (int32_t val : paddings_elems.getValues()) { + paddings.push_back(val); + } + + TosaPadAttribute attribute(paddings, 0 /* pad_const_int */, + 0.0f /* pad_const_fp */); + + auto quant_info = + op.getAttrOfType("quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_PadQuantInfo; + qinfo = GetPadQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_PAD, Attribute_PadAttribute, &attribute, qinfo_type, qinfo, + std::vector{input_name}, + std::vector{output_name}); + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + // Match perm tensor as compile-time constant attribute + // TODO: fix when MLIR dialect changes + mlir::ElementsAttr perm_elems; + if (!matchPattern(op.getOperand(1), m_Constant(&perm_elems))) + return nullptr; + + std::vector perm; + for (int32_t i = 0; i < perm_elems.getNumElements(); i++) { + perm.push_back(perm_elems.getValue(i).getInt()); + } + + TosaTransposeAttribute attribute(perm); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_TRANSPOSE, Attribute_TransposeAttribute, &attribute, QuantInfo_NONE, + nullptr, std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector start, size; + auto begin_attr = op.getAttr("start").dyn_cast().getValue(); + auto size_attr = op.getAttr("size").dyn_cast().getValue(); + + for (auto &int_attr : begin_attr) { + start.push_back(int_attr.dyn_cast().getInt()); + } + + for (auto &int_attr : size_attr) { + size.push_back(int_attr.dyn_cast().getInt()); + } + + TosaSliceAttribute attribute(start, size); + + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_SLICE, Attribute_SliceAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + std::vector multiples; + auto multiples_attr = + op.getAttr("multiples").dyn_cast().getValue(); + for (auto &int_attr : multiples_attr) { + multiples.push_back(int_attr.dyn_cast().getInt()); + } + + TosaTileAttribute attribute(multiples); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_TILE, Attribute_TileAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_GATHER, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{input0_name, input1_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_SCATTER, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + std::vector output_size; + auto output_size_attr = + op.getAttr("output_size").dyn_cast().getValue(); + for (auto &int_attr : output_size_attr) { + output_size.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(output_size, 2); + + std::vector stride; + auto stride_attr = + op.getAttr("stride").dyn_cast().getValue(); + for (auto &int_attr : stride_attr) { + stride.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(stride, 2); + + std::vector offset; + auto offset_attr = + op.getAttr("offset").dyn_cast().getValue(); + for (auto &int_attr : offset_attr) { + offset.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(offset, 2); + + int32_t shift = op.getAttr("shift").dyn_cast().getInt(); + + std::vector stride_fp; + auto stride_fp_attr = + op.getAttr("stride_fp").dyn_cast().getValue(); + for (auto &fp_attr : stride_fp_attr) { + stride_fp.push_back(fp_attr.dyn_cast().getValueAsDouble()); + } + ASSERT_VECTOR_LENGTH(stride_fp, 2); + + std::vector offset_fp; + auto offset_fp_attr = + op.getAttr("offset_fp").dyn_cast().getValue(); + for (auto &fp_attr : offset_fp_attr) { + offset_fp.push_back(fp_attr.dyn_cast().getValueAsDouble()); + } + ASSERT_VECTOR_LENGTH(offset_fp, 2); + + auto mode_str = + op.getAttr("mode").dyn_cast().getValue().str(); + ResizeMode mode = ResizeModeStr2Enum(mode_str); + + TosaResizeAttribute attribute(output_size, stride, offset, shift, stride_fp, + offset_fp, mode); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_RESIZE, Attribute_ResizeAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + int32_t axis = op.getAttr("axis").dyn_cast().getInt(); + + TosaAxisAttribute attribute(axis); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_REVERSE, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string output_name = GetTensorName(op.getResult(0)); + + int32_t shift = op.getAttr("shift").dyn_cast().getInt(); + + TosaMulAttribute attribute(shift); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_MUL, Attribute_MulAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input0_name, input1_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string output_name = GetTensorName(op.getResult(0)); + + bool round = op.getAttr("round").dyn_cast().getValue(); + + TosaArithmeticRightShiftAttribute attribute(round); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_ARITHMETIC_RIGHT_SHIFT, Attribute_ArithmeticRightShiftAttribute, + &attribute, QuantInfo_NONE, nullptr, + std::vector{input0_name, input1_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + // Match table tensor as compile-time constant attribute + // TODO: fix when MLIR dialect changes + mlir::ElementsAttr table_elems; + if (!matchPattern(op.getOperand(1), m_Constant(&table_elems))) + return nullptr; + + std::vector table; + for (int32_t i = 0; i < table_elems.getNumElements(); i++) { + table.push_back(table_elems.getValue(i).getInt()); + } + + TosaTableAttribute attribute(table); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_TABLE, Attribute_TableAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + int32_t input_zp = + op.getAttr("input_zp").dyn_cast().getInt(); + int32_t output_zp = + op.getAttr("output_zp").dyn_cast().getInt(); + bool scale32 = op.getAttr("scale32").dyn_cast().getValue(); + bool double_round = + op.getAttr("double_round").dyn_cast().getValue(); + bool per_channel = + op.getAttr("per_channel").dyn_cast().getValue(); + + std::vector multiplier, shift; + auto multiplier_attr = + op.getAttr("multiplier").dyn_cast().getValue(); + auto shift_attr = op.getAttr("shift").dyn_cast().getValue(); + + for (auto &int_attr : multiplier_attr) { + multiplier.push_back(int_attr.dyn_cast().getInt()); + } + + for (auto &int_attr : shift_attr) { + shift.push_back(int_attr.dyn_cast().getInt()); + } + + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaRescaleAttribute attribute(input_zp, output_zp, multiplier, shift, + scale32, double_round, per_channel); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_RESCALE, Attribute_RescaleAttribute, &attribute, QuantInfo_NONE, + nullptr, std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CUSTOM, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector input_names, output_names; + + mlir::Region &then_region = op.getRegion(0); + mlir::Region &else_region = op.getRegion(1); + std::vector then_yields, else_yields; + TosaSerializationBasicBlock *then_block = nullptr; + TosaSerializationBasicBlock *else_block = nullptr; + + // Building then branch block + std::string then_block_name = + "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); + then_block = new TosaSerializationBasicBlock( + then_block_name, std::vector(), + std::vector(), std::vector(), + std::vector()); + assert(then_block); + block_builder->GetTsh()->GetBlocks().push_back(then_block); + + TosaSerializationBlockBuilder then_block_builder( + then_block, block_builder->GetTsh(), &then_region); + if (then_block_builder.BuildAllOpsInRegion(then_yields).failed()) { + return nullptr; + } + if (then_yields.size() != op.getNumResults()) { + op.emitOpError("BuildOpCondIf: then_region yield.size() doesn't match " + "cond_if's output size"); + return nullptr; + } + + // Building else branch block + std::string else_block_name = + "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); + else_block = new TosaSerializationBasicBlock( + else_block_name, std::vector(), + std::vector(), std::vector(), + std::vector()); + assert(else_block); + block_builder->GetTsh()->GetBlocks().push_back(else_block); + + TosaSerializationBlockBuilder else_block_builder( + else_block, block_builder->GetTsh(), &else_region); + if (else_block_builder.BuildAllOpsInRegion(else_yields).failed()) { + return nullptr; + } + if (else_yields.size() != op.getNumResults()) { + op.emitOpError("BuildOpCondIf: else_region yield.size() doesn't match " + "cond_if's output size"); + return nullptr; + } + + TosaCondIfAttribute attribute(then_block->GetName(), else_block->GetName()); + + for (size_t i = 0; i < op.getNumOperands(); i++) { + std::string input_name = GetTensorName(op.getOperand(i)); + input_names.push_back(input_name); + } + + for (size_t i = 0; i < op.getNumResults(); i++) { + std::string output_name = GetTensorName(op.getResult(i)); + output_names.push_back(output_name); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_COND_IF, Attribute_CondIfAttribute, &attribute, QuantInfo_NONE, + nullptr, input_names, output_names); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector input_names, output_names; + + mlir::Region &cond_region = op.getRegion(0); + mlir::Region &body_region = op.getRegion(1); + std::vector cond_yields, body_yields; + TosaSerializationBasicBlock *cond_block = nullptr; + TosaSerializationBasicBlock *body_block = nullptr; + + // Building cond branch block + std::string cond_block_name = + "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); + cond_block = new TosaSerializationBasicBlock( + cond_block_name, std::vector(), + std::vector(), std::vector(), + std::vector()); + assert(cond_block); + block_builder->GetTsh()->GetBlocks().push_back(cond_block); + + TosaSerializationBlockBuilder cond_block_builder( + cond_block, block_builder->GetTsh(), &cond_region); + if (cond_block_builder.BuildAllOpsInRegion(cond_yields).failed()) { + return nullptr; + } + if (cond_yields.size() != 1) { + op.emitOpError("BuildOpWhileLoop: cond_region yield.size() is not 1"); + return nullptr; + } + + // Building body branch block + std::string body_block_name = + "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); + body_block = new TosaSerializationBasicBlock( + body_block_name, std::vector(), + std::vector(), std::vector(), + std::vector()); + assert(body_block); + block_builder->GetTsh()->GetBlocks().push_back(body_block); + + TosaSerializationBlockBuilder body_block_builder( + body_block, block_builder->GetTsh(), &body_region); + if (body_block_builder.BuildAllOpsInRegion(body_yields).failed()) { + return nullptr; + } + if (body_yields.size() != op.getNumResults()) { + op.emitOpError("BuildOpWhileLoop: body_region yield.size() doesn't " + "match while_loop's output size"); + return nullptr; + } + + TosaWhileLoopAttribute attribute(cond_block->GetName(), + body_block->GetName()); + + for (size_t i = 0; i < op.getNumOperands(); i++) { + std::string input_name = GetTensorName(op.getOperand(i)); + input_names.push_back(input_name); + } + + for (size_t i = 0; i < op.getNumResults(); i++) { + std::string output_name = GetTensorName(op.getResult(i)); + output_names.push_back(output_name); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute, QuantInfo_NONE, + nullptr, input_names, output_names); + + return tyop; +} + +/* End translating TOSA operator */ + +mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( + std::vector &return_values) { + TosaSerializationOperator *ser_operator = nullptr; + TosaSerializationTensor *ser_tensor = nullptr; + size_t num_blocks_in_region = 0; + static int input_tensor_index = 0; + static int intermediate_tensor_index = 0; + static int output_tensor_index = 0; + TosaSerializationOperatorBuilder op_builder(this); + + for (auto &bb : region->getBlocks()) { + num_blocks_in_region++; + + if (num_blocks_in_region > 1) { + llvm::errs() << "Invalid MLIR: multiple blocks in a region\n"; + return mlir::failure(); + } + + // We always have one block for each region right now + assert(bb.isEntryBlock()); + + // Specify block input tensor name + for (auto args : bb.getArguments()) { + std::string block_input_name = + "TosaInput_" + std::to_string(input_tensor_index++); + block->GetInputs().push_back(block_input_name); + tensor_map[args] = block_input_name; + } + + // Build tensor_map + for (auto &op : bb) { + if (!(llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op))) { + for (uint32_t i = 0; i < op.getNumResults(); i++) { + std::string intermediate_tensor_name = + "layer_" + std::to_string(intermediate_tensor_index++); + tensor_map[op.getResult(i)] = intermediate_tensor_name; + } + } else { + if (llvm::isa(op)) + continue; + // Override return tensor name + for (auto val : op.getOperands()) { + // Workaround to skip mlir::tensor::CastOp before return + mlir::Operation *val_defining_op = val.getDefiningOp(); + if (llvm::isa(*val_defining_op)) + val = val_defining_op->getOperand(0); + + // Sanity check. This mlir::Value should be built in map since graph + // is DAG + if (tensor_map.find(val) == tensor_map.end()) { + llvm::errs() << "ERROR: Can't find built mlir::Value key.\n"; + return mlir::failure(); + } + std::string output_name = + "TosaOutput_" + std::to_string(output_tensor_index++); + tensor_map[val] = output_name; + block->GetOutputs().push_back(output_name); + return_values.push_back(val); + } + } + } + + // Build tensor + for (auto pair : tensor_map) { + ser_tensor = BuildTosaSerializationTensor(pair.first /* val */, + pair.second /* name */); + if (!ser_tensor) { + llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; + return mlir::failure(); + } + block->GetTensors().push_back(ser_tensor); + } + + // Build operator + for (auto &op : bb) { + if (llvm::isa(op) || llvm::isa(op) || + llvm::isa(op)) + continue; + ser_operator = BuildTosaSerializationOperator(op_builder, op); + if (!ser_operator) { + llvm::errs() << "ERROR: Failed to build TosaSerializationOperator\n"; + return mlir::failure(); + } + block->GetOperators().push_back(ser_operator); + } + } + + return mlir::success(); +} + +TosaSerializationOperator * +TosaSerializationBlockBuilder::BuildTosaSerializationOperator( + const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) { + std::string full_op_name = op.getName().getStringRef().str(); + TosaSerializationOperator *target_operator = nullptr; + + if (false) { + } +#define DEF_OPERATOR(MLIR_OP) \ + else if (llvm::isa(op)) { \ + target_operator = op_builder.build(op); \ + } +#include "operator.def" +#undef DEF_OPERATOR + else { + llvm::errs() << "unsupported tosa operator " << op.getName().getStringRef() + << "\n"; + } + + if (!target_operator) { + llvm::errs() << op.getName().getStringRef() + << " operator hasn't been translated to flatbuffer, skipped\n"; + return nullptr; + } + + // Sanity check the number of inputs/outputs of TOSA dialect matches the + // number of TOSA flatbuffer + if (op.getNumOperands() != target_operator->GetInputTensorNames().size()) { + llvm::errs() << "WARNING. MLIR operator has " << op.getNumOperands() + << " input tensors != Flatbuffer " + "operator has " + << target_operator->GetInputTensorNames().size() + << " input tensors\n"; + } + if (op.getNumResults() != target_operator->GetOutputTensorNames().size()) { + llvm::errs() << "WARNING. MLIR operator has " << op.getNumResults() + << " output tensors != Flatbuffer " + "operator has " + << target_operator->GetOutputTensorNames().size() + << " output tensors\n"; + } + + return target_operator; +} + +TosaSerializationTensor * +TosaSerializationBlockBuilder::BuildTosaSerializationTensor( + mlir::Value val, const std::string &name) { + // If tensor already created before, use that tensor directly, create a new + // one otherwise + TosaSerializationTensor *ts = block->GetTensorByName(name); + if (ts) { + return nullptr; + } + + mlir::RankedTensorType tensor = + val.getType().dyn_cast(); + std::vector shape(tensor.getShape().begin(), + tensor.getShape().end()); + DType type = Type2DType(tensor.getElementType()); + + ts = new TosaSerializationTensor(name, shape, type, std::vector()); + + return ts; +} + +mlir::LogicalResult translate2FlatBuffer(mlir::FuncOp &func, + TosaSerializationHandler &tsh) { + TosaSerializationBasicBlock *main_block; + + mlir::Region *main_region = func.getCallableRegion(); + std::vector main_returns; + + if (!main_region) { + llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n"; + return mlir::failure(); + } + + if (!tsh.GetBlocks().empty()) { + llvm::errs() << "Internal Error: TosaSerializationHandler's block list " + "must be empty\n"; + return mlir::failure(); + } + + main_block = new TosaSerializationBasicBlock( + std::string("main"), std::vector(), + std::vector(), std::vector(), + std::vector()); + assert(main_block); + tsh.GetBlocks().push_back(main_block); + + TosaSerializationBlockBuilder block_builder(main_block, &tsh, main_region); + if (block_builder.BuildAllOpsInRegion(main_returns).failed()) { + return mlir::failure(); + } + + if (main_returns.empty()) { + llvm::errs() << "Warning: graph doesn't have return values\n"; + } + + return mlir::success(); +} + +mlir::LogicalResult dumpTosaFlatbuffer(mlir::FuncOp &func) { + tosa::TosaSerializationHandler tsh; + + std::string tosa_flatbuffer_directory_fullpath; + if (translate2FlatBuffer(func, tsh).failed()) { + llvm::errs() << "Fail to translate TOSA MLIR to flatbuffer\n"; + return mlir::failure(); + } + + if (tsh.SaveFileTosaFlatbuffer(tosa_flatbuffer_filename.c_str())) { + llvm::errs() << "Fail to save flatbuffer " << tosa_flatbuffer_filename + << "\n"; + return mlir::failure(); + } + return mlir::success(); +} + +mlir::LogicalResult dumpTosaJSON(mlir::FuncOp &func) { + tosa::TosaSerializationHandler tsh; + + const char *tosa_schema = tosa_flatbuffer_schema.c_str(); + + if (!tosa_schema) { + llvm::errs() << "Flatbuffer schema not defined\n"; + return mlir::failure(); + } + + if (tsh.LoadFileSchema(tosa_schema)) { + llvm::errs() << "Error loading tosa schema file: " << tosa_schema << "\n"; + return mlir::failure(); + } + + std::string tosa_flatbuffer_directory_fullpath; + if (translate2FlatBuffer(func, tsh).failed()) { + llvm::errs() << "Fail to translate TOSA MLIR to flatbuffer\n"; + return mlir::failure(); + } + + if (tsh.SaveFileJson(tosa_flatbuffer_filename.c_str())) { + llvm::errs() << "Fail to save flatbuffer " << tosa_flatbuffer_filename + << "\n"; + return mlir::failure(); + } + + return mlir::success(); +} + +namespace mlir { + +namespace tosa { + +namespace { + +class TosaSerialize : public PassWrapper { +public: + TosaSerialize() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tosa-serialize"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Run the TOSA serialization (flatbuffer generation) pass"; + } + + void runOnFunction() override { + auto function = getFunction(); + + if (dumpTosaFlatbuffer(function).failed()) { + llvm::errs() << "Failed to generate TOSA flatbuffer...\n"; + return signalPassFailure(); + } + } +}; + +class TosaSerializeJSON : public PassWrapper { +public: + TosaSerializeJSON() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tosa-serialize-json"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Run the TOSA serialization (JSON generation) pass"; + } + + void runOnFunction() override { + auto function = getFunction(); + + if (dumpTosaJSON(function).failed()) { + llvm::errs() << "Failed to generate TOSA JSON...\n"; + return signalPassFailure(); + } + } +}; + +} // anonymous namespace + +// Creates an instance of the TOSA flatbuffer generation pass +std::unique_ptr> createTosaSerializePass() { + return std::make_unique(); +} + +std::unique_ptr> createTosaSerializeJSONPass() { + return std::make_unique(); +} + +static PassRegistration pass([] { + return createTosaSerializePass(); +}); + +static PassRegistration passJSON([] { + return createTosaSerializeJSONPass(); +}); + +} // namespace tosa +} // namespace mlir diff --git a/third_party/serialization_lib b/third_party/serialization_lib new file mode 160000 index 0000000..545a508 --- /dev/null +++ b/third_party/serialization_lib @@ -0,0 +1 @@ +Subproject commit 545a508429afe1d22760563d252839e13ecd12a3 -- cgit v1.2.1