diff options
-rw-r--r-- | .pre-commit-config.yaml | 15 | ||||
-rw-r--r-- | CMakeLists.txt | 8 | ||||
-rw-r--r-- | include/DeserializationPasses.h | 46 | ||||
-rw-r--r-- | include/DeserializationPasses.td | 25 | ||||
-rw-r--r-- | include/SerializationPasses.h | 6 | ||||
-rw-r--r-- | include/SerializationPasses.td | 2 | ||||
-rw-r--r-- | include/operator.def | 22 | ||||
-rw-r--r-- | include/schema_operator.def | 103 | ||||
-rw-r--r-- | src/TosaDeserialize.cpp | 2128 | ||||
-rw-r--r-- | src/TosaSerialize.cpp | 1735 | ||||
m--------- | third_party/serialization_lib | 0 |
11 files changed, 3494 insertions, 596 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e52e423 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +# Copyright (c) 2023 Arm Limited. +# SPDX-License-Identifier: Apache-2.0 + +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: local + hooks: + - id: clang-format + name: clang-format + exclude: build|third_party + language: system + entry: clang-format + types: ["c++"] + args: ["-i"] diff --git a/CMakeLists.txt b/CMakeLists.txt index 968d73b..6bc3e6b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,6 @@ cmake_minimum_required(VERSION 3.13.4) project(MlirTosaPasses) -set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ standard to conform to") set(CMAKE_CXX_STANDARD_REQUIRED YES) set(CMAKE_VERBOSE_MAKEFILE ON) @@ -12,21 +11,28 @@ set(CMAKE_VERBOSE_MAKEFILE ON) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${PROJECT_SOURCE_DIR}/third_party/serialization_lib/include) +include_directories(${PROJECT_SOURCE_DIR}/third_party/serialization_lib/third_party/half/include) include_directories(${PROJECT_SOURCE_DIR}/third_party/serialization_lib/third_party/flatbuffers/include) set(LLVM_TARGET_DEFINITIONS include/SerializationPasses.td) mlir_tablegen(include/SerializationPasses.h.inc -gen-pass-decls -name TosaSerialization) add_public_tablegen_target(tosa_serialization_passes_inc_gen) +set(LLVM_TARGET_DEFINITIONS include/DeserializationPasses.td) +mlir_tablegen(include/DeserializationPasses.h.inc -gen-pass-decls -name TosaDeserialization) +add_public_tablegen_target(tosa_deserialization_passes_inc_gen) + # Compile the TOSA serialization_lib add_subdirectory(third_party/serialization_lib) add_mlir_library(tosa_serialize src/TosaSerialize.cpp + src/TosaDeserialize.cpp DEPENDS mlir-headers tosa_serialization_passes_inc_gen + tosa_deserialization_passes_inc_gen LINK_LIBS PRIVATE tosa_serialization_lib diff --git a/include/DeserializationPasses.h b/include/DeserializationPasses.h new file mode 100644 index 0000000..1a38814 --- /dev/null +++ b/include/DeserializationPasses.h @@ -0,0 +1,46 @@ + +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef INCLUDE_DESERIALIZATION_PASSES_H +#define INCLUDE_DESERIALIZATION_PASSES_H + +#include <memory> + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tosa { + +std::unique_ptr<Pass> createTosaDeserializePass(); +std::unique_ptr<Pass> createTosaDeserializeJSONPass(); + +// deserializes a tosa file and return an mlir module +// if file_is_fbs is true, then treat file_name as a tosa flatbuffer file +// otherwise, treat file_name as a tosa json file +mlir::OwningOpRef<mlir::ModuleOp> +BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context, + bool file_is_fbs = true); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_CLASSES +#include "include/DeserializationPasses.h.inc" + +} // namespace tosa +} // namespace mlir + +#endif // INCLUDE_DESERIALIZATION_PASSES_H diff --git a/include/DeserializationPasses.td b/include/DeserializationPasses.td new file mode 100644 index 0000000..999f0b4 --- /dev/null +++ b/include/DeserializationPasses.td @@ -0,0 +1,25 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +include "mlir/Pass/PassBase.td" + +def TosaDeserializationPass : Pass<"tosa-deserialize", "func::FuncOp"> { + let summary = "Deserialize TOSA flatbuffer. Clear original MLIR graph and generate TOSA MLIR"; + let constructor = "createTosaDeserializePass()"; +} + +def TosaDeserializationJSONPass : Pass<"tosa-deserialize-json", "func::FuncOp"> { + let summary = "Deserialize TOSA flatbuffer JSON form. Clear original MLIR graph and generate TOSA MLIR"; + let constructor = "createTosaDeserializeJSONPass()"; +} diff --git a/include/SerializationPasses.h b/include/SerializationPasses.h index 66c6d80..c769b15 100644 --- a/include/SerializationPasses.h +++ b/include/SerializationPasses.h @@ -19,16 +19,18 @@ #include <memory> #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { namespace tosa { -std::unique_ptr<Pass> createTosaSerializePass(); +std::unique_ptr<OperationPass<ModuleOp>> createTosaSerializePass(); std::unique_ptr<Pass> createTosaSerializeJSONPass(); #define GEN_PASS_REGISTRATION #define GEN_PASS_CLASSES +#define GEN_PASS_DECL_TOSASERIALIZEPASS #include "include/SerializationPasses.h.inc" } // namespace tosa diff --git a/include/SerializationPasses.td b/include/SerializationPasses.td index 3bdeb1b..9cfc204 100644 --- a/include/SerializationPasses.td +++ b/include/SerializationPasses.td @@ -14,7 +14,7 @@ include "mlir/Pass/PassBase.td" -def TosaSerializationPass : Pass<"tosa-serialize", "func::FuncOp"> { +def TosaSerializationPass : Pass<"tosa-serialize", "mlir::ModuleOp"> { let summary = "Generate TOSA flatbuffer serialized form"; let constructor = "createTosaSerializePass()"; } diff --git a/include/operator.def b/include/operator.def index 85bb5c9..0bd0d08 100644 --- a/include/operator.def +++ b/include/operator.def @@ -1,8 +1,8 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // -// Licensed under the Apache License, Version 2.0 with LLVM Exceptions -// (the "License"); you may not use this file except in compliance with +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with // the License. You may obtain a copy of the License at // // https://llvm.org/LICENSE.txt @@ -27,12 +27,15 @@ DEF_OPERATOR(AvgPool2d) DEF_OPERATOR(Conv2D) DEF_OPERATOR(Conv3D) DEF_OPERATOR(DepthwiseConv2D) +DEF_OPERATOR(FFT2d) DEF_OPERATOR(FullyConnected) DEF_OPERATOR(MatMul) DEF_OPERATOR(MaxPool2d) +DEF_OPERATOR(RFFT2d) DEF_OPERATOR(TransposeConv2D) /* activation */ +DEF_OPERATOR(Erf) DEF_OPERATOR(Clamp) DEF_OPERATOR(Sigmoid) DEF_OPERATOR(Tanh) @@ -43,7 +46,7 @@ DEF_OPERATOR(ArithmeticRightShift) DEF_OPERATOR(BitwiseAnd) DEF_OPERATOR(BitwiseOr) DEF_OPERATOR(BitwiseXor) -DEF_OPERATOR(Div) +DEF_OPERATOR(IntDiv) DEF_OPERATOR(LogicalAnd) DEF_OPERATOR(LogicalLeftShift) DEF_OPERATOR(LogicalRightShift) @@ -61,6 +64,7 @@ DEF_OPERATOR(Abs) DEF_OPERATOR(BitwiseNot) DEF_OPERATOR(Ceil) DEF_OPERATOR(Clz) +DEF_OPERATOR(Cos) DEF_OPERATOR(Exp) DEF_OPERATOR(Floor) DEF_OPERATOR(Log) @@ -68,6 +72,7 @@ DEF_OPERATOR(LogicalNot) DEF_OPERATOR(Negate) DEF_OPERATOR(Reciprocal) DEF_OPERATOR(Rsqrt) +DEF_OPERATOR(Sin) /* elementwise - ternary */ DEF_OPERATOR(Select) @@ -88,6 +93,7 @@ DEF_OPERATOR(ReduceSum) /* memory operation */ DEF_OPERATOR(Concat) DEF_OPERATOR(Pad) +DEF_OPERATOR(Dim) DEF_OPERATOR(Reshape) DEF_OPERATOR(Reverse) DEF_OPERATOR(Slice) @@ -115,3 +121,11 @@ DEF_OPERATOR(Custom) /* control flow operators */ DEF_OPERATOR(If) DEF_OPERATOR(While) + +/* shape operators */ +DEF_OPERATOR(ConstShape) +DEF_OPERATOR(ConcatShape) +DEF_OPERATOR(AddShape) +DEF_OPERATOR(SubShape) +DEF_OPERATOR(MulShape) +DEF_OPERATOR(DivShape) diff --git a/include/schema_operator.def b/include/schema_operator.def new file mode 100644 index 0000000..02b639e --- /dev/null +++ b/include/schema_operator.def @@ -0,0 +1,103 @@ +// Copyright (c) 2023-2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + Syntax: + DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME) + + Description: + SCHEMA_OP_NAME: the schema operator names, must match Op names in schema/tosa.fbs in serialization_lib repo +*/ + +/* schema operators */ +DEF_SCHEMA_OPERATOR(ARGMAX) +DEF_SCHEMA_OPERATOR(AVG_POOL2D) +DEF_SCHEMA_OPERATOR(CONV2D) +DEF_SCHEMA_OPERATOR(CONV3D) +DEF_SCHEMA_OPERATOR(DEPTHWISE_CONV2D) +DEF_SCHEMA_OPERATOR(FULLY_CONNECTED) +DEF_SCHEMA_OPERATOR(MATMUL) +DEF_SCHEMA_OPERATOR(MAX_POOL2D) +DEF_SCHEMA_OPERATOR(TRANSPOSE_CONV2D) +DEF_SCHEMA_OPERATOR(CLAMP) +DEF_SCHEMA_OPERATOR(RESERVED) +DEF_SCHEMA_OPERATOR(ERF) +DEF_SCHEMA_OPERATOR(SIGMOID) +DEF_SCHEMA_OPERATOR(TANH) +DEF_SCHEMA_OPERATOR(ADD) +DEF_SCHEMA_OPERATOR(ARITHMETIC_RIGHT_SHIFT) +DEF_SCHEMA_OPERATOR(BITWISE_AND) +DEF_SCHEMA_OPERATOR(BITWISE_OR) +DEF_SCHEMA_OPERATOR(BITWISE_XOR) +DEF_SCHEMA_OPERATOR(INTDIV) +DEF_SCHEMA_OPERATOR(LOGICAL_AND) +DEF_SCHEMA_OPERATOR(LOGICAL_LEFT_SHIFT) +DEF_SCHEMA_OPERATOR(LOGICAL_RIGHT_SHIFT) +DEF_SCHEMA_OPERATOR(LOGICAL_OR) +DEF_SCHEMA_OPERATOR(LOGICAL_XOR) +DEF_SCHEMA_OPERATOR(MAXIMUM) +DEF_SCHEMA_OPERATOR(MINIMUM) +DEF_SCHEMA_OPERATOR(MUL) +DEF_SCHEMA_OPERATOR(POW) +DEF_SCHEMA_OPERATOR(SUB) +DEF_SCHEMA_OPERATOR(TABLE) +DEF_SCHEMA_OPERATOR(ABS) +DEF_SCHEMA_OPERATOR(BITWISE_NOT) +DEF_SCHEMA_OPERATOR(CEIL) +DEF_SCHEMA_OPERATOR(CLZ) +DEF_SCHEMA_OPERATOR(EXP) +DEF_SCHEMA_OPERATOR(FLOOR) +DEF_SCHEMA_OPERATOR(LOG) +DEF_SCHEMA_OPERATOR(LOGICAL_NOT) +DEF_SCHEMA_OPERATOR(NEGATE) +DEF_SCHEMA_OPERATOR(RECIPROCAL) +DEF_SCHEMA_OPERATOR(RSQRT) +DEF_SCHEMA_OPERATOR(SELECT) +DEF_SCHEMA_OPERATOR(EQUAL) +DEF_SCHEMA_OPERATOR(GREATER) +DEF_SCHEMA_OPERATOR(GREATER_EQUAL) +DEF_SCHEMA_OPERATOR(REDUCE_ANY) +DEF_SCHEMA_OPERATOR(REDUCE_ALL) +DEF_SCHEMA_OPERATOR(REDUCE_MAX) +DEF_SCHEMA_OPERATOR(REDUCE_MIN) +DEF_SCHEMA_OPERATOR(REDUCE_PRODUCT) +DEF_SCHEMA_OPERATOR(REDUCE_SUM) +DEF_SCHEMA_OPERATOR(CONCAT) +DEF_SCHEMA_OPERATOR(PAD) +DEF_SCHEMA_OPERATOR(DIM) +DEF_SCHEMA_OPERATOR(RESHAPE) +DEF_SCHEMA_OPERATOR(REVERSE) +DEF_SCHEMA_OPERATOR(SLICE) +DEF_SCHEMA_OPERATOR(TILE) +DEF_SCHEMA_OPERATOR(TRANSPOSE) +DEF_SCHEMA_OPERATOR(GATHER) +DEF_SCHEMA_OPERATOR(SCATTER) +DEF_SCHEMA_OPERATOR(RESIZE) +DEF_SCHEMA_OPERATOR(CAST) +DEF_SCHEMA_OPERATOR(RESCALE) +DEF_SCHEMA_OPERATOR(CONST) +DEF_SCHEMA_OPERATOR(IDENTITY) +DEF_SCHEMA_OPERATOR(CUSTOM) +DEF_SCHEMA_OPERATOR(COND_IF) +DEF_SCHEMA_OPERATOR(WHILE_LOOP) +DEF_SCHEMA_OPERATOR(FFT2D) +DEF_SCHEMA_OPERATOR(RFFT2D) +DEF_SCHEMA_OPERATOR(CONST_SHAPE) +DEF_SCHEMA_OPERATOR(CONCAT_SHAPE) +DEF_SCHEMA_OPERATOR(ADD_SHAPE) +DEF_SCHEMA_OPERATOR(SUB_SHAPE) +DEF_SCHEMA_OPERATOR(MUL_SHAPE) +DEF_SCHEMA_OPERATOR(DIV_SHAPE) +DEF_SCHEMA_OPERATOR(COS) +DEF_SCHEMA_OPERATOR(SIN)
\ No newline at end of file diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp new file mode 100644 index 0000000..215d760 --- /dev/null +++ b/src/TosaDeserialize.cpp @@ -0,0 +1,2128 @@ + +// Copyright (c) 2023-2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TOSA MLIR deserialize passes + +#include "include/DeserializationPasses.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tosa_serialization_handler.h" +#include <functional> +#include <queue> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +// The namespace might be confusing here. We have mlir::tosa:: defined in MLIR +// and tosa:: defined in serialization library +// TODO: align the namespace +using namespace tosa; + +namespace cl = llvm::cl; + +llvm::cl::opt<std::string> tosa_deserialize_filename( + "tosa-deserialize-filename", llvm::cl::desc("<tosa flatbuffer filename>"), + llvm::cl::init("tosa_dump.tosa"), llvm::cl::value_desc("filename")); + +llvm::cl::opt<std::string> tosa_deserialize_schema( + "tosa-deserialize-schema", llvm::cl::desc("<tosa flatbuffer schema file>"), + llvm::cl::init(""), llvm::cl::value_desc("filename")); + +const std::string kDefaultExportedName = "tosa_deserialized"; +const std::string kDefaultInputPrefix = "input_"; +const std::string kDefaultOutputPrefix = "output_"; +const std::string kDefaultFBSDescription = "Tosa FBS Converted"; +const std::string kDefaultJSONDescription = "Tosa JSON Converted"; +const std::string kMainFunctionName = "main"; + +namespace { + +// a global map from flatbuffer variable names to serialized tensors +std::unordered_map<std::string, TosaSerializationTensor *> variable_tensor_map; + +void RegisterVariableTensor(TosaSerializationTensor *ts) { + assert(ts->GetVariable()); + // insert variable tensor ts only if not already present + variable_tensor_map.insert({ts->GetName(), ts}); +} + +bool IsVariableTensor(const std::string flatbuffer_tensor_name) { + return variable_tensor_map.count(flatbuffer_tensor_name); +} + +// return the variable name corresponding to flatbuffer_tensor_name +const std::string GetVariableTensorName(TosaSerializationTensor *ts) { + assert(ts->GetVariable()); + const auto name = ts->GetVariableName(); + if (name == "") { + // for legacy flatbuffers which may not have variable_name fields + return ts->GetName(); + } + return name; +} + +// return the variable name corresponding to flatbuffer_tensor_name +const std::string +GetVariableTensorName(const std::string flatbuffer_tensor_name) { + if (!IsVariableTensor(flatbuffer_tensor_name)) { + llvm::errs() << "ERROR: Variable tensor " << flatbuffer_tensor_name + << " is not found in variable_tensor_map"; + return ""; + } + return GetVariableTensorName(variable_tensor_map[flatbuffer_tensor_name]); +} + +bool IsVariableReadOp(TosaSerializationOperator *op) { + return (op->GetOp() == tosa::Op::Op_IDENTITY) && + IsVariableTensor(op->GetInputTensorNames()[0]); +} + +bool IsVariableWriteOp(TosaSerializationOperator *op) { + return (op->GetOp() == tosa::Op::Op_IDENTITY) && + IsVariableTensor(op->GetOutputTensorNames()[0]); +} + +// construct tensor type from dtype and shape of TosaSerializationTensor +mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, + TosaSerializationTensor *ts, + mlir::RankedTensorType &type) { + mlir::Type element_type; + switch (ts->GetDtype()) { + case DType_BOOL: + element_type = op_builder->getI1Type(); + break; + case DType_UINT8: + element_type = op_builder->getIntegerType(8, false); + break; + case DType_INT4: + element_type = op_builder->getI4Type(); + break; + case DType_INT8: + element_type = op_builder->getI8Type(); + break; + case DType_INT16: + element_type = op_builder->getIntegerType(16); + break; + case DType_INT32: + element_type = op_builder->getI32Type(); + break; + case DType_INT48: + element_type = op_builder->getIntegerType(48); + break; + case DType_FP32: + element_type = op_builder->getF32Type(); + break; + case DType_UINT16: + element_type = op_builder->getIntegerType(16, false); + break; + case DType_FP16: + element_type = op_builder->getF16Type(); + break; + case DType_BF16: + element_type = op_builder->getBF16Type(); + break; + case DType_FP8E4M3: + element_type = op_builder->getFloat8E4M3FNType(); + break; + case DType_FP8E5M2: + element_type = op_builder->getFloat8E5M2Type(); + break; + case DType_SHAPE: + llvm::errs() + << "ERROR: Cannot construct RankedTensorType out of tosa.shape type \n"; + return mlir::failure(); + default: + llvm::errs() << "ERROR: unknown type " << EnumNamesDType()[ts->GetDtype()] + << "\n"; + return mlir::failure(); + } + llvm::SmallVector<int64_t> shape; + for (auto dim : ts->GetShape()) { + if (dim > 0) { + shape.push_back(dim); + } else { + // dynamic dim + shape.push_back(mlir::ShapedType::kDynamic); + } + } + type = mlir::RankedTensorType::get(llvm::ArrayRef(shape), element_type); + return mlir::success(); +} + +mlir::DenseElementsAttr GetConstAttr(const std::vector<uint8_t> &data, + const mlir::RankedTensorType &output_type, + uint32_t out_size) { + auto element_type = output_type.getElementType(); + if (element_type.isF32()) { + // for FP32, value attributes are stored as FP32 values + std::vector<float> float_data; + TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data); + return mlir::DenseElementsAttr::get(output_type, + llvm::ArrayRef(float_data)); + } + if (element_type.isBF16()) { + mlir::SmallVector<mlir::APFloat> bf16_data; + for (uint32_t i = 0; i < out_size; i++) { + uint64_t byte0 = data[i * sizeof(int16_t)]; + uint64_t byte1 = data[i * sizeof(int16_t) + 1]; + uint64_t bits = byte0 + (byte1 << 8); + mlir::APInt bf16_bits(16, bits); + mlir::APFloat bf16(mlir::APFloat::BFloat(), bf16_bits); + bf16_data.push_back(bf16); + } + return mlir::DenseElementsAttr::get(output_type, bf16_data); + } + if (element_type.isFloat8E4M3FN()) { + mlir::SmallVector<mlir::APFloat> f8_data; + for (uint32_t i = 0; i < out_size; i++) { + mlir::APInt f8_bits(8, static_cast<uint64_t>(data[i])); + mlir::APFloat f8(mlir::APFloat::Float8E4M3FN(), f8_bits); + f8_data.push_back(f8); + } + return mlir::DenseElementsAttr::get(output_type, f8_data); + } + if (element_type.isFloat8E5M2()) { + mlir::SmallVector<mlir::APFloat> f8_data; + for (uint32_t i = 0; i < out_size; i++) { + mlir::APInt f8_bits(8, static_cast<uint64_t>(data[i])); + mlir::APFloat f8(mlir::APFloat::Float8E5M2(), f8_bits); + f8_data.push_back(f8); + } + return mlir::DenseElementsAttr::get(output_type, f8_data); + } + if (element_type.isInteger(4)) { + std::vector<int8_t> int4_data; + TosaSerializationHandler::ConvertU8toI4(data, out_size, int4_data); + return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data)); + } + if (element_type.isInteger(8)) { + std::vector<int8_t> int8_data; + TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data); + return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data)); + } + if (element_type.isInteger(16)) { + std::vector<int16_t> int16_data; + TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data); + return mlir::DenseElementsAttr::get(output_type, + llvm::ArrayRef(int16_data)); + } + if (element_type.isInteger(32)) { + std::vector<int32_t> int32_data; + TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data); + return mlir::DenseElementsAttr::get(output_type, + llvm::ArrayRef(int32_data)); + } + if (element_type.isInteger(48)) { + std::vector<int64_t> int48_data; + TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data); + std::vector<mlir::APInt> apint_data; + for (const auto v : int48_data) { + mlir::APInt apint_value(48, static_cast<uint64_t>(v), + /* isSigned = */ false); + apint_data.push_back(apint_value); + } + return mlir::DenseElementsAttr::get(output_type, + llvm::ArrayRef(apint_data)); + } + if (element_type.isInteger(1)) { + std::vector<bool> bool_data; + TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data); + llvm::SmallVector<bool> bool_values(bool_data.begin(), bool_data.end()); + return mlir::DenseElementsAttr::get(output_type, bool_values); + } + if (element_type.isF16()) { + std::vector<half_float::half> half_data; + TosaSerializationHandler::ConvertU8toF16(data, out_size, half_data); + return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(half_data)); + } + + return nullptr; +} + +mlir::DenseElementsAttr +ConstructConstAttr(const mlir::RankedTensorType &output_type, + TosaSerializationTensor *ts, const std::string &op_name) { + // compute output data size + uint32_t out_size = 1; + for (const auto dim : ts->GetShape()) { + out_size *= dim; + } + auto attr = GetConstAttr(ts->GetData(), output_type, out_size); + if (!attr) { + llvm::errs() << "ERROR: " << op_name + << " contains unsupported element type\n"; + } + return attr; +} + +mlir::LogicalResult ConstructVariableOps(mlir::ModuleOp &module) { + if (variable_tensor_map.empty()) { + return mlir::success(); + } + auto loc = module.getLoc(); + auto op_builder = mlir::OpBuilder(module.getBodyRegion()); + for (auto [flatbuffer_name, ts] : variable_tensor_map) { + auto name = GetVariableTensorName(ts); + mlir::RankedTensorType type; + if (BuildTensorType(&op_builder, ts, type).failed()) { + return mlir::failure(); + } + + mlir::Attribute value_attr = nullptr; + if (!ts->GetData().empty()) { + value_attr = ConstructConstAttr(type, ts, name); + } + op_builder.create<mlir::tosa::VariableOp>(loc, llvm::StringRef(name), type, + value_attr); + } + + return mlir::success(); +} + +template <class T> +mlir::DenseElementsAttr BuildDenseI8ElementsAttr(mlir::OpBuilder *op_builder, + const std::vector<T> &values) { + llvm::SmallVector<int8_t> vec; + for (auto val : values) { + vec.push_back(val); + } + auto type = mlir::RankedTensorType::get({static_cast<int64_t>(vec.size())}, + op_builder->getI8Type()); + return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); +} + +template <class T> +mlir::DenseElementsAttr +BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder, + const std::vector<T> &values) { + llvm::SmallVector<int16_t> vec; + for (auto val : values) { + vec.push_back(val); + } + auto type = mlir::RankedTensorType::get({static_cast<int64_t>(vec.size())}, + op_builder->getI16Type()); + return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); +} + +template <class T> +mlir::DenseElementsAttr +BuildDenseI32ElementsAttr(mlir::OpBuilder *op_builder, + mlir::RankedTensorType &type, + const std::vector<T> &values) { + llvm::SmallVector<int32_t> vec; + for (auto val : values) { + vec.push_back(val); + } + return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); +} + +template <class T> +mlir::DenseI8ArrayAttr BuildDenseI8ArrayAttr(mlir::OpBuilder *op_builder, + const std::vector<T> &values) { + std::vector<int8_t> vec; + for (auto val : values) { + vec.push_back(val); + } + return op_builder->getDenseI8ArrayAttr(vec); +} + +template <class T> +mlir::DenseI32ArrayAttr BuildDenseI32ArrayAttr(mlir::OpBuilder *op_builder, + const std::vector<T> &values) { + std::vector<int32_t> vec; + for (auto val : values) { + vec.push_back(val); + } + return op_builder->getDenseI32ArrayAttr(vec); +} + +template <class T> +mlir::DenseI64ArrayAttr BuildDenseI64ArrayAttr(mlir::OpBuilder *op_builder, + const std::vector<T> &values) { + std::vector<int64_t> vec; + for (auto val : values) { + vec.push_back(val); + } + return op_builder->getDenseI64ArrayAttr(vec); +} + +const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) { + if (mode == ResizeMode_NEAREST) { + return "NEAREST_NEIGHBOR"; + } else if (mode == ResizeMode_BILINEAR) { + return "BILINEAR"; + } + return ""; +} + +// this is a counter part to Type2AccDType +mlir::Type AccDType2Type(mlir::OpBuilder *op_builder, DType dtype) { + // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>; + if (dtype == DType_INT32) { + return op_builder->getI32Type(); + } else if (dtype == DType_INT48) { + return op_builder->getIntegerType(48); + } else if (dtype == DType_FP32) { + return op_builder->getF32Type(); + } else if (dtype == DType_FP16) { + return op_builder->getF16Type(); + } else { + // unknown acc type + // for now, default to F32 + return op_builder->getF32Type(); + } +} + +} // namespace + +class TosaMlirRegionBuilder; +class TosaMlirBlockBuilder; + +class TosaMlirOperatorBuilder { +public: + TosaMlirOperatorBuilder( + mlir::OpBuilder *_op_builder, TosaSerializationBasicBlock *_ser_block, + mlir::Block *_block, mlir::Location _loc, + TosaMlirBlockBuilder *_block_builder, + std::unordered_map<std::string, mlir::Value> *_tensor_map, + std::unordered_map<std::string, mlir::RankedTensorType> *_tensor_type_map, + std::unordered_map<std::string, mlir::tosa::shapeType> *_shape_type_map) + : op_builder(_op_builder), ser_block(_ser_block), block(_block), + loc(_loc), block_builder(_block_builder), tensor_map(_tensor_map), + tensor_type_map(_tensor_type_map), shape_type_map(_shape_type_map) {} + + template <Op OPCODE> + std::vector<mlir::Value> build(TosaSerializationOperator *op) const; + + std::vector<mlir::Value> BuildVariableOp(TosaSerializationOperator *op) const; + + std::vector<mlir::Value> + BuildVariableReadOp(TosaSerializationOperator *op) const; + + void BuildVariableWriteOp(TosaSerializationOperator *op) const; + + std::string get_string(TosaSerializationOperator *op) const { + std::string op_string; + op_string += "operator opcode="; + op_string += EnumNamesOp()[op->GetOp()]; + op_string += ", input=["; + for (auto ts : op->GetInputTensorNames()) { + op_string += (ts + " "); + } + op_string += "], output=["; + for (auto ts : op->GetOutputTensorNames()) { + op_string += (ts + " "); + } + op_string += "]"; + return op_string; + } + + TosaSerializationHandler *GetTsh() const; + TosaMlirRegionBuilder *GetRegionBuilder() const; + +private: + template <class MLIR_OP> + std::vector<mlir::Value> + BuildEwiseUnaryOp(TosaSerializationOperator *op) const; + + template <class MLIR_OP> + std::vector<mlir::Value> + BuildEwiseBinaryOp(TosaSerializationOperator *op) const; + + template <class MLIR_OP> + std::vector<mlir::Value> + BuildEwiseBinaryShapeOp(TosaSerializationOperator *op) const; + + template <class MLIR_OP> + std::vector<mlir::Value> + BuildReductionOp(TosaSerializationOperator *op) const; + + template <class T> + mlir::Value BuildConstShape(mlir::OpBuilder *op_builder, mlir::Location loc, + const std::vector<T> &values) const; + + template <class MLIR_OP> + std::vector<mlir::Value> BuildConvOp(TosaSerializationOperator *op) const; + + mlir::OpBuilder *op_builder; + TosaSerializationBasicBlock *ser_block; + mlir::Block *block; + mlir::Location loc; + TosaMlirBlockBuilder *block_builder; + std::unordered_map<std::string, mlir::Value> *tensor_map; + std::unordered_map<std::string, mlir::RankedTensorType> *tensor_type_map; + std::unordered_map<std::string, mlir::tosa::shapeType> *shape_type_map; +}; + +// Main template to catch unimplemented translation +template <Op OPCODE> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { + llvm::errs() << "ERROR: " << get_string(op) + << " translation hasn't been implemented\n"; + return {}; +} + +// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_MAX_POOL2D>( + TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_PoolAttribute); // double check attribute type + TosaPoolAttribute *attr = + static_cast<TosaPoolAttribute *>(op->GetAttribute()); + mlir::DenseI64ArrayAttr kernel = + BuildDenseI64ArrayAttr(op_builder, attr->kernel()); + mlir::DenseI64ArrayAttr stride = + BuildDenseI64ArrayAttr(op_builder, attr->stride()); + mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); + int32_t input_zp = attr->input_zp(); + int32_t output_zp = attr->output_zp(); + assert(input_zp == 0 && output_zp == 0); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MaxPool2dOp>( + loc, output_type, input_val, kernel, stride, pad); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +// BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D) +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>( + TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_PoolAttribute); // double check attribute type + TosaPoolAttribute *attr = + static_cast<TosaPoolAttribute *>(op->GetAttribute()); + mlir::DenseI64ArrayAttr kernel = + BuildDenseI64ArrayAttr(op_builder, attr->kernel()); + mlir::DenseI64ArrayAttr stride = + BuildDenseI64ArrayAttr(op_builder, attr->stride()); + mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); + auto acc_attr = + mlir::TypeAttr::get(AccDType2Type(op_builder, attr->acc_type())); + + int32_t input_zp = attr->input_zp(); + int32_t output_zp = attr->output_zp(); + mlir::Operation *mlir_op; + if (input_zp == 0 && output_zp == 0) { + mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>( + loc, output_type, input_val, kernel, stride, pad, acc_attr); + } else { + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp); + mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>( + loc, output_type, input_val, kernel, stride, pad, acc_attr, + input_zp_attr, output_zp_attr); + } + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildVariableReadOp( + TosaSerializationOperator *op) const { + auto input_tensor_name = op->GetInputTensorNames()[0]; + auto output_tensor_name = op->GetOutputTensorNames()[0]; + + assert(IsVariableTensor(input_tensor_name)); + + auto variable_name = GetVariableTensorName(input_tensor_name); + mlir::RankedTensorType output_type = tensor_type_map->at(output_tensor_name); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::VariableReadOp>( + loc, output_type, llvm::StringRef(variable_name)); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +void TosaMlirOperatorBuilder::BuildVariableWriteOp( + TosaSerializationOperator *op) const { + auto input_tensor_name = op->GetInputTensorNames()[0]; + auto output_tensor_name = op->GetOutputTensorNames()[0]; + + assert(IsVariableTensor(output_tensor_name)); + auto variable_name = GetVariableTensorName(output_tensor_name); + mlir::Value input_val = tensor_map->at(input_tensor_name); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::VariableWriteOp>( + loc, llvm::StringRef(variable_name), input_val); + block->push_back(mlir_op); +} + +template <class MLIR_OP> +std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseUnaryOp( + TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = + op_builder->create<MLIR_OP>(loc, output_type, input_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <class MLIR_OP> +std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseBinaryOp( + TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = + op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <class MLIR_OP> +std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseBinaryShapeOp( + TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::tosa::shapeType output_type = + shape_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = + op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <class MLIR_OP> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::BuildReductionOp(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_AxisAttribute); // double check attribute type + + TosaAxisAttribute *attr = + static_cast<TosaAxisAttribute *>(op->GetAttribute()); + auto axis = op_builder->getI32IntegerAttr(attr->axis()); + + mlir::Operation *mlir_op = + op_builder->create<MLIR_OP>(loc, output_type, input_val, axis); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +#define BUILD_OP_ELEMENTWISE_UNARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector<mlir::Value> \ + TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \ + TosaSerializationOperator * op) const { \ + return BuildEwiseUnaryOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \ + } + +#define BUILD_OP_ELEMENTWISE_BINARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector<mlir::Value> \ + TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \ + TosaSerializationOperator * op) const { \ + return BuildEwiseBinaryOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \ + } + +#define BUILD_OP_ELEMENTWISE_BINARY_SHAPE(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector<mlir::Value> \ + TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \ + TosaSerializationOperator * op) const { \ + return BuildEwiseBinaryShapeOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \ + } + +#define BUILD_OP_REDUCTION(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector<mlir::Value> \ + TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \ + TosaSerializationOperator * op) const { \ + return BuildReductionOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \ + } + +// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) +// BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D) + +BUILD_OP_ELEMENTWISE_BINARY(Add, ADD) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR) +BUILD_OP_ELEMENTWISE_BINARY(IntDiv, INTDIV) +BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND) +BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT) +BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT) +BUILD_OP_ELEMENTWISE_BINARY(LogicalOr, LOGICAL_OR) +BUILD_OP_ELEMENTWISE_BINARY(LogicalXor, LOGICAL_XOR) +BUILD_OP_ELEMENTWISE_BINARY(Maximum, MAXIMUM) +BUILD_OP_ELEMENTWISE_BINARY(Minimum, MINIMUM) +BUILD_OP_ELEMENTWISE_BINARY(Pow, POW) +BUILD_OP_ELEMENTWISE_BINARY(Sub, SUB) + +BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS) +BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT) +BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL) +BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ) +BUILD_OP_ELEMENTWISE_UNARY(Cos, COS) +BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP) +BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR) +BUILD_OP_ELEMENTWISE_UNARY(Log, LOG) +BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT) +BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL) +BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT) +BUILD_OP_ELEMENTWISE_UNARY(Sin, SIN) + +BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY) +BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL) +BUILD_OP_REDUCTION(ReduceMax, REDUCE_MAX) +BUILD_OP_REDUCTION(ReduceMin, REDUCE_MIN) +BUILD_OP_REDUCTION(ReduceProd, REDUCE_PRODUCT) +BUILD_OP_REDUCTION(ReduceSum, REDUCE_SUM) + +BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL) +BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER) +BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL) + +BUILD_OP_ELEMENTWISE_UNARY(Erf, ERF) +BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID) +BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH) +BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY) +BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST) + +BUILD_OP_ELEMENTWISE_BINARY_SHAPE(AddShape, ADD_SHAPE) +BUILD_OP_ELEMENTWISE_BINARY_SHAPE(SubShape, SUB_SHAPE) +BUILD_OP_ELEMENTWISE_BINARY_SHAPE(MulShape, MUL_SHAPE) +BUILD_OP_ELEMENTWISE_BINARY_SHAPE(DivShape, DIV_SHAPE) + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_CONST>(TosaSerializationOperator *op) const { + const auto &output_name = op->GetOutputTensorNames()[0]; + mlir::RankedTensorType output_type = tensor_type_map->at(output_name); + TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name); + auto value_attr = ConstructConstAttr(output_type, ts, get_string(op)); + if (!value_attr) { + return {}; + } + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::ConstOp>(loc, output_type, value_attr); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <class T> +mlir::Value +TosaMlirOperatorBuilder::BuildConstShape(mlir::OpBuilder *op_builder, + mlir::Location loc, + const std::vector<T> &values) const { + std::vector<int64_t> vec; + for (auto val : values) { + vec.push_back(val); + } + auto attr = op_builder->getIndexTensorAttr(vec); + auto type = mlir::tosa::shapeType::get(op_builder->getContext(), + /* rank = */ vec.size()); + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::ConstShapeOp>(loc, type, attr); + block->push_back(mlir_op); + return mlir_op->getResult(0); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONST_SHAPE>( + TosaSerializationOperator *op) const { + const auto &output_name = op->GetOutputTensorNames()[0]; + mlir::tosa::shapeType output_type = shape_type_map->at(output_name); + TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name); + + const auto &data = ts->GetData(); + + std::vector<int64_t> i64_data; + TosaSerializationHandler::ConvertU8toI64(data, output_type.getRank(), + i64_data); + mlir::Value result = BuildConstShape(op_builder, loc, i64_data); + return std::vector<mlir::Value>({result}); +} + +template <class MLIR_OP> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_ConvAttribute); // double check attribute type + TosaConvAttribute *attr = + static_cast<TosaConvAttribute *>(op->GetAttribute()); + mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); + mlir::DenseI64ArrayAttr stride = + BuildDenseI64ArrayAttr(op_builder, attr->stride()); + mlir::DenseI64ArrayAttr dilation = + BuildDenseI64ArrayAttr(op_builder, attr->dilation()); + auto input_zp = attr->input_zp(); + auto weight_zp = attr->weight_zp(); + bool local_bound = attr->local_bound(); + auto acc_type = AccDType2Type(op_builder, attr->acc_type()); + + // input_zp/weight_zp is not allowed for float type + mlir::Operation *mlir_op; + if (output_type.getElementType().isa<mlir::FloatType>()) { + assert(input_zp == 0 && weight_zp == 0); + } + + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); + mlir_op = op_builder->create<MLIR_OP>( + loc, output_type, input0_val, input1_val, input2_val, pad, stride, + dilation, acc_type, input_zp_attr, weight_zp_attr, local_bound); + + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +#define BUILD_OP_CONV(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector<mlir::Value> \ + TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \ + TosaSerializationOperator * op) const { \ + return BuildConvOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \ + } + +BUILD_OP_CONV(Conv2D, CONV2D) +BUILD_OP_CONV(Conv3D, CONV3D) +BUILD_OP_CONV(DepthwiseConv2D, DEPTHWISE_CONV2D) + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE_CONV2D>( + TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_TransposeConvAttribute); // double check attribute type + TosaTransposeConvAttribute *attr = + static_cast<TosaTransposeConvAttribute *>(op->GetAttribute()); + mlir::DenseI64ArrayAttr out_pad = + BuildDenseI64ArrayAttr(op_builder, attr->out_pad()); + mlir::DenseI64ArrayAttr stride = + BuildDenseI64ArrayAttr(op_builder, attr->stride()); + auto input_zp = attr->input_zp(); + auto weight_zp = attr->weight_zp(); + bool local_bound = attr->local_bound(); + auto acc_type = AccDType2Type(op_builder, attr->acc_type()); + + // input_zp/weight_zp is not allowed for float type + mlir::Operation *mlir_op; + if (output_type.getElementType().isa<mlir::FloatType>()) { + assert(input_zp == 0 && weight_zp == 0); + } + + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); + + mlir_op = op_builder->create<mlir::tosa::TransposeConv2DOp>( + loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, + acc_type, input_zp_attr, weight_zp_attr, local_bound); + + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_FULLY_CONNECTED>( + TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_FullyConnectedAttribute); // double check attribute type + TosaFullyConnectedAttribute *attr = + static_cast<TosaFullyConnectedAttribute *>(op->GetAttribute()); + auto input_zp = attr->input_zp(); + auto weight_zp = attr->weight_zp(); + + // input_zp/weight_zp is not allowed for float type + mlir::Operation *mlir_op; + if (output_type.getElementType().isa<mlir::FloatType>()) { + assert(input_zp == 0 && weight_zp == 0); + mlir_op = op_builder->create<mlir::tosa::FullyConnectedOp>( + loc, output_type, input0_val, input1_val, input2_val); + } else { + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); + mlir_op = op_builder->create<mlir::tosa::FullyConnectedOp>( + loc, output_type, input0_val, input1_val, input2_val, input_zp_attr, + weight_zp_attr); + } + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_MATMUL>(TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_MatMulAttribute); // double check attribute type + TosaMatMulAttribute *attr = + static_cast<TosaMatMulAttribute *>(op->GetAttribute()); + auto A_zp = attr->a_zp(); + auto B_zp = attr->b_zp(); + + mlir::Operation *mlir_op; + if (A_zp == 0 && B_zp == 0) { + mlir_op = op_builder->create<mlir::tosa::MatMulOp>(loc, output_type, + input0_val, input1_val); + } else { + auto a_zp_attr = op_builder->getI32IntegerAttr(A_zp); + auto b_zp_attr = op_builder->getI32IntegerAttr(B_zp); + mlir_op = op_builder->create<mlir::tosa::MatMulOp>( + loc, output_type, input0_val, input1_val, a_zp_attr, b_zp_attr); + } + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_SELECT>(TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::SelectOp>( + loc, output_type, input0_val, input1_val, input2_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_ClampAttribute); // double check attribute type + TosaClampAttribute *attr = + static_cast<TosaClampAttribute *>(op->GetAttribute()); + + mlir::Type element_type = + llvm::cast<mlir::ShapedType>(input_val.getType()).getElementType(); + if (auto quantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(element_type)) { + element_type = quantType.getStorageType(); + } + + auto element_const_type = mlir::RankedTensorType::get({1}, element_type); + auto min_values_attr = GetConstAttr(attr->min_val(), element_const_type, 1); + auto max_values_attr = GetConstAttr(attr->max_val(), element_const_type, 1); + + mlir::Attribute min_val_attr, max_val_attr; + if (element_type.isa<mlir::FloatType>()) { + min_val_attr = op_builder->getFloatAttr( + element_type, min_values_attr.getValues<mlir::APFloat>()[0]); + max_val_attr = op_builder->getFloatAttr( + element_type, max_values_attr.getValues<mlir::APFloat>()[0]); + } else { + min_val_attr = op_builder->getIntegerAttr( + element_type, min_values_attr.getValues<mlir::APInt>()[0]); + max_val_attr = op_builder->getIntegerAttr( + element_type, max_values_attr.getValues<mlir::APInt>()[0]); + } + + auto mlir_op = op_builder->create<mlir::tosa::ClampOp>( + loc, output_type, input_val, min_val_attr, max_val_attr); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +// ArgMax has single input, and single I64 axis attribute +BUILD_OP_REDUCTION(ArgMax, ARGMAX) + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_CONCAT>(TosaSerializationOperator *op) const { + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + llvm::SmallVector<mlir::Value> input_values; + for (auto &input_name : op->GetInputTensorNames()) { + mlir::Value input_val = tensor_map->at(input_name); + input_values.push_back(input_val); + } + + assert(op->GetAttributeType() == + Attribute_AxisAttribute); // double check attribute type + TosaAxisAttribute *attr = + static_cast<TosaAxisAttribute *>(op->GetAttribute()); + auto axis = op_builder->getI32IntegerAttr(attr->axis()); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ConcatOp>( + loc, output_type, input_values, axis); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONCAT_SHAPE>( + TosaSerializationOperator *op) const { + mlir::tosa::shapeType output_type = + shape_type_map->at(op->GetOutputTensorNames()[0]); + + llvm::SmallVector<mlir::Value> input_values; + for (auto &input_name : op->GetInputTensorNames()) { + mlir::Value input_val = tensor_map->at(input_name); + input_values.push_back(input_val); + } + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ConcatShapeOp>( + loc, output_type, input_values); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_NEGATE>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_NegateAttribute); // double check attribute type + TosaNegateAttribute *attr = + static_cast<TosaNegateAttribute *>(op->GetAttribute()); + + auto input_zp = attr->input1_zp(); + auto output_zp = attr->output_zp(); + + mlir::Operation *mlir_op; + if (input_zp == 0 && output_zp == 0) { + mlir_op = + op_builder->create<mlir::tosa::NegateOp>(loc, output_type, input_val); + } else { + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp); + mlir_op = op_builder->create<mlir::tosa::NegateOp>( + loc, output_type, input_val, input_zp_attr, output_zp_attr); + } + + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESHAPE>( + TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + mlir::Value shape_val = tensor_map->at(op->GetInputTensorNames()[1]); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ReshapeOp>( + loc, output_type, input_val, shape_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value padding_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType input_type = + tensor_type_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + const auto element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); + + assert(op->GetAttributeType() == + Attribute_PadAttribute); // double check attribute type + TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute()); + const auto &pad_const_u8_data = attr->pad_const(); + + // check for any value in pad_const_u8_data + bool has_pad_const = false; + for (auto v : pad_const_u8_data) { + if (v != 0) { + has_pad_const = true; + break; + } + } + if (!has_pad_const) { + // handle the cases where no explicit pad_const input. + auto mlir_op = op_builder->create<mlir::tosa::PadOp>( + loc, output_type, input_val, padding_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); + } + + // has pad const - create a const op for pad_const input + auto pad_const_type = mlir::RankedTensorType::get({}, element_type); + auto pad_const_attr = GetConstAttr(pad_const_u8_data, pad_const_type, 1); + + auto pad_const_op = op_builder->create<mlir::tosa::ConstOp>( + loc, pad_const_type, pad_const_attr); + + block->push_back(pad_const_op); + mlir::Value pad_const_value = pad_const_op->getResult(0); + + auto mlir_op = op_builder->create<mlir::tosa::PadOp>( + loc, output_type, input_val, padding_val, pad_const_value); + + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_DIM>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_AxisAttribute); // double check attribute type + TosaAxisAttribute *attr = + static_cast<TosaAxisAttribute *>(op->GetAttribute()); + auto axis = op_builder->getI32IntegerAttr(attr->axis()); + + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::DimOp>(loc, output_type, input_val, axis); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE>( + TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_TransposeAttribute); // double check attribute type + TosaTransposeAttribute *attr = + static_cast<TosaTransposeAttribute *>(op->GetAttribute()); + // make a constant op from attr->perms values, of type: { shape = { perms.size + // }, element_type = I32 } + const auto perms_values = attr->perms(); + auto const_type = mlir::RankedTensorType::get( + {static_cast<int64_t>(perms_values.size())}, op_builder->getI32Type()); + mlir::DenseElementsAttr const_attr = + BuildDenseI32ElementsAttr(op_builder, const_type, perms_values); + mlir::Operation *mlir_const_op = + op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr); + auto perms_val = mlir_const_op->getResult(0); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TransposeOp>( + loc, output_type, input_val, perms_val); + + block->push_back(mlir_const_op); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_SLICE>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + mlir::Value start = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value size = tensor_map->at(op->GetInputTensorNames()[2]); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::SliceOp>( + loc, output_type, input_val, start, size); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_TILE>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value multiples = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_NONE); // double check attribute type + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TileOp>( + loc, output_type, input_val, multiples); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +// Gather is a binary op +BUILD_OP_ELEMENTWISE_BINARY(Gather, GATHER) + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_SCATTER>( + TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ScatterOp>( + loc, output_type, input0_val, input1_val, input2_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_RESIZE>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value scale_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value offset_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::Value border_val = tensor_map->at(op->GetInputTensorNames()[3]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_ResizeAttribute); // double check attribute type + TosaResizeAttribute *attr = + static_cast<TosaResizeAttribute *>(op->GetAttribute()); + + auto mode = op_builder->getStringAttr(ResizeEnum2Str(attr->mode())); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ResizeOp>( + loc, output_type, input_val, scale_val, offset_val, border_val, mode); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +// Reverse has single input, and single I64 axis attribute +BUILD_OP_REDUCTION(Reverse, REVERSE) + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_MUL>(TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]); + + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_MulAttribute); // double check attribute type + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MulOp>( + loc, output_type, input0_val, input1_val, shift_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_ARITHMETIC_RIGHT_SHIFT>( + TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert( + op->GetAttributeType() == + Attribute_ArithmeticRightShiftAttribute); // double check attribute type + TosaArithmeticRightShiftAttribute *attr = + static_cast<TosaArithmeticRightShiftAttribute *>(op->GetAttribute()); + + auto round = op_builder->getBoolAttr(attr->round()); + + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::ArithmeticRightShiftOp>( + loc, output_type, input0_val, input1_val, round); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_TableAttribute); // double check attribute type + TosaTableAttribute *attr = + static_cast<TosaTableAttribute *>(op->GetAttribute()); + + // create a const op for table value attribute + const auto table_values = attr->table(); + mlir::RankedTensorType const_type; + mlir::DenseElementsAttr const_attr; + const auto input_element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); + if (input_element_type.isInteger(8)) { + // table is signed 8 mode + const_type = mlir::RankedTensorType::get( + {static_cast<int64_t>(table_values.size())}, op_builder->getI8Type()); + const_attr = BuildDenseI8ElementsAttr(op_builder, table_values); + } else { + // table is signed 16 mode + const_type = mlir::RankedTensorType::get( + {static_cast<int64_t>(table_values.size())}, op_builder->getI16Type()); + const_attr = BuildDenseI16ElementsAttr(op_builder, table_values); + } + mlir::Operation *mlir_const_op = + op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr); + auto table_value = mlir_const_op->getResult(0); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TableOp>( + loc, output_type, input_val, table_value); + block->push_back(mlir_const_op); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESCALE>( + TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value multiplier_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_RescaleAttribute); // double check attribute type + TosaRescaleAttribute *attr = + static_cast<TosaRescaleAttribute *>(op->GetAttribute()); + + auto input_zp = op_builder->getI32IntegerAttr(attr->input_zp()); + auto output_zp = op_builder->getI32IntegerAttr(attr->output_zp()); + + auto scale32 = op_builder->getBoolAttr(attr->scale32()); + auto double_round = op_builder->getBoolAttr(attr->double_round()); + auto per_channel = op_builder->getBoolAttr(attr->per_channel()); + + auto input_unsigned = op_builder->getBoolAttr(attr->input_unsigned()); + auto output_unsigned = op_builder->getBoolAttr(attr->output_unsigned()); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RescaleOp>( + loc, output_type, input_val, multiplier_val, shift_val, input_zp, + output_zp, scale32, double_round, per_channel, input_unsigned, + output_unsigned); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_CUSTOM>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_CustomAttribute); // double check attribute type + TosaCustomAttribute *attr = + static_cast<TosaCustomAttribute *>(op->GetAttribute()); + + auto operator_name = op_builder->getStringAttr(attr->operator_name()); + auto domain_name = op_builder->getStringAttr(attr->domain_name()); + std::string impl_str; + impl_str.resize(attr->implementation_attrs().size() + 1); + int idx = 0; + for (auto c : attr->implementation_attrs()) { + impl_str[idx++] = c; + } + auto impl = op_builder->getStringAttr(impl_str); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::CustomOp>( + loc, output_type, operator_name, domain_name, impl, input_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_RFFT2D>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output0_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + mlir::RankedTensorType output1_type = + tensor_type_map->at(op->GetOutputTensorNames()[1]); + assert(op->GetAttributeType() == + Attribute_RFFTAttribute); // double check attribute type + TosaRFFTAttribute *attr = + static_cast<TosaRFFTAttribute *>(op->GetAttribute()); + + bool local_bound = attr->local_bound(); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RFFT2dOp>( + loc, output0_type, output1_type, input_val, local_bound); + block->push_back(mlir_op); + return std::vector<mlir::Value>( + {mlir_op->getResult(0), mlir_op->getResult(1)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_FFT2D>(TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output0_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + mlir::RankedTensorType output1_type = + tensor_type_map->at(op->GetOutputTensorNames()[1]); + + assert(op->GetAttributeType() == Attribute_FFTAttribute); + TosaFFTAttribute *attr = static_cast<TosaFFTAttribute *>(op->GetAttribute()); + auto inverse = op_builder->getBoolAttr(attr->inverse()); + auto local_bound = op_builder->getBoolAttr(attr->local_bound()); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::FFT2dOp>( + loc, output0_type, output1_type, input0_val, input1_val, inverse, + local_bound); + block->push_back(mlir_op); + return std::vector<mlir::Value>( + {mlir_op->getResult(0), mlir_op->getResult(1)}); +} + +class TosaMlirRegionBuilder { +public: + TosaMlirRegionBuilder(TosaSerializationRegion *_ser_region, + TosaSerializationHandler *_tsh, mlir::Region *_region, + mlir::OpBuilder *_op_builder, mlir::Location _loc, + TosaMlirRegionBuilder *_parent_value_scope = nullptr) + : ser_region(_ser_region), tsh(_tsh), region(_region), + op_builder(_op_builder), loc(_loc) { + if (_parent_value_scope) { + // inherit parent_value_scope's tensor_map + for (auto &kv : _parent_value_scope->GetTensorMap()) { + tensor_map.insert(kv); + } + } + } + + mlir::LogicalResult + BuildAllBlocksInRegion(std::vector<mlir::Value> &return_values); + + mlir::OpBuilder *GetOpBuilder() { return op_builder; } + mlir::Location GetLocation() { return loc; } + std::unordered_map<std::string, mlir::Value> &GetTensorMap() { + return tensor_map; + } + TosaSerializationHandler *GetTsh() const { return tsh; } + +private: + mlir::Region *region; + TosaSerializationRegion *ser_region; + TosaSerializationHandler *tsh; + mlir::OpBuilder *op_builder; + mlir::Location loc; + std::unordered_map<std::string, mlir::Value> tensor_map; +}; + +class TosaMlirBlockBuilder { +public: + TosaMlirBlockBuilder(TosaSerializationBasicBlock *_ser_block, + TosaMlirRegionBuilder *_region_builder, + mlir::Block *_block) + : ser_block(_ser_block), region_builder(_region_builder), block(_block) {} + + mlir::LogicalResult + BuildAllOpsInBlock(std::vector<mlir::Value> &return_values); + + mlir::OpBuilder *GetOpBuilder() { return region_builder->GetOpBuilder(); } + mlir::Location GetLocation() { return region_builder->GetLocation(); } + std::unordered_map<std::string, mlir::Value> &GetTensorMap() { + return region_builder->GetTensorMap(); + } + + TosaSerializationHandler *GetTsh() const { return region_builder->GetTsh(); } + TosaMlirRegionBuilder *GetRegionBuilder() const { return region_builder; } + +private: + TosaSerializationBasicBlock *ser_block; + TosaMlirRegionBuilder *region_builder; + mlir::Block *block; + std::unordered_map<std::string, mlir::RankedTensorType> tensor_type_map; + std::unordered_map<std::string, mlir::tosa::shapeType> shape_type_map; + std::unordered_set<std::string> unranked_tensors; +}; + +TosaSerializationHandler *TosaMlirOperatorBuilder::GetTsh() const { + return block_builder->GetTsh(); +} + +TosaMlirRegionBuilder *TosaMlirOperatorBuilder::GetRegionBuilder() const { + return block_builder->GetRegionBuilder(); +} + +// build control flow ops: + +namespace { + +mlir::LogicalResult +BuildRegion(TosaSerializationRegion *ser_region, TosaSerializationHandler *tsh, + mlir::Region *mlir_region, mlir::OpBuilder *op_builder, + mlir::Location loc, std::vector<mlir::Value> &return_values, + bool isolated_from_above = false, + TosaMlirRegionBuilder *parent_region_builder = nullptr) { + TosaMlirRegionBuilder *parent_value_scope = + isolated_from_above ? nullptr : parent_region_builder; + TosaMlirRegionBuilder region_builder(ser_region, tsh, mlir_region, op_builder, + loc, parent_value_scope); + return region_builder.BuildAllBlocksInRegion(return_values); +} + +} // namespace + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_COND_IF>( + TosaSerializationOperator *op) const { + mlir::Value cond_val = tensor_map->at(op->GetInputTensorNames().at(0)); + std::vector<mlir::Value> input_values; + for (auto idx = 1u; idx < op->GetInputTensorNames().size(); idx++) { + input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx))); + } + std::vector<mlir::Type> output_types; + for (auto &name : op->GetInputTensorNames()) { + output_types.push_back(tensor_type_map->at(name)); + } + + assert(op->GetAttributeType() == + Attribute_CondIfAttribute); // double check attribute type + TosaCondIfAttribute *attr = + static_cast<TosaCondIfAttribute *>(op->GetAttribute()); + auto ser_then_region = GetTsh()->GetRegionByName(attr->then_graph()); + auto ser_else_region = GetTsh()->GetRegionByName(attr->else_graph()); + + if (!ser_then_region || !ser_else_region) { + llvm::errs() << "ERROR: " << get_string(op) + << " region serialization hasn't been implemented\n"; + return {}; + } + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::IfOp>( + loc, output_types, cond_val, input_values); + + const bool isolated_from_above = + mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>(); + mlir::Region &then_region = mlir_op->getRegion(0); + mlir::Region &else_region = mlir_op->getRegion(1); + + auto curr_region_builder = GetRegionBuilder(); + + std::vector<mlir::Value> then_returns, else_returns; + + if (BuildRegion(ser_then_region, GetTsh(), &then_region, op_builder, loc, + then_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (then_returns.size() != mlir_op->getNumResults()) { + llvm::errs() + << "ERROR: " << get_string(op) + << " then_region yield.size() doesn't match cond_if's output size\n"; + return {}; + } + + if (BuildRegion(ser_else_region, GetTsh(), &else_region, op_builder, loc, + else_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (else_returns.size() != mlir_op->getNumResults()) { + llvm::errs() + << "ERROR: " << get_string(op) + << " else_region yield.size() doesn't match cond_if's output size\n"; + return {}; + } + + block->push_back(mlir_op); + return std::vector<mlir::Value>(mlir_op->getResults().begin(), + mlir_op->getResults().end()); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_WHILE_LOOP>( + TosaSerializationOperator *op) const { + std::vector<mlir::Value> input_values; + for (auto idx = 0u; idx < op->GetInputTensorNames().size(); idx++) { + input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx))); + } + std::vector<mlir::Type> output_types; + for (auto &name : op->GetInputTensorNames()) { + output_types.push_back(tensor_type_map->at(name)); + } + assert(op->GetAttributeType() == + Attribute_WhileLoopAttribute); // double check attribute type + TosaWhileLoopAttribute *attr = + static_cast<TosaWhileLoopAttribute *>(op->GetAttribute()); + auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_graph()); + auto ser_body_region = GetTsh()->GetRegionByName(attr->body_graph()); + + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::WhileOp>(loc, output_types, input_values); + + const bool isolated_from_above = + mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>(); + + mlir::Region &cond_region = mlir_op->getRegion(0); + mlir::Region &body_region = mlir_op->getRegion(1); + + auto curr_region_builder = GetRegionBuilder(); + + std::vector<mlir::Value> cond_returns, body_returns; + + if (BuildRegion(ser_cond_region, GetTsh(), &cond_region, op_builder, loc, + cond_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (cond_returns.size() != 1) { + llvm::errs() << "ERROR: " << get_string(op) + << " cond_region yield.size() is not 1\n"; + return {}; + } + + if (BuildRegion(ser_body_region, GetTsh(), &body_region, op_builder, loc, + body_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (body_returns.size() != mlir_op->getNumResults()) { + llvm::errs() + << "ERROR: " << get_string(op) + << " body_region yield.size() doesn't match while_loop's output size\n"; + return {}; + } + + block->push_back(mlir_op); + return std::vector<mlir::Value>(mlir_op->getResults().begin(), + mlir_op->getResults().end()); +} + +mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( + std::vector<mlir::Value> &return_values) { + block->clear(); + auto loc = GetLocation(); + auto op_builder = GetOpBuilder(); + auto &tensor_map = GetTensorMap(); + + std::unordered_set<TosaSerializationOperator *> operator_built; + std::queue<TosaSerializationOperator *> operator_queue; + + TosaMlirOperatorBuilder tosa_op_builder(op_builder, ser_block, block, loc, + this, &tensor_map, &tensor_type_map, + &shape_type_map); + + for (auto ts : ser_block->GetTensors()) { + if (ts->GetVariable()) { + RegisterVariableTensor(ts); + } + const auto &ts_name = ts->GetName(); + if (ts->GetDtype() == DType::DType_SHAPE) { + // ts is tosa.shape type + auto shape_rank = ts->GetShape()[0]; + shape_type_map[ts_name] = + mlir::tosa::shapeType::get(op_builder->getContext(), shape_rank); + continue; + } + mlir::RankedTensorType type; + if (BuildTensorType(op_builder, ts, type).failed()) { + return mlir::failure(); + } + tensor_type_map[ts_name] = type; + if (ts->GetIsUnranked()) { + assert(ts->GetShape().empty()); // unranked tensors should have shape = {} + unranked_tensors.insert(ts_name); + } + } + + // Initialize tensor_map/operator_queue based on block input arguments + for (const std::string &block_input_name : ser_block->GetInputs()) { + mlir::Type type = tensor_type_map[block_input_name]; + if (unranked_tensors.count(block_input_name)) { + // recast type as unranked tensor type + auto element_type = type.cast<mlir::RankedTensorType>().getElementType(); + type = mlir::UnrankedTensorType::get(element_type); + } + auto input_value = block->addArgument(type, loc); + if (tensor_map.count(block_input_name)) { + llvm::errs() << "ERROR: block input tensor " << block_input_name + << " already exists\n"; + return mlir::failure(); + } + tensor_map[block_input_name] = input_value; + } + + for (auto op : ser_block->GetOperators()) { + + // skip if operator has been built + if (operator_built.count(op)) { + // this happens when same input appears twice or more in operator, eg, + // concat(%0, %0) + continue; + } + operator_built.insert(op); + + std::vector<mlir::Value> output_values; + if (IsVariableReadOp(op)) { + output_values = tosa_op_builder.BuildVariableReadOp(op); + } else if (IsVariableWriteOp(op)) { + tosa_op_builder.BuildVariableWriteOp(op); + } +#define DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME) \ + else if (op->GetOp() == Op_##SCHEMA_OP_NAME) { \ + output_values = tosa_op_builder.build<Op_##SCHEMA_OP_NAME>(op); \ + } +#include "schema_operator.def" +#undef DEF_SCHEMA_OPERATOR + else { + llvm::errs() << "ERROR: unsupported opcode=" << EnumNamesOp()[op->GetOp()] + << "\n"; + return mlir::failure(); + } + + if (IsVariableWriteOp(op)) { + // the sanity checking below does not apply for variable write op because + // it has no output tensors whereas the original identity op has + continue; + } + + // Sanity check if number of built mlir::Value is expected + if (op->GetOutputTensorNames().size() != output_values.size()) { + llvm::errs() << "ERROR: number of built mlir::Value is not matching " + "number of operator output tensor\n"; + return mlir::failure(); + } + + for (size_t i = 0; i < output_values.size(); i++) { + // Sanity check tensor hasn't been built + std::string op_output_name = op->GetOutputTensorNames()[i]; + if (tensor_map.count(op_output_name)) { + llvm::errs() << "ERROR: tensor " << op_output_name + << " is already built\n"; + return mlir::failure(); + } + tensor_map[op_output_name] = output_values[i]; + } + } + + // Construct return values + std::vector<mlir::Value> return_operands; + for (const auto &output_name : ser_block->GetOutputs()) { + // Sanity check if terminator mlir::Value is built + if (!tensor_map.count(output_name)) { + llvm::errs() << "ERROR: terminator mlir::Value " << output_name + << " is not built in block " << ser_block->GetName() << "\n"; + return mlir::failure(); + } + mlir::Value output_value = tensor_map.at(output_name); + return_operands.push_back(output_value); + return_values.push_back(output_value); + } + mlir::Operation *terminator_op; + auto parent_op = block->getParentOp(); + if (mlir::isa<mlir::func::FuncOp>(parent_op)) { + terminator_op = + op_builder->create<mlir::func::ReturnOp>(loc, return_operands); + } else { + terminator_op = + op_builder->create<mlir::tosa::YieldOp>(loc, return_operands); + } + block->push_back(terminator_op); + + // need topological sorting? + + return mlir::success(); +} + +mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion( + std::vector<mlir::Value> &return_values) { + for (auto &ser_block : ser_region->GetBlocks()) { + auto &block = region->emplaceBlock(); + TosaMlirBlockBuilder block_builder(ser_block, this, &block); + + if (block_builder.BuildAllOpsInBlock(return_values).failed()) { + return mlir::failure(); + } + + if (return_values.empty()) { + llvm::errs() << "Warning: graph doesn't have return values\n"; + } + } + return mlir::success(); +} + +mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp &func, + mlir::MLIRContext &context, + tosa::TosaSerializationHandler &tsh, + std::vector<mlir::Value> &main_returns) { + + mlir::Region *main_region = func.getCallableRegion(); + if (!main_region) { + llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n"; + return mlir::failure(); + } + + TosaSerializationRegion *ser_main_region = tsh.GetRegions().front(); + + auto loc = func.getLoc(); + + main_region->takeBody(*main_region); // empty old func body + auto op_builder = mlir::OpBuilder(func.getBody()); + + if (BuildRegion(ser_main_region, &tsh, main_region, &op_builder, loc, + main_returns) + .failed()) { + return mlir::failure(); + } + + if (main_returns.empty()) { + llvm::errs() << "Warning: graph doesn't have return values\n"; + } + + return mlir::success(); +} + +// Load Tosa Schema into TosaSerializationHandler, required for JSON save/load +mlir::LogicalResult loadTosaSchema(tosa::TosaSerializationHandler &tsh) { + const char *tosa_schema = tosa_deserialize_schema.c_str(); + + if (!tosa_schema) { + llvm::errs() << "Flatbuffer schema not defined\n"; + return mlir::failure(); + } + + if (tsh.LoadFileSchema(tosa_schema)) { + llvm::errs() << "Error loading tosa schema file: " << tosa_schema << "\n"; + return mlir::failure(); + } + return mlir::success(); +} + +namespace { + +mlir::NamedAttribute DefaultEntryFuncitonAttr(mlir::Builder &builder, + bool is_input, int count) { + std::string names; + for (int i = 0; i < count; i++) { + std::string name = kDefaultExportedName + "_"; + name += (is_input ? kDefaultInputPrefix : kDefaultOutputPrefix); + name += std::to_string(i) + ":0"; + if (i > 0) { + names += ","; + } + names += name; + } + return builder.getNamedAttr((is_input ? "inputs" : "outputs"), + builder.getStringAttr(names)); +} + +// erase all ops in block except for FuncOp +void ClearNonFuncOps(mlir::Block *block) { + std::vector<mlir::Operation *> to_delete; + for (auto &op : block->getOperations()) { + if (!mlir::isa<mlir::func::FuncOp>(op)) { + to_delete.push_back(&op); + } + } + for (mlir::Operation *op : to_delete) { + op->erase(); + } +} + +// erase function attrs and empty function region's body +void ResetFunction(mlir::func::FuncOp &function, mlir::MLIRContext &context) { + function->setAttrs(mlir::DictionaryAttr::get(&context, {})); + mlir::Region *main_region = function.getCallableRegion(); + main_region->takeBody(*main_region); +} + +// replace attrs and body of @a to_function and its parent module +// by @a from_module and its "main" function +mlir::LogicalResult CloneIntoModuleAndFunction( + mlir::MLIRContext &context, mlir::func::FuncOp &to_function, + mlir::ModuleOp &to_module, mlir::func::FuncOp &from_function, + mlir::ModuleOp &from_module) { + auto from_block = from_function.getOperation()->getBlock(); + auto to_block = to_function.getOperation()->getBlock(); + ClearNonFuncOps(to_block); + // copy all attrs from new_module to module + to_module->setAttrs(from_module->getAttrDictionary()); + // erase attrs and body of function + ResetFunction(to_function, context); + // clone new_func attrs and region into function + mlir::IRMapping mapping; + from_function.cloneInto(to_function, mapping); + + // copy variable ops in from_block to to_block + // collect variable ops in from_block in reverse order + std::vector<mlir::Operation *> variable_ops; + for (mlir::Operation &op : *from_block) { + if (mlir::isa<mlir::tosa::VariableOp>(op)) { + variable_ops.push_back(&op); + } + } + auto cloneOptions = + mlir::Operation::CloneOptions::all().cloneRegions(false).cloneOperands( + false); + for (auto iter = variable_ops.rbegin(); iter != variable_ops.rend(); iter++) { + auto op = *iter; + to_block->push_front(op->clone(mapping, cloneOptions)); + } + return mlir::success(); +} + +} // namespace + +namespace mlir { + +namespace tosa { + +mlir::OwningOpRef<mlir::ModuleOp> +BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context, + bool file_is_fbs) { + TosaSerializationHandler tsh; + if (file_is_fbs) { + if (tsh.LoadFileTosaFlatbuffer(file_name)) { + llvm::errs() << "Fail to load TOSA file " << file_name << "\n"; + return nullptr; + } + } else { + // must load tosa schema before loading json file + if (loadTosaSchema(tsh).failed()) { + return nullptr; + } + if (tsh.LoadFileJson(file_name)) { + llvm::errs() << "Fail to load TOSA JSON file " << file_name << "\n"; + return nullptr; + } + } + + // create new module + auto base_loc = mlir::FileLineColLoc::get(context, file_name, 0, 0); + auto module = mlir::ModuleOp::create(base_loc); + + // set module attributes + const auto &tosa_version = tsh.GetVersion().to_string(); + std::string tosa_description = + file_is_fbs ? kDefaultFBSDescription : kDefaultJSONDescription; + auto builder = mlir::Builder(context); + module->setAttr("tosa.fbs_version", builder.getStringAttr(tosa_version)); + module->setAttr("tosa.description", builder.getStringAttr(tosa_description)); + module->setAttr("tf_saved_model.semantics", mlir::UnitAttr::get(context)); + + // construct function with input and return types + llvm::SmallVector<mlir::Type, 2> ret_types; + llvm::SmallVector<mlir::Type, 4> input_types; + auto func_type = builder.getFunctionType(input_types, ret_types); + auto func_loc = + mlir::NameLoc::get(builder.getStringAttr(kMainFunctionName), base_loc); + auto func = mlir::func::FuncOp::create(func_loc, kMainFunctionName, func_type, + /* attrs= */ {}); + func.addEntryBlock(); + + // deserialize tosa fbs into function + std::vector<mlir::Value> main_returns; + if (buildTosaMlir(func, *context, tsh, main_returns).failed()) { + llvm::errs() << "Failed to deserialize flatbuffer " + << tosa_deserialize_filename << "\n"; + return nullptr; + } + auto main_args = func.getCallableRegion()->getArguments(); + // extract function input types + for (auto arg : main_args) { + input_types.push_back(arg.getType()); + } + // extract function return types + for (auto ret : main_returns) { + ret_types.push_back(ret.getType()); + } + // set function type with full input and return types + func_type = builder.getFunctionType(input_types, ret_types); + func.setType(func_type); + + // set function attributes + llvm::SmallVector<mlir::NamedAttribute, 2> attributes; + if (!input_types.empty()) { + attributes.push_back(DefaultEntryFuncitonAttr( + builder, /* is_input = */ true, /* count = */ input_types.size())); + for (int i = 0; i < input_types.size(); i++) { + std::string input_i = kDefaultInputPrefix + std::to_string(i); + func.setArgAttr(i, "tf_saved_model.index_path", + mlir::ArrayAttr::get( + context, {mlir::StringAttr::get(context, input_i)})); + } + } + if (!ret_types.empty()) { + attributes.push_back(DefaultEntryFuncitonAttr( + builder, /* is_input = */ false, /* count = */ ret_types.size())); + for (int i = 0; i < ret_types.size(); i++) { + std::string output_i = kDefaultOutputPrefix + std::to_string(i); + func.setResultAttr( + i, "tf_saved_model.index_path", + mlir::ArrayAttr::get(context, + {mlir::StringAttr::get(context, output_i)})); + } + } + func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes)); + func->setAttr( + "tf_saved_model.exported_names", + mlir::ArrayAttr::get( + context, {mlir::StringAttr::get(context, kDefaultExportedName)})); + + // deserialize variable ops in the new module just before adding func op + if (ConstructVariableOps(module).failed()) { + return nullptr; + } + + // add func to module + module.push_back(std::move(func)); + return mlir::OwningOpRef<mlir::ModuleOp>(module); +} + +namespace { + +class TosaDeserialize : public TosaDeserializationPassBase<TosaDeserialize> { +public: + void runOnOperation() final { + auto function = getOperation(); + auto &context = getContext(); + + auto new_module_ref = BuildMlirFromTosaFile( + tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ true); + if (!new_module_ref) { + return signalPassFailure(); + } + + mlir::ModuleOp new_module = *new_module_ref; + auto builder = mlir::Builder(&context); + auto module = function->getParentOfType<mlir::ModuleOp>(); + auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>( + builder.getStringAttr(kMainFunctionName)); + if (!new_function) { + llvm::errs() << "Failed to find main function in deserialized module\n"; + return signalPassFailure(); + } + if (CloneIntoModuleAndFunction(context, + /* to_function = */ function, + /* to_module = */ module, + /* from_function = */ new_function, + /* from_module = */ new_module) + .failed()) { + return signalPassFailure(); + } + } +}; + +class TosaDeserializeJSON + : public TosaDeserializationJSONPassBase<TosaDeserializeJSON> { +public: + void runOnOperation() final { + auto function = getOperation(); + auto &context = getContext(); + + auto new_module_ref = BuildMlirFromTosaFile( + tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ false); + if (!new_module_ref) { + return signalPassFailure(); + } + + mlir::ModuleOp new_module = *new_module_ref; + auto builder = mlir::Builder(&context); + auto module = function->getParentOfType<mlir::ModuleOp>(); + auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>( + builder.getStringAttr(kMainFunctionName)); + if (!new_function) { + llvm::errs() << "Failed to find main function in deserialized module\n"; + return signalPassFailure(); + } + if (CloneIntoModuleAndFunction(context, + /* to_function = */ function, + /* to_module = */ module, + /* from_function = */ new_function, + /* from_module = */ new_module) + .failed()) { + return signalPassFailure(); + } + } +}; + +} // anonymous namespace + +// Creates an instance of the TOSA flatbuffer deserialization pass +std::unique_ptr<Pass> createTosaDeserializePass() { + return std::make_unique<TosaDeserialize>(); +} + +std::unique_ptr<Pass> createTosaDeserializeJSONPass() { + return std::make_unique<TosaDeserializeJSON>(); +} + +static PassRegistration<TosaDeserialize> passDeserialize([] { + return createTosaDeserializePass(); +}); + +static PassRegistration<TosaDeserializeJSON> passDeserializeJSON([] { + return createTosaDeserializeJSONPass(); +}); + +} // namespace tosa +} // namespace mlir diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index a6ea5ff..c3a9878 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 with LLVM Exceptions // (the "License"); you may not use this file except in compliance with @@ -21,9 +21,13 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tosa_serialization_handler.h" +#include <algorithm> #include <functional> #include <map> #include <unordered_map> @@ -61,7 +65,7 @@ template <> struct equal_to<mlir::Value> { } // namespace std -ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { +static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { if (mode_str == "NEAREST_NEIGHBOR") return ResizeMode_NEAREST; else if (mode_str == "BILINEAR") @@ -70,14 +74,21 @@ ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { return ResizeMode_UNKNOWN; } -DType Type2DType(mlir::Type element_type) { - if (element_type.isF64() || element_type.isF32() || element_type.isF16() || - element_type.isBF16()) { - return DType_FLOAT; +static DType Type2DType(mlir::Type element_type) { + if (element_type.isF64() || element_type.isF32()) { + return DType_FP32; + } else if (element_type.isFloat8E5M2()) { + return DType_FP8E5M2; + } else if (element_type.isFloat8E4M3FN()) { + return DType_FP8E4M3; + } else if (element_type.isF16()) { + return DType_FP16; + } else if (element_type.isBF16()) { + return DType_BF16; } else if (element_type.isUnsignedInteger(8)) { return DType_UINT8; } else if (element_type.isInteger(4)) { - return DType_INT8; + return DType_INT4; } else if (element_type.isInteger(8)) { return DType_INT8; } else if (element_type.isInteger(16)) { @@ -94,18 +105,79 @@ DType Type2DType(mlir::Type element_type) { return DType_UNKNOWN; } +static DType Type2AccDType(mlir::Type element_type) { + // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>; + if (element_type.isF32()) { + return DType_FP32; + } else if (element_type.isF16()) { + return DType_FP16; + } else if (element_type.isInteger(32)) { + return DType_INT32; + } else if (element_type.isInteger(48)) { + return DType_INT48; + } + return DType_UNKNOWN; +} class TosaSerializationBlockBuilder; +class TosaSerializationRegionBuilder; + +std::unordered_map<std::string, mlir::Operation *> variable_tensor_op_map; +std::unordered_map<std::string, std::string> + variable_tensor_flatbuffer_name_map; +static int variable_tensor_index = 0; + +namespace { + +// for now, this is a global map of variables +void RegisterVariableOp(mlir::Operation &op) { + std::string variable_tensor_flatbuffer_name = + "Variable_" + std::to_string(variable_tensor_index++); + std::string variable_tensor_mlir_name = + op.getAttr("name").cast<mlir::StringAttr>().getValue().str(); + variable_tensor_op_map[variable_tensor_flatbuffer_name] = &op; + variable_tensor_flatbuffer_name_map[variable_tensor_mlir_name] = + variable_tensor_flatbuffer_name; +} + +} // namespace class TosaSerializationOperatorBuilder { public: TosaSerializationOperatorBuilder( TosaSerializationBlockBuilder *_block_builder) : block_builder(_block_builder) {} + template <typename T> TosaSerializationOperator *build(mlir::Operation &op) const; + TosaSerializationHandler *GetTsh() const; + TosaSerializationRegionBuilder *GetRegionBuilder() const; + mlir::LogicalResult GetDataFromAttribute(mlir::Operation &op, + mlir::Attribute &attr, + mlir::Type element_type, + std::vector<uint8_t> &u8_data) const; + + // populate u8_data with either int64_value or float_value depending on + // element_type + mlir::LogicalResult + GetU8DataFromIntOrFloatValue(int64_t int64_value, float fp_value, + mlir::Type element_type, + std::vector<uint8_t> &u8_data) const; + + // populate u8_data with int_value depending on non-float element_type + mlir::LogicalResult + GetU8DataFromIntValues(const std::vector<int64_t> &int_values, + mlir::Type element_type, + std::vector<uint8_t> &u8_data) const; + + // populate u8_data with fp_value depending on float element_type + mlir::LogicalResult + GetU8DataFromFloatValues(const std::vector<float> &fp_values, + mlir::Type element_type, + std::vector<uint8_t> &u8_data) const; private: std::string GetTensorName(mlir::Value val) const; + std::string GetVariableTensorName(mlir::Operation *op) const; TosaSerializationOperator *BuildPoolOpFromMlirOp(mlir::Operation &op, Op opcode) const; TosaSerializationOperator *BuildEwiseBinaryOpFromMlirOp(mlir::Operation &op, @@ -120,37 +192,275 @@ private: // This builder assumes each region only has only one block class TosaSerializationBlockBuilder { public: - friend class TosaSerializationOperatorBuilder; - TosaSerializationBlockBuilder(TosaSerializationBasicBlock *_block, - TosaSerializationHandler *_tsh, - mlir::Region *_region) - : block(_block), tsh(_tsh), region(_region) {} + // constructor + TosaSerializationBlockBuilder(TosaSerializationBasicBlock *_ser_block, + TosaSerializationRegionBuilder *_region_builder, + mlir::Block *_block) + : ser_block(_ser_block), region_builder(_region_builder), block(_block) {} mlir::LogicalResult - BuildAllOpsInRegion(std::vector<mlir::Value> &return_values); - TosaSerializationBasicBlock *GetBlock() { return block; } - TosaSerializationHandler *GetTsh() { return tsh; } + BuildAllOpsInBlock(std::vector<mlir::Value> &return_values); + TosaSerializationBasicBlock *GetBlock() const { return ser_block; } + TosaSerializationRegionBuilder *GetRegionBuilder() const { + return region_builder; + } + TosaSerializationHandler *GetTsh() const; + std::unordered_map<mlir::Value, std::string> &GetTensorMap() { + return tensor_map; + } private: TosaSerializationOperator *BuildTosaSerializationOperator( const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op); TosaSerializationTensor * + BuildTosaSerializationVariableTensor(mlir::RankedTensorType tensor_type, + const std::string &name, + const std::string &variable_mlir_name); + TosaSerializationTensor * BuildTosaSerializationTensor(mlir::Value val, const std::string &name); - TosaSerializationBasicBlock *block; - TosaSerializationHandler *tsh; - mlir::Region *region; + TosaSerializationBasicBlock *ser_block; + TosaSerializationRegionBuilder *region_builder; + mlir::Block *block; std::unordered_map<mlir::Value, std::string> tensor_map; std::unordered_map<mlir::Value, std::string> input_tensor_map; }; +class TosaSerializationRegionBuilder { +public: + // Constructor + TosaSerializationRegionBuilder( + TosaSerializationRegion *_ser_region, mlir::Region *_region, + TosaSerializationRegionBuilder *_parent_value_scope, + TosaSerializationHandler *_tsh) + : ser_region(_ser_region), region(_region), + parent_value_scope(_parent_value_scope), tsh(_tsh) {} + TosaSerializationHandler *GetTsh() const { return tsh; } + mlir::LogicalResult + BuildAllBlocksInRegion(bool is_top, std::vector<mlir::Value> &return_values); + TosaSerializationRegionBuilder *GetParentValueScope() const { + return parent_value_scope; + } + std::vector<TosaSerializationBlockBuilder *> &GetBlockBuilders() { + return block_builders; + } + +private: + TosaSerializationRegion *ser_region; + mlir::Region *region; + TosaSerializationRegionBuilder *parent_value_scope; + TosaSerializationHandler *tsh; + std::vector<TosaSerializationBlockBuilder *> block_builders; +}; + +TosaSerializationHandler *TosaSerializationOperatorBuilder::GetTsh() const { + return block_builder->GetTsh(); +} +TosaSerializationHandler *TosaSerializationBlockBuilder::GetTsh() const { + return region_builder->GetTsh(); +} +TosaSerializationRegionBuilder * +TosaSerializationOperatorBuilder::GetRegionBuilder() const { + return block_builder->GetRegionBuilder(); +} + std::string TosaSerializationOperatorBuilder::GetTensorName(mlir::Value val) const { - if (block_builder->tensor_map.find(val) == block_builder->tensor_map.end()) { - llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map"; + auto value_scope = GetRegionBuilder(); + while (value_scope) { + // Traverse through each block builder in the region + for (auto curr_block_builder : value_scope->GetBlockBuilders()) { + const auto &tensor_map = curr_block_builder->GetTensorMap(); + if (tensor_map.count(val)) { + return tensor_map.at(val); + } + } + value_scope = value_scope->GetParentValueScope(); + } + // Didn't find anything + llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map\n"; + assert(0); +} + +// Unpack 64-bit integer attribute element and pack into a std vector. +template <class T> +static std::vector<T> getDenseI64ArrayAttr(mlir::Attribute attr) { + auto array_ref = attr.cast<mlir::DenseI64ArrayAttr>().asArrayRef(); + + std::vector<T> vec; + for (auto val : array_ref) { + vec.push_back(val); + } + + return vec; +} + +// Unpack 8-bit integer attribute element and pack into a std vector. +template <class T> +static std::vector<T> getDenseI8ArrayAttr(mlir::Attribute attr) { + auto array_ref = attr.cast<mlir::DenseI8ArrayAttr>().asArrayRef(); + + std::vector<T> vec; + for (auto val : array_ref) { + vec.push_back(val); + } + + return vec; +} + +std::string TosaSerializationOperatorBuilder::GetVariableTensorName( + mlir::Operation *op) const { + std::string variable_tensor_mlir_name = + op->getAttr("name").cast<mlir::StringAttr>().getValue().str(); + + if (variable_tensor_flatbuffer_name_map.find(variable_tensor_mlir_name) == + variable_tensor_flatbuffer_name_map.end()) { + llvm::errs() << "ERROR: Failed to find key " << variable_tensor_mlir_name + << " from variable_tensor_flatbuffer_name_map\n"; assert(0); } - return block_builder->tensor_map[val]; + return variable_tensor_flatbuffer_name_map[variable_tensor_mlir_name]; +} + +mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( + mlir::Operation &op, mlir::Attribute &attr, mlir::Type element_type, + std::vector<uint8_t> &u8_data) const { + if (!element_type.isIntOrFloat()) { + return mlir::failure(); + } + auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>(); + + // handle float types + if (element_type.isa<mlir::FloatType>()) { + std::vector<float> fp_data; + auto val_attr = attr.dyn_cast<mlir::FloatAttr>(); + + if (dense_attr) { + for (auto val : dense_attr.getValues<mlir::APFloat>()) { + fp_data.push_back(val.convertToFloat()); + } + } else if (val_attr) { + fp_data.push_back((float)val_attr.getValueAsDouble()); + } else { + op.emitOpError("Unknown const attribute"); + return mlir::failure(); + } + + return GetU8DataFromFloatValues(fp_data, element_type, u8_data); + } + + // element_type is integer type + + bool isInt48 = element_type.isInteger(48); + std::vector<int64_t> i64_data; + + auto val_attr = attr.dyn_cast<mlir::IntegerAttr>(); + if (dense_attr) { + for (auto valueIt : dense_attr.getValues<mlir::APInt>()) { + int64_t val = isInt48 ? static_cast<int64_t>(valueIt.getLimitedValue()) + : valueIt.getSExtValue(); + i64_data.push_back(val); + } + } else if (val_attr) { + i64_data.push_back(val_attr.getInt()); + } else { + op.emitOpError("Unknown const attribute"); + return mlir::failure(); + } + + return GetU8DataFromIntValues(i64_data, element_type, u8_data); +} + +mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromIntValues( + const std::vector<int64_t> &int64_values, mlir::Type element_type, + std::vector<uint8_t> &u8_data) const { + switch (element_type.getIntOrFloatBitWidth()) { + case 1: { + // bool use bool vec + std::vector<bool> bool_values; + for (auto v : int64_values) { + bool bool_value = v == 0 ? false : true; + bool_values.push_back(bool_value); + } + TosaSerializationHandler::ConvertBooltoU8(bool_values, u8_data); + break; + } + case 4: + case 8: { + // I4 and I8 use int8_t vec + std::vector<int8_t> i8_values; + for (auto v : int64_values) { + i8_values.push_back(static_cast<int8_t>(v)); + } + if (element_type.isInteger(4)) { + TosaSerializationHandler::ConvertI4toU8(i8_values, u8_data); + } else { + TosaSerializationHandler::ConvertI8toU8(i8_values, u8_data); + } + break; + } + case 16: { + // I16 use int16_t vec + std::vector<int16_t> i16_values; + for (auto v : int64_values) { + i16_values.push_back(static_cast<int16_t>(v)); + } + TosaSerializationHandler::ConvertI16toU8(i16_values, u8_data); + break; + } + case 32: { + // I32 use int32_t vec + std::vector<int32_t> i32_values; + for (auto v : int64_values) { + i32_values.push_back(static_cast<int32_t>(v)); + } + TosaSerializationHandler::ConvertI32toU8(i32_values, u8_data); + break; + } + case 48: { + // I48 use int64_t vec + TosaSerializationHandler::ConvertI48toU8(int64_values, u8_data); + break; + } + default: { + // unsupported bit widths + return mlir::failure(); + } + } + return mlir::success(); +} + +mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromFloatValues( + const std::vector<float> &fp_values, mlir::Type element_type, + std::vector<uint8_t> &u8_data) const { + assert( + element_type + .isa<mlir::FloatType>()); // this should only be called for float type + if (element_type.isF16()) { + TosaSerializationHandler::ConvertF16toU8(fp_values, u8_data); + } else if (element_type.isBF16()) { + TosaSerializationHandler::ConvertBF16toU8(fp_values, u8_data); + } else if (element_type.isFloat8E4M3FN()) { + TosaSerializationHandler::ConvertFP8E4M3toU8(fp_values, u8_data); + } else if (element_type.isFloat8E5M2()) { + TosaSerializationHandler::ConvertFP8E5M2toU8(fp_values, u8_data); + } else if (element_type.isF32()) { + TosaSerializationHandler::ConvertF32toU8(fp_values, u8_data); + } else { + return mlir::failure(); + } + return mlir::success(); +} + +mlir::LogicalResult +TosaSerializationOperatorBuilder::GetU8DataFromIntOrFloatValue( + int64_t int64_value, float fp_value, mlir::Type element_type, + std::vector<uint8_t> &u8_data) const { + if (element_type.isa<mlir::FloatType>()) { + return GetU8DataFromFloatValues({fp_value}, element_type, u8_data); + } else { + return GetU8DataFromIntValues({int64_value}, element_type, u8_data); + } } // Main template to catch unimplemented translation. @@ -176,43 +486,42 @@ TosaSerializationOperatorBuilder::build(mlir::Operation &op) const { TosaSerializationOperator * TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, Op opcode) const { - std::vector<int> pad, stride, kernel; - - auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : pad_attr) { - pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } + auto pad = getDenseI64ArrayAttr<int>(op.getAttr("pad")); ASSERT_VECTOR_LENGTH(pad, 4); - auto stride_attr = - op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : stride_attr) { - stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } + auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride")); ASSERT_VECTOR_LENGTH(stride, 2); - auto kernel_attr = - op.getAttr("kernel").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : kernel_attr) { - kernel.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } + auto kernel = getDenseI64ArrayAttr<int>(op.getAttr("kernel")); ASSERT_VECTOR_LENGTH(kernel, 2); + DType acc_dtype = DType_FP32; + // AvgPool has acc_type, MaxPool does not + if (op.hasAttr("acc_type")) { + auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue(); + acc_dtype = Type2AccDType(acc_type); + } + std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = op.getAttrOfType<mlir::tosa::UnaryOpQuantizationAttr>( - "quantization_info"); + int32_t input_zp = + op.hasAttr("input_zp") + ? input_zp = op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t output_zp = + op.hasAttr("output_zp") + ? output_zp = + op.getAttr("output_zp").cast<mlir::IntegerAttr>().getInt() + : 0; - int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; - int32_t output_zp = quant_info ? quant_info.output_zp().getInt() : 0; + TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, + acc_dtype); - TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp); - - TosaSerializationOperator *tyop = new TosaSerializationOperator( - opcode, Attribute_PoolAttribute, &attribute, - std::vector<std::string>{input_name}, - std::vector<std::string>{output_name}); + TosaSerializationOperator *tyop = + new TosaSerializationOperator(opcode, Attribute_PoolAttribute, &attribute, + std::vector<std::string>{input_name}, + std::vector<std::string>{output_name}); return tyop; } @@ -239,8 +548,7 @@ TosaSerializationOperatorBuilder::BuildEwiseUnaryOpFromMlirOp( std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( - opcode, Attribute_NONE, nullptr, - std::vector<std::string>{input_name}, + opcode, Attribute_NONE, nullptr, std::vector<std::string>{input_name}, std::vector<std::string>{output_name}); return tyop; @@ -255,10 +563,10 @@ TosaSerializationOperatorBuilder::BuildReductionOpFromMlirOp( int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt(); TosaAxisAttribute attribute(axis); - TosaSerializationOperator *tyop = new TosaSerializationOperator( - opcode, Attribute_AxisAttribute, &attribute, - std::vector<std::string>{input_name}, - std::vector<std::string>{output_name}); + TosaSerializationOperator *tyop = + new TosaSerializationOperator(opcode, Attribute_AxisAttribute, &attribute, + std::vector<std::string>{input_name}, + std::vector<std::string>{output_name}); return tyop; } @@ -302,7 +610,7 @@ BUILD_OP_ELEMENTWISE_BINARY(Add, ADD) BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND) BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR) BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR) -BUILD_OP_ELEMENTWISE_BINARY(Div, INTDIV) +BUILD_OP_ELEMENTWISE_BINARY(IntDiv, INTDIV) BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND) BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT) BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT) @@ -317,12 +625,14 @@ BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS) BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT) BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL) BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ) +BUILD_OP_ELEMENTWISE_UNARY(Cos, COS) BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP) BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR) BUILD_OP_ELEMENTWISE_UNARY(Log, LOG) BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT) BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL) BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT) +BUILD_OP_ELEMENTWISE_UNARY(Sin, SIN) BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY) BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL) @@ -335,14 +645,20 @@ BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL) BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER) BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL) +BUILD_OP_ELEMENTWISE_UNARY(Erf, ERF) BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID) BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH) BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY) BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST) +BUILD_OP_ELEMENTWISE_BINARY(AddShape, ADD_SHAPE) +BUILD_OP_ELEMENTWISE_BINARY(SubShape, SUB_SHAPE) +BUILD_OP_ELEMENTWISE_BINARY(MulShape, MUL_SHAPE) +BUILD_OP_ELEMENTWISE_BINARY(DivShape, DIV_SHAPE) + template <> TosaSerializationOperator * -TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>( +TosaSerializationOperatorBuilder::build<mlir::tosa::ConstShapeOp>( mlir::Operation &op) const { std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationTensor *ts = @@ -353,142 +669,73 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>( return nullptr; } -#if 0 - // Gracefully handle constants of "constant unit" type which have no value - // by creating a numpy value of 0. - auto unit_val = op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::UnitAttr>(); - - if (unit_val) - { - std::vector<float> data = { 0.0 }; - type = DType_FLOAT; - TosaSerializationHandler::ConvertF32toU8(data, u8_data); - } -#endif - // Update tensor.data array with Const value attribute + mlir::Attribute value_attr = op.getAttr("value"); + if (!value_attr) { + op.emitOpError("ERROR: tosa.const_shape doesn't have value"); + return nullptr; + } + assert(ts->GetDtype() == DType::DType_SHAPE); std::vector<uint8_t> u8_data; - DType type = ts->GetDtype(); - if (type == DType_FLOAT) { - std::vector<float> data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>(); - if (dense_attr) { - for (auto val : dense_attr.getValues<float>()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back((float)val_attr.getValueAsDouble()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertF32toU8(data, u8_data); - } else if (type == DType_INT8) { - std::vector<int8_t> data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>(); + std::vector<int64_t> data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast<mlir::DenseIntElementsAttr>(); + if (!dense_attr) { + op.emitOpError("Unknown const attribute"); + return nullptr; + } - if (dense_attr) { - for (auto val : dense_attr.getValues<int8_t>()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertI8toU8(data, u8_data); - } else if (type == DType_INT16) { - std::vector<int16_t> data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>(); + for (auto valueIt : dense_attr.getValues<mlir::APInt>()) { + int64_t val = valueIt.getSExtValue(); + data.push_back(val); + } - if (dense_attr) { - for (auto val : dense_attr.getValues<int16_t>()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertI16toU8(data, u8_data); - } else if (type == DType_INT32) { - std::vector<int32_t> data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>(); + TosaSerializationHandler::ConvertI64toU8(data, u8_data); - if (dense_attr) { - for (auto val : dense_attr.getValues<int32_t>()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertI32toU8(data, u8_data); - } else if (type == DType_INT48) { - std::vector<int64_t> data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>(); + ts->SetData(u8_data); - if (dense_attr) { - for (auto valueIt : dense_attr.getValues<mlir::APInt>()) { - uint64_t val = valueIt.getLimitedValue(); - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertI48toU8(data, u8_data); - } else if (type == DType_BOOL) { - std::vector<bool> data; + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CONST_SHAPE, Attribute_NONE, nullptr, std::vector<std::string>{}, + std::vector<std::string>{output_name}); - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::BoolAttr>(); + return tyop; +} - if (dense_attr) { - for (auto val : dense_attr.getValues<bool>()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getValue()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>( + mlir::Operation &op) const { + std::string output_name = GetTensorName(op.getResult(0)); + TosaSerializationTensor *ts = + block_builder->GetBlock()->GetTensorByName(output_name); + if (!ts) { + op.emitOpError( + "ERROR: serialization tensor must be built before building operator"); + return nullptr; + } - TosaSerializationHandler::ConvertBooltoU8(data, u8_data); - } else { - op.emitOpError("Unknown element type of const attribute"); + // Update tensor.data array with Const value attribute + mlir::Attribute value_attr = op.getAttr("value"); + if (!value_attr) { + op.emitOpError("ERROR: tosa.const doesn't have value"); + return nullptr; + } + std::vector<uint8_t> u8_data; + mlir::Attribute attr = op.getAttr(llvm::StringRef("value")); + mlir::Type element_type = + llvm::cast<mlir::ShapedType>(op.getResult(0).getType()).getElementType(); + + if (GetDataFromAttribute(op, attr, element_type, u8_data).failed()) { + op.emitOpError("ERROR: GetDataFromAttribute() fails when building value of " + "const tensor"); return nullptr; } - ts->SetData(u8_data); + ts->SetData(u8_data); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_CONST, Attribute_NONE, nullptr, - std::vector<std::string>{}, std::vector<std::string>{output_name}); + Op_CONST, Attribute_NONE, nullptr, std::vector<std::string>{}, + std::vector<std::string>{output_name}); return tyop; } @@ -497,26 +744,13 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>( mlir::Operation &op) const { - std::vector<int> pad, stride, dilation; - - auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : pad_attr) { - pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } + auto pad = getDenseI64ArrayAttr<int>(op.getAttr("pad")); ASSERT_VECTOR_LENGTH(pad, 4); - auto stride_attr = - op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : stride_attr) { - stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } + auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride")); ASSERT_VECTOR_LENGTH(stride, 2); - auto dilation_attr = - op.getAttr("dilation").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : dilation_attr) { - dilation.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } + auto dilation = getDenseI64ArrayAttr<int>(op.getAttr("dilation")); ASSERT_VECTOR_LENGTH(dilation, 2); std::string input0_name = GetTensorName(op.getOperand(0)); @@ -524,14 +758,25 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt() + : 0; - auto quant_info = - op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info"); + bool local_bound = + op.hasAttr("local_bound") + ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue() + : false; - int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; - int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0; + auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue(); + auto acc_dtype = Type2AccDType(acc_type); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, + local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV2D, Attribute_ConvAttribute, &attribute, @@ -543,28 +788,61 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>( template <> TosaSerializationOperator * -TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>( +TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>( mlir::Operation &op) const { - std::vector<int> pad, stride, dilation; + auto pad = getDenseI64ArrayAttr<int>(op.getAttr("pad")); + ASSERT_VECTOR_LENGTH(pad, 6); - auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : pad_attr) { - pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } + auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride")); + ASSERT_VECTOR_LENGTH(stride, 3); + + auto dilation = getDenseI64ArrayAttr<int>(op.getAttr("dilation")); + ASSERT_VECTOR_LENGTH(dilation, 3); + + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + + bool local_bound = + op.hasAttr("local_bound") + ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue() + : false; + + auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue(); + auto acc_dtype = Type2AccDType(acc_type); + + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, + local_bound, acc_dtype); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CONV3D, Attribute_ConvAttribute, &attribute, + std::vector<std::string>{input0_name, input1_name, input2_name}, + std::vector<std::string>{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>( + mlir::Operation &op) const { + auto pad = getDenseI64ArrayAttr<int>(op.getAttr("pad")); ASSERT_VECTOR_LENGTH(pad, 4); - auto stride_attr = - op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : stride_attr) { - stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } + auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride")); ASSERT_VECTOR_LENGTH(stride, 2); - auto dilation_attr = - op.getAttr("dilation").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : dilation_attr) { - dilation.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } + auto dilation = getDenseI64ArrayAttr<int>(op.getAttr("dilation")); ASSERT_VECTOR_LENGTH(dilation, 2); std::string input0_name = GetTensorName(op.getOperand(0)); @@ -572,13 +850,25 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info"); + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + + bool local_bound = + op.hasAttr("local_bound") + ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue() + : false; - int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; - int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0; + auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue(); + auto acc_dtype = Type2AccDType(acc_type); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, + local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, @@ -592,41 +882,39 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>( mlir::Operation &op) const { - std::vector<int> outpad, stride, dilation, output_shape; + auto out_pad = getDenseI64ArrayAttr<int>(op.getAttr("out_pad")); + ASSERT_VECTOR_LENGTH(out_pad, 4); - auto outpad_attr = - op.getAttr("out_pad").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : outpad_attr) { - outpad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } - ASSERT_VECTOR_LENGTH(outpad, 4); - - auto stride_attr = - op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : stride_attr) { - stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } + auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride")); ASSERT_VECTOR_LENGTH(stride, 2); - auto output_shape_attr = - op.getAttr("out_shape").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : output_shape_attr) { - output_shape.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } - ASSERT_VECTOR_LENGTH(output_shape, 4); - std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info"); + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt() + : 0; - int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; - int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0; + mlir::RankedTensorType tensor = + op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - TosaTransposeConvAttribute attribute(outpad, stride, output_shape, input_zp, weight_zp); + bool local_bound = + op.hasAttr("local_bound") + ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue() + : false; + + auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue(); + auto acc_dtype = Type2AccDType(acc_type); + + TosaTransposeConvAttribute attribute(out_pad, stride, input_zp, weight_zp, + local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, @@ -645,11 +933,15 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FullyConnectedOp>( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info"); + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt() + : 0; - int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; - int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0; TosaFullyConnectedAttribute attribute(input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( @@ -668,11 +960,12 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::MatMulOp>( std::string input1_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = op.getAttrOfType<mlir::tosa::MatMulOpQuantizationAttr>( - "quantization_info"); - - int32_t A_zp = quant_info ? quant_info.a_zp().getInt() : 0; - int32_t B_zp = quant_info ? quant_info.b_zp().getInt() : 0; + int32_t A_zp = op.hasAttr("a_zp") + ? op.getAttr("a_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t B_zp = op.hasAttr("b_zp") + ? op.getAttr("b_zp").cast<mlir::IntegerAttr>().getInt() + : 0; TosaMatMulAttribute attribute(A_zp, B_zp); @@ -705,23 +998,49 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>( mlir::Operation &op) const { - int32_t min_int = - op.getAttr("min_int").dyn_cast<mlir::IntegerAttr>().getInt(); - int32_t max_int = - op.getAttr("max_int").dyn_cast<mlir::IntegerAttr>().getInt(); - float min_fp = op.getAttr("min_fp") - .dyn_cast<mlir::FloatAttr>() - .getValue() - .convertToFloat(); - float max_fp = op.getAttr("max_fp") - .dyn_cast<mlir::FloatAttr>() - .getValue() - .convertToFloat(); + auto min_val_attr = op.getAttr("min_val"); + auto max_val_attr = op.getAttr("max_val"); + + mlir::Type input_element_type = + llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType(); + if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>( + input_element_type)) { + input_element_type = quantType.getStorageType(); + } + + std::vector<uint8_t> min_val, max_val; + float min_fp, max_fp; + int64_t min_int, max_int; + + if (input_element_type.isa<mlir::FloatType>()) { + min_fp = + mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat(); + max_fp = + mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat(); + min_int = max_int = 0; + } else { + assert(input_element_type.isa<mlir::IntegerType>()); + min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt(); + max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt(); + min_fp = max_fp = 0.f; + } + + if (GetU8DataFromIntOrFloatValue(min_int, min_fp, input_element_type, min_val) + .failed()) { + op.emitOpError("Failed to serialize min value"); + return nullptr; + } + + if (GetU8DataFromIntOrFloatValue(max_int, max_fp, input_element_type, max_val) + .failed()) { + op.emitOpError("Failed to serialize max value"); + return nullptr; + } std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp); + TosaClampAttribute attribute(min_val, max_val); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CLAMP, Attribute_ClampAttribute, &attribute, @@ -767,8 +1086,27 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConcatOp>( TosaAxisAttribute attribute(axis); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_CONCAT, Attribute_AxisAttribute, &attribute, - inputs, std::vector<std::string>{output_name}); + Op_CONCAT, Attribute_AxisAttribute, &attribute, inputs, + std::vector<std::string>{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build<mlir::tosa::ConcatShapeOp>( + mlir::Operation &op) const { + std::vector<std::string> inputs; + for (uint32_t i = 0; i < op.getNumOperands(); i++) { + std::string input_name = GetTensorName(op.getOperand(i)); + inputs.push_back(input_name); + } + + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CONCAT_SHAPE, Attribute_NONE, nullptr, inputs, + std::vector<std::string>{output_name}); return tyop; } @@ -780,13 +1118,16 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::NegateOp>( std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = op.getAttrOfType<mlir::tosa::UnaryOpQuantizationAttr>( - "quantization_info"); - - int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; - int32_t output_zp = quant_info ? quant_info.output_zp().getInt() : 0; + int32_t input1_zp = + op.hasAttr("input1_zp") + ? op.getAttr("input1_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t output_zp = + op.hasAttr("output_zp") + ? op.getAttr("output_zp").cast<mlir::IntegerAttr>().getInt() + : 0; - TosaNegateAttribute attribute(input_zp, output_zp); + TosaNegateAttribute attribute(input1_zp, output_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_NEGATE, Attribute_NegateAttribute, &attribute, @@ -800,20 +1141,12 @@ TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::ReshapeOp>( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); + std::string shape_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); - std::vector<int> shape; - auto shape_attr = - op.getAttr("new_shape").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : shape_attr) { - shape.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } - - TosaReshapeAttribute attribute(shape); - TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_RESHAPE, Attribute_ReshapeAttribute, &attribute, - std::vector<std::string>{input_name}, + Op_RESHAPE, Attribute_NONE, nullptr, + std::vector<std::string>{input_name, shape_name}, std::vector<std::string>{output_name}); return tyop; @@ -824,38 +1157,80 @@ TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); + std::string padding_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); + auto pad_op = llvm::cast<mlir::tosa::PadOp>(op); - // Match padding tensor as compile-time constant attribute - // TODO: fix when MLIR dialect changes - mlir::ElementsAttr paddings_elems; - if (!matchPattern(op.getOperand(1), m_Constant(&paddings_elems))) - return nullptr; + auto input_zp_attr = pad_op.getInputZpAttr(); + // pad_const includes the zero point if the tensor uses a zero point. + int32_t pad_const_int = input_zp_attr ? input_zp_attr.getInt() : 0; + float pad_const_fp = 0.f; + + if (auto tensor = pad_op.getPadConst()) { + // Match pad_const tensor as compile-time constant attribute if present. + mlir::DenseElementsAttr attr; + if (!matchPattern(tensor, m_Constant(&attr))) + return nullptr; + + assert(attr.getNumElements() == 1); + auto elementTy = attr.getElementType(); + + if (elementTy.isa<mlir::IntegerType>()) { + pad_const_int = (attr.getValues<mlir::APInt>()[0]).getSExtValue(); + } else if (elementTy.isa<mlir::FloatType>()) { + pad_const_fp = (attr.getValues<mlir::APFloat>()[0]).convertToFloat(); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + } + + std::vector<uint8_t> pad_const; + mlir::Type input_element_type = + llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType(); - std::vector<int> paddings; - for (int32_t val : paddings_elems.getValues<int32_t>()) { - paddings.push_back(val); + if (GetU8DataFromIntOrFloatValue(pad_const_int, pad_const_fp, + input_element_type, pad_const) + .failed()) { + op.emitOpError("Failed to serialize pad_const value"); + return nullptr; } - TosaPadAttribute attribute(paddings, 0 /* pad_const_int */, - 0.0f /* pad_const_fp */); + TosaPadAttribute attribute(pad_const); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_PAD, Attribute_PadAttribute, &attribute, - std::vector<std::string>{input_name}, + std::vector<std::string>{input_name, padding_name}, std::vector<std::string>{output_name}); return tyop; } template <> TosaSerializationOperator * +TosaSerializationOperatorBuilder::build<mlir::tosa::DimOp>( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt(); + TosaAxisAttribute attribute(axis); + + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_DIM, Attribute_AxisAttribute, &attribute, + std::vector<std::string>{input_name}, + std::vector<std::string>{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeOp>( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); // Match perm tensor as compile-time constant attribute - // TODO: fix when MLIR dialect changes mlir::ElementsAttr perm_elems; if (!matchPattern(op.getOperand(1), m_Constant(&perm_elems))) return nullptr; @@ -879,26 +1254,14 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::SliceOp>( mlir::Operation &op) const { - std::vector<int> start, size; - auto begin_attr = op.getAttr("start").dyn_cast<mlir::ArrayAttr>().getValue(); - auto size_attr = op.getAttr("size").dyn_cast<mlir::ArrayAttr>().getValue(); - - for (auto &int_attr : begin_attr) { - start.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } - - for (auto &int_attr : size_attr) { - size.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } - - TosaSliceAttribute attribute(start, size); - std::string input_name = GetTensorName(op.getOperand(0)); + std::string start_name = GetTensorName(op.getOperand(1)); + std::string size_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_SLICE, Attribute_SliceAttribute, &attribute, - std::vector<std::string>{input_name}, + Op_SLICE, Attribute_NONE, nullptr, + std::vector<std::string>{input_name, start_name, size_name}, std::vector<std::string>{output_name}); return tyop; @@ -908,21 +1271,13 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::TileOp>( mlir::Operation &op) const { - std::string input_name = GetTensorName(op.getOperand(0)); + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); - std::vector<int> multiples; - auto multiples_attr = - op.getAttr("multiples").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : multiples_attr) { - multiples.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } - - TosaTileAttribute attribute(multiples); - TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_TILE, Attribute_TileAttribute, &attribute, - std::vector<std::string>{input_name}, + Op_TILE, Attribute_NONE, nullptr, + std::vector<std::string>{input0_name, input1_name}, std::vector<std::string>{output_name}); return tyop; @@ -966,60 +1321,21 @@ TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::ResizeOp>( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); + std::string scale_name = GetTensorName(op.getOperand(1)); + std::string offset_name = GetTensorName(op.getOperand(2)); + std::string border_name = GetTensorName(op.getOperand(3)); std::string output_name = GetTensorName(op.getResult(0)); - std::vector<int> output_size; - auto output_size_attr = - op.getAttr("output_size").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : output_size_attr) { - output_size.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } - ASSERT_VECTOR_LENGTH(output_size, 2); - - std::vector<int> stride; - auto stride_attr = - op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : stride_attr) { - stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } - ASSERT_VECTOR_LENGTH(stride, 2); - - std::vector<int> offset; - auto offset_attr = - op.getAttr("offset").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &int_attr : offset_attr) { - offset.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } - ASSERT_VECTOR_LENGTH(offset, 2); - - int32_t shift = op.getAttr("shift").dyn_cast<mlir::IntegerAttr>().getInt(); - - std::vector<float> stride_fp; - auto stride_fp_attr = - op.getAttr("stride_fp").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &fp_attr : stride_fp_attr) { - stride_fp.push_back(fp_attr.dyn_cast<mlir::FloatAttr>().getValueAsDouble()); - } - ASSERT_VECTOR_LENGTH(stride_fp, 2); - - std::vector<float> offset_fp; - auto offset_fp_attr = - op.getAttr("offset_fp").dyn_cast<mlir::ArrayAttr>().getValue(); - for (auto &fp_attr : offset_fp_attr) { - offset_fp.push_back(fp_attr.dyn_cast<mlir::FloatAttr>().getValueAsDouble()); - } - ASSERT_VECTOR_LENGTH(offset_fp, 2); - auto mode_str = op.getAttr("mode").dyn_cast<mlir::StringAttr>().getValue().str(); ResizeMode mode = ResizeModeStr2Enum(mode_str); - TosaResizeAttribute attribute(output_size, stride, offset, shift, stride_fp, - offset_fp, mode); + TosaResizeAttribute attribute({}, {}, {}, mode); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_RESIZE, Attribute_ResizeAttribute, &attribute, - std::vector<std::string>{input_name}, + std::vector<std::string>{input_name, scale_name, offset_name, + border_name}, std::vector<std::string>{output_name}); return tyop; @@ -1048,17 +1364,15 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::MulOp>( mlir::Operation &op) const { - std::string input0_name = GetTensorName(op.getOperand(0)); - std::string input1_name = GetTensorName(op.getOperand(1)); - std::string output_name = GetTensorName(op.getResult(0)); - - int32_t shift = op.getAttr("shift").dyn_cast<mlir::IntegerAttr>().getInt(); - TosaMulAttribute attribute(shift); + mlir::tosa::MulOp mul_op = mlir::cast<mlir::tosa::MulOp>(op); + std::string input0_name = GetTensorName(mul_op.getInput1()); + std::string input1_name = GetTensorName(mul_op.getInput2()); + std::string output_name = GetTensorName(mul_op.getOutput()); + std::string shift_name = GetTensorName(mul_op.getShift()); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_MUL, Attribute_MulAttribute, &attribute, - std::vector<std::string>{input0_name, input1_name}, + Op_MUL, Attribute_NONE, nullptr, {input0_name, input1_name, shift_name}, std::vector<std::string>{output_name}); return tyop; @@ -1078,8 +1392,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ArithmeticRightShiftOp>( TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_ARITHMETIC_RIGHT_SHIFT, Attribute_ArithmeticRightShiftAttribute, - &attribute, - std::vector<std::string>{input0_name, input1_name}, + &attribute, std::vector<std::string>{input0_name, input1_name}, std::vector<std::string>{output_name}); return tyop; @@ -1093,7 +1406,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TableOp>( std::string output_name = GetTensorName(op.getResult(0)); // Match table tensor as compile-time constant attribute - // TODO: fix when MLIR dialect changes mlir::ElementsAttr table_elems; if (!matchPattern(op.getOperand(1), m_Constant(&table_elems))) return nullptr; @@ -1127,28 +1439,27 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>( bool per_channel = op.getAttr("per_channel").dyn_cast<mlir::BoolAttr>().getValue(); - std::vector<int> multiplier, shift; - auto multiplier_attr = - op.getAttr("multiplier").dyn_cast<mlir::ArrayAttr>().getValue(); - auto shift_attr = op.getAttr("shift").dyn_cast<mlir::ArrayAttr>().getValue(); + bool input_unsigned = + op.getAttr("input_unsigned").dyn_cast<mlir::BoolAttr>().getValue(); + bool output_unsigned = + op.getAttr("output_unsigned").dyn_cast<mlir::BoolAttr>().getValue(); - for (auto &int_attr : multiplier_attr) { - multiplier.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } + auto input = op.getOperand(0); + auto input_ty = input.getType().cast<mlir::RankedTensorType>(); + auto output = op.getResult(0); + auto output_ty = output.getType().cast<mlir::RankedTensorType>(); - for (auto &int_attr : shift_attr) { - shift.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt()); - } + std::string input_name = GetTensorName(input); + std::string multiplier_name = GetTensorName(op.getOperand(1)); + std::string shift_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(output); - std::string input_name = GetTensorName(op.getOperand(0)); - std::string output_name = GetTensorName(op.getResult(0)); - - TosaRescaleAttribute attribute(input_zp, output_zp, multiplier, shift, - scale32, double_round, per_channel); + TosaRescaleAttribute attribute(input_zp, output_zp, scale32, double_round, + per_channel, input_unsigned, output_unsigned); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_RESCALE, Attribute_RescaleAttribute, &attribute, - std::vector<std::string>{input_name}, + std::vector<std::string>{input_name, multiplier_name, shift_name}, std::vector<std::string>{output_name}); return tyop; @@ -1161,39 +1472,77 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::CustomOp>( std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); + const std::string implementation_attrs = op.getAttr("implementation_attrs") + .cast<mlir::StringAttr>() + .getValue() + .str(); + + std::vector<uint8_t> attrs_data(implementation_attrs.size()); + memcpy(attrs_data.data(), implementation_attrs.data(), attrs_data.size()); + TosaCustomAttribute attribute( + op.getAttr("operator_name").cast<mlir::StringAttr>().getValue().str(), + op.getAttr("domain_name").cast<mlir::StringAttr>().getValue().str(), + attrs_data); + TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_CUSTOM, Attribute_NONE, nullptr, + Op_CUSTOM, Attribute_CustomAttribute, &attribute, std::vector<std::string>{input_name}, std::vector<std::string>{output_name}); return tyop; } +namespace { + +// serialize a region and all its blocks, and return region's return values +TosaSerializationRegion * +BuildRegion(mlir::Region ®ion, const std::string region_name, + const bool isolated_from_above, + TosaSerializationRegionBuilder *curr_region_builder, + TosaSerializationHandler *tsh, + std::vector<mlir::Value> &return_values, bool is_top = false) { + TosaSerializationRegion *ser_region = + new TosaSerializationRegion(region_name, {}); + assert(ser_region); + tsh->GetRegions().push_back(ser_region); + + TosaSerializationRegionBuilder *parent_value_scope = + isolated_from_above ? nullptr : curr_region_builder; + + TosaSerializationRegionBuilder region_builder(ser_region, ®ion, + parent_value_scope, tsh); + if (region_builder.BuildAllBlocksInRegion(is_top, return_values).failed()) { + return nullptr; + } + return ser_region; +} + +static int input_tensor_index = 0; +static int intermediate_tensor_index = 0; +static int output_tensor_index = 0; + +} // namespace + template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>( mlir::Operation &op) const { + const std::string op_name = op.getName().getStringRef().str(); + const bool isolated_from_above = + op.hasTrait<mlir::OpTrait::IsIsolatedFromAbove>(); + auto curr_region_builder = GetRegionBuilder(); std::vector<std::string> input_names, output_names; + std::vector<mlir::Value> then_yields, else_yields; + auto tsh = GetTsh(); mlir::Region &then_region = op.getRegion(0); mlir::Region &else_region = op.getRegion(1); - std::vector<mlir::Value> then_yields, else_yields; - TosaSerializationBasicBlock *then_block = nullptr; - TosaSerializationBasicBlock *else_block = nullptr; - - // Building then branch block - std::string then_block_name = - "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); - then_block = new TosaSerializationBasicBlock( - then_block_name, std::vector<TosaSerializationOperator *>(), - std::vector<TosaSerializationTensor *>(), std::vector<std::string>(), - std::vector<std::string>()); - assert(then_block); - block_builder->GetTsh()->GetBlocks().push_back(then_block); - - TosaSerializationBlockBuilder then_block_builder( - then_block, block_builder->GetTsh(), &then_region); - if (then_block_builder.BuildAllOpsInRegion(then_yields).failed()) { + + const std::string then_region_name = op_name + "_then_region"; + TosaSerializationRegion *ser_then_region = + BuildRegion(then_region, then_region_name, isolated_from_above, + curr_region_builder, tsh, then_yields); + if (!ser_then_region) { return nullptr; } if (then_yields.size() != op.getNumResults()) { @@ -1202,19 +1551,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>( return nullptr; } - // Building else branch block - std::string else_block_name = - "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); - else_block = new TosaSerializationBasicBlock( - else_block_name, std::vector<TosaSerializationOperator *>(), - std::vector<TosaSerializationTensor *>(), std::vector<std::string>(), - std::vector<std::string>()); - assert(else_block); - block_builder->GetTsh()->GetBlocks().push_back(else_block); - - TosaSerializationBlockBuilder else_block_builder( - else_block, block_builder->GetTsh(), &else_region); - if (else_block_builder.BuildAllOpsInRegion(else_yields).failed()) { + const std::string else_region_name = op_name + "_else_region"; + TosaSerializationRegion *ser_else_region = + BuildRegion(else_region, else_region_name, isolated_from_above, + curr_region_builder, tsh, else_yields); + if (!ser_else_region) { return nullptr; } if (else_yields.size() != op.getNumResults()) { @@ -1223,7 +1564,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>( return nullptr; } - TosaCondIfAttribute attribute(then_block->GetName(), else_block->GetName()); + TosaCondIfAttribute attribute(then_region_name, else_region_name); for (size_t i = 0; i < op.getNumOperands(); i++) { std::string input_name = GetTensorName(op.getOperand(i)); @@ -1235,9 +1576,9 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>( output_names.push_back(output_name); } - TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_COND_IF, Attribute_CondIfAttribute, &attribute, - input_names, output_names); + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_COND_IF, Attribute_CondIfAttribute, + &attribute, input_names, output_names); return tyop; } @@ -1246,27 +1587,22 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>( mlir::Operation &op) const { + const std::string op_name = op.getName().getStringRef().str(); + const bool isolated_from_above = + op.hasTrait<mlir::OpTrait::IsIsolatedFromAbove>(); + auto curr_region_builder = GetRegionBuilder(); std::vector<std::string> input_names, output_names; + auto tsh = GetTsh(); mlir::Region &cond_region = op.getRegion(0); mlir::Region &body_region = op.getRegion(1); std::vector<mlir::Value> cond_yields, body_yields; - TosaSerializationBasicBlock *cond_block = nullptr; - TosaSerializationBasicBlock *body_block = nullptr; - - // Building cond branch block - std::string cond_block_name = - "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); - cond_block = new TosaSerializationBasicBlock( - cond_block_name, std::vector<TosaSerializationOperator *>(), - std::vector<TosaSerializationTensor *>(), std::vector<std::string>(), - std::vector<std::string>()); - assert(cond_block); - block_builder->GetTsh()->GetBlocks().push_back(cond_block); - - TosaSerializationBlockBuilder cond_block_builder( - cond_block, block_builder->GetTsh(), &cond_region); - if (cond_block_builder.BuildAllOpsInRegion(cond_yields).failed()) { + + const std::string cond_region_name = op_name + "_cond_region"; + TosaSerializationRegion *ser_cond_region = + BuildRegion(cond_region, cond_region_name, isolated_from_above, + curr_region_builder, tsh, cond_yields); + if (!ser_cond_region) { return nullptr; } if (cond_yields.size() != 1) { @@ -1274,19 +1610,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>( return nullptr; } - // Building body branch block - std::string body_block_name = - "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); - body_block = new TosaSerializationBasicBlock( - body_block_name, std::vector<TosaSerializationOperator *>(), - std::vector<TosaSerializationTensor *>(), std::vector<std::string>(), - std::vector<std::string>()); - assert(body_block); - block_builder->GetTsh()->GetBlocks().push_back(body_block); - - TosaSerializationBlockBuilder body_block_builder( - body_block, block_builder->GetTsh(), &body_region); - if (body_block_builder.BuildAllOpsInRegion(body_yields).failed()) { + const std::string body_region_name = op_name + "_body_region"; + TosaSerializationRegion *ser_body_region = + BuildRegion(body_region, body_region_name, isolated_from_above, + curr_region_builder, tsh, body_yields); + if (!ser_body_region) { return nullptr; } if (body_yields.size() != op.getNumResults()) { @@ -1295,8 +1623,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>( return nullptr; } - TosaWhileLoopAttribute attribute(cond_block->GetName(), - body_block->GetName()); + TosaWhileLoopAttribute attribute(cond_region_name, body_region_name); for (size_t i = 0; i < op.getNumOperands(); i++) { std::string input_name = GetTensorName(op.getOperand(i)); @@ -1308,135 +1635,286 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>( output_names.push_back(output_name); } + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_WHILE_LOOP, Attribute_WhileLoopAttribute, + &attribute, input_names, output_names); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build<mlir::tosa::RFFT2dOp>( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_real_name = GetTensorName(op.getResult(0)); + std::string output_imag_name = GetTensorName(op.getResult(1)); + + bool local_bound = + op.hasAttr("local_bound") + ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue() + : false; + + TosaRFFTAttribute attribute(local_bound); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_RFFT2D, Attribute_RFFTAttribute, &attribute, + std::vector<std::string>{input_name}, + std::vector<std::string>{output_real_name, output_imag_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build<mlir::tosa::FFT2dOp>( + mlir::Operation &op) const { + + bool inverse = op.getAttr("inverse").dyn_cast<mlir::BoolAttr>().getValue(); + + bool local_bound = + op.hasAttr("local_bound") + ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue() + : false; + + std::string input_real_name = GetTensorName(op.getOperand(0)); + std::string input_imag_name = GetTensorName(op.getOperand(1)); + std::string output_real_name = GetTensorName(op.getResult(0)); + std::string output_imag_name = GetTensorName(op.getResult(1)); + + TosaFFTAttribute attribute(inverse, local_bound); + TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute, - input_names, output_names); + Op_FFT2D, Attribute_FFTAttribute, &attribute, + std::vector<std::string>{input_real_name, input_imag_name}, + std::vector<std::string>{output_real_name, output_imag_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build<mlir::tosa::VariableReadOp>( + mlir::Operation &op) const { + + std::string input_name = GetVariableTensorName(&op); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_IDENTITY, Attribute_NONE, nullptr, + std::vector<std::string>{input_name}, + std::vector<std::string>{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build<mlir::tosa::VariableWriteOp>( + mlir::Operation &op) const { + + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetVariableTensorName(&op); + + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_IDENTITY, Attribute_NONE, nullptr, + std::vector<std::string>{input_name}, + std::vector<std::string>{output_name}); return tyop; } /* End translating TOSA operator */ +mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion( + bool is_top, std::vector<mlir::Value> &return_values) { + std::string region_name = ser_region->GetName(); + int block_index = 0; + for (auto &block : this->region->getBlocks()) { + // must name first block of top region "main" + const std::string block_name = + (is_top && block_index == 0) + ? "main" + : (region_name + "_bb" + std::to_string(block_index++)); + TosaSerializationBasicBlock *ser_block = new TosaSerializationBasicBlock( + block_name, region_name, std::vector<TosaSerializationOperator *>(), + std::vector<TosaSerializationTensor *>(), std::vector<std::string>(), + std::vector<std::string>()); + + // build the block + TosaSerializationBlockBuilder block_builder(ser_block, this, &block); + // Region Builders need access to block builders + block_builders.push_back(&block_builder); + + if (block_builder.BuildAllOpsInBlock(return_values).failed()) { + return mlir::failure(); + } + + if (return_values.empty()) { + llvm::errs() << "BWarning: graph doesn't have return values\n"; + } -mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( + // Add serialized block to serialized region + ser_region->GetBlocks().push_back(ser_block); + } + + return mlir::success(); +} + +mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( std::vector<mlir::Value> &return_values) { TosaSerializationOperator *ser_operator = nullptr; TosaSerializationTensor *ser_tensor = nullptr; size_t num_blocks_in_region = 0; - static int input_tensor_index = 0; - static int intermediate_tensor_index = 0; - static int output_tensor_index = 0; TosaSerializationOperatorBuilder op_builder(this); - for (auto &bb : region->getBlocks()) { - num_blocks_in_region++; - - if (num_blocks_in_region > 1) { - llvm::errs() << "Invalid MLIR: multiple blocks in a region\n"; - return mlir::failure(); - } - - // We always have one block for each region right now - assert(bb.isEntryBlock()); - - // Specify block input tensor name - for (auto args : bb.getArguments()) { - std::string block_input_name = - "TosaInput_" + std::to_string(input_tensor_index++); - block->GetInputs().push_back(block_input_name); - tensor_map[args] = block_input_name; - input_tensor_map[args] = block_input_name; - } + // Specify block input tensor name + for (auto args : block->getArguments()) { + std::string block_input_name = + "TosaInput_" + std::to_string(input_tensor_index++); + ser_block->GetInputs().push_back(block_input_name); + tensor_map[args] = block_input_name; + input_tensor_map[args] = block_input_name; + } + + // Build tensor_map + for (auto &op : block->getOperations()) { + if (llvm::isa<mlir::tosa::VariableOp>(op)) { + RegisterVariableOp(op); + } else if (!(llvm::isa<mlir::tosa::YieldOp>(op) || + llvm::isa<mlir::func::ReturnOp>(op) || + llvm::isa<mlir::tensor::CastOp>(op))) { + for (uint32_t i = 0; i < op.getNumResults(); i++) { + std::string intermediate_tensor_name = + "layer_" + std::to_string(intermediate_tensor_index++); + tensor_map[op.getResult(i)] = intermediate_tensor_name; + } + } else { + if (llvm::isa<mlir::tensor::CastOp>(op)) + continue; + // Override return tensor name + for (auto val : op.getOperands()) { + // Workaround to skip mlir::tensor::CastOp before return + mlir::Operation *val_defining_op = val.getDefiningOp(); + if (val_defining_op) { + if (llvm::isa<mlir::tensor::CastOp>(*val_defining_op)) + val = val_defining_op->getOperand(0); + } - // Build tensor_map - for (auto &op : bb) { - if (!(llvm::isa<mlir::tosa::YieldOp>(op) || - llvm::isa<mlir::func::ReturnOp>(op) || - llvm::isa<mlir::tensor::CastOp>(op))) { - for (uint32_t i = 0; i < op.getNumResults(); i++) { - std::string intermediate_tensor_name = - "layer_" + std::to_string(intermediate_tensor_index++); - tensor_map[op.getResult(i)] = intermediate_tensor_name; + // Sanity check. This mlir::Value should be built in map since graph + // is DAG + if (tensor_map.find(val) == tensor_map.end()) { + llvm::errs() << "ERROR: Can't find built mlir::Value key.\n"; + return mlir::failure(); } - } else { - if (llvm::isa<mlir::tensor::CastOp>(op)) - continue; - // Override return tensor name - for (auto val : op.getOperands()) { - // Workaround to skip mlir::tensor::CastOp before return - mlir::Operation *val_defining_op = val.getDefiningOp(); - if (val_defining_op) { - if (llvm::isa<mlir::tensor::CastOp>(*val_defining_op)) - val = val_defining_op->getOperand(0); - } - - // Sanity check. This mlir::Value should be built in map since graph - // is DAG - if (tensor_map.find(val) == tensor_map.end()) { - llvm::errs() << "ERROR: Can't find built mlir::Value key.\n"; - return mlir::failure(); - } - - // If returned value is block input, short-circuit the tensor name - // Otherwise, build a new output name and override the origin tensor - // name - if (input_tensor_map.find(val) != input_tensor_map.end()) { - block->GetOutputs().push_back(input_tensor_map[val]); - return_values.push_back(val); - } else { - std::string output_name = - "TosaOutput_" + std::to_string(output_tensor_index++); - tensor_map[val] = output_name; - block->GetOutputs().push_back(output_name); - return_values.push_back(val); - } + + // If returned value is block input, short-circuit the tensor name + // Otherwise, build a new output name and override the origin tensor + // name + if (input_tensor_map.find(val) != input_tensor_map.end()) { + ser_block->GetOutputs().push_back(input_tensor_map[val]); + return_values.push_back(val); + } else { + std::string output_name = + "TosaOutput_" + std::to_string(output_tensor_index++); + tensor_map[val] = output_name; + ser_block->GetOutputs().push_back(output_name); + return_values.push_back(val); } } } + } - // Build tensor + // Build variable tensor + for (auto pair : variable_tensor_op_map) { + mlir::Operation *op = pair.second; + mlir::Value val = op->getResult(0); + mlir::RankedTensorType tensor_type = op->getAttr("type") + .cast<mlir::TypeAttr>() + .getValue() + .cast<mlir::RankedTensorType>(); - // The tensor_map is sorted by hashed mlir::Value types. - // For serialization, sort tensors alphabetically by name for a - // deterministic and human-friendly ordering. - std::map<std::string, mlir::Value> tensor_name_sort; - for (auto pair : tensor_map) - tensor_name_sort[pair.second] = pair.first; + std::string variable_mlir_name = + op->getAttr("name").cast<mlir::StringAttr>().getValue().str(); - for (auto pair : tensor_name_sort) { - ser_tensor = BuildTosaSerializationTensor(pair.second /* val */, - pair.first /* name */); - if (!ser_tensor) { - llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; - return mlir::failure(); - } - block->GetTensors().push_back(ser_tensor); + ser_tensor = BuildTosaSerializationVariableTensor( + tensor_type /* tensor_type */, pair.first /* flatbuffer name */, + variable_mlir_name); + if (!ser_tensor) { + llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; + return mlir::failure(); } - - // Build operator - for (auto &op : bb) { - if (llvm::isa<mlir::tosa::YieldOp>(op) || - llvm::isa<mlir::func::ReturnOp>(op) || - llvm::isa<mlir::tensor::CastOp>(op)) - continue; - ser_operator = BuildTosaSerializationOperator(op_builder, op); - if (!ser_operator) { - llvm::errs() << "ERROR: Failed to build TosaSerializationOperator\n"; + // Initialize if "initial_value" attribute exists. If not, set data to all + // zeros + mlir::Attribute initial_value = op->getAttr("initial_value"); + std::vector<uint8_t> u8_data; + if (initial_value) { + if (initial_value.isa<mlir::DenseElementsAttr>()) { + if (op_builder + .GetDataFromAttribute(*op, initial_value, + tensor_type.getElementType(), u8_data) + .failed()) { + llvm::errs() << "ERROR: GetDataFromAttribute() fails when building " + "initial_value of variable tensor\n"; + return mlir::failure(); + } + } else { + llvm::errs() << "ERROR: Unknown initial_value attribute type\n"; return mlir::failure(); } - block->GetOperators().push_back(ser_operator); + } else { + TosaSerializationHandler::ForceAlignTensorData(u8_data); } + ser_tensor->SetData(u8_data); + ser_block->GetTensors().push_back(ser_tensor); + } + + // Build tensor + + // The tensor_map is sorted by hashed mlir::Value types. + // For serialization, sort tensors alphabetically by name for a + // deterministic and human-friendly ordering. + std::map<std::string, mlir::Value> tensor_name_sort; + for (auto pair : tensor_map) + tensor_name_sort[pair.second] = pair.first; + + for (auto pair : tensor_name_sort) { + mlir::RankedTensorType tensor_type = + pair.second.getType().cast<mlir::RankedTensorType>(); + ser_tensor = BuildTosaSerializationTensor(pair.second /* val */, + pair.first /* name */); + if (!ser_tensor) { + llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; + return mlir::failure(); + } + ser_block->GetTensors().push_back(ser_tensor); + } + + // Build operator + for (auto &op : block->getOperations()) { + if (llvm::isa<mlir::tosa::YieldOp>(op) || + llvm::isa<mlir::func::ReturnOp>(op) || + llvm::isa<mlir::tensor::CastOp>(op) || + llvm::isa<mlir::tosa::VariableOp>(op)) + continue; + ser_operator = BuildTosaSerializationOperator(op_builder, op); + if (!ser_operator) { + llvm::errs() << "ERROR: Failed to build TosaSerializationOperator\n"; + return mlir::failure(); + } + ser_block->GetOperators().push_back(ser_operator); } - return mlir::success(); } TosaSerializationOperator * TosaSerializationBlockBuilder::BuildTosaSerializationOperator( const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) { - std::string full_op_name = op.getName().getStringRef().str(); TosaSerializationOperator *target_operator = nullptr; - if (false) { + if (llvm::isa<mlir::tosa::VariableReadOp>(op)) { + target_operator = op_builder.build<mlir::tosa::VariableReadOp>(op); + } else if (llvm::isa<mlir::tosa::VariableWriteOp>(op)) { + target_operator = op_builder.build<mlir::tosa::VariableWriteOp>(op); } #define DEF_OPERATOR(MLIR_OP) \ else if (llvm::isa<mlir::tosa::MLIR_OP##Op>(op)) { \ @@ -1455,17 +1933,22 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator( return nullptr; } + if (llvm::isa<mlir::tosa::VariableReadOp>(op) || + llvm::isa<mlir::tosa::VariableWriteOp>(op)) { + return target_operator; + } + // Sanity check the number of inputs/outputs of TOSA dialect matches the // number of TOSA flatbuffer if (op.getNumOperands() != target_operator->GetInputTensorNames().size()) { - llvm::errs() << "WARNING. MLIR operator has " << op.getNumOperands() + llvm::errs() << "WARNING: MLIR operator has " << op.getNumOperands() << " input tensors != Flatbuffer " "operator has " << target_operator->GetInputTensorNames().size() << " input tensors\n"; } if (op.getNumResults() != target_operator->GetOutputTensorNames().size()) { - llvm::errs() << "WARNING. MLIR operator has " << op.getNumResults() + llvm::errs() << "WARNING: MLIR operator has " << op.getNumResults() << " output tensors != Flatbuffer " "operator has " << target_operator->GetOutputTensorNames().size() @@ -1476,30 +1959,83 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator( } TosaSerializationTensor * +TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor( + mlir::RankedTensorType tensor_type, const std::string &name, + const std::string &variable_mlir_name) { + // If tensor already created before, use that tensor directly, create a new + // one otherwise + TosaSerializationTensor *ts = ser_block->GetTensorByName(name); + if (ts) { + return nullptr; + } + + std::vector<int32_t> shape(tensor_type.getShape().begin(), + tensor_type.getShape().end()); + + DType type = Type2DType(tensor_type.getElementType()); + + ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(), + /* is_variable = */ true, + /* is_unranked = */ false, + variable_mlir_name); + + return ts; +} + +TosaSerializationTensor * TosaSerializationBlockBuilder::BuildTosaSerializationTensor( mlir::Value val, const std::string &name) { // If tensor already created before, use that tensor directly, create a new // one otherwise - TosaSerializationTensor *ts = block->GetTensorByName(name); + TosaSerializationTensor *ts = ser_block->GetTensorByName(name); if (ts) { return nullptr; } - mlir::RankedTensorType tensor = - val.getType().dyn_cast<mlir::RankedTensorType>(); - std::vector<int32_t> shape(tensor.getShape().begin(), - tensor.getShape().end()); - DType type = Type2DType(tensor.getElementType()); + // handling of tosa.shape values + if (auto shape_ty = val.getType().dyn_cast<mlir::tosa::shapeType>()) { + auto rank = shape_ty.getRank(); + std::vector<int32_t> shape; + if (rank > 0) { + shape.push_back(rank); + } + ts = new TosaSerializationTensor(name, + /* shape = */ shape, + /* type = */ DType::DType_SHAPE, + /* data = */ std::vector<uint8_t>()); + return ts; + } - ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>()); + auto ttype = val.getType().dyn_cast<mlir::TensorType>(); + if (!ttype) { + llvm::errs() << "TOSA serialization, supplied value is not of TensorType\n"; + return nullptr; + } + + const bool is_unranked = !ttype.hasRank(); + std::vector<int32_t> shape; + if (!is_unranked) { + auto shaped = val.getType().dyn_cast<mlir::ShapedType>(); + assert(shaped); + for (int idx = 0; idx < ttype.getRank(); idx++) { + if (shaped.isDynamicDim(idx)) { + shape.push_back(0); // size of 0 represents dynamic dimension + } else { + auto dim = shaped.getDimSize(idx); + shape.push_back(dim); + } + } + } + + DType type = Type2DType(ttype.getElementType()); + ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(), + /* variable = */ false, is_unranked); return ts; } mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func, TosaSerializationHandler &tsh) { - TosaSerializationBasicBlock *main_block; - mlir::Region *main_region = func.getCallableRegion(); std::vector<mlir::Value> main_returns; @@ -1508,21 +2044,22 @@ mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func, return mlir::failure(); } - if (!tsh.GetBlocks().empty()) { - llvm::errs() << "Internal Error: TosaSerializationHandler's block list " + if (!tsh.GetRegions().empty()) { + llvm::errs() << "Internal Error: TosaSerializationHandler's region list " "must be empty\n"; return mlir::failure(); } - main_block = new TosaSerializationBasicBlock( - std::string("main"), std::vector<TosaSerializationOperator *>(), - std::vector<TosaSerializationTensor *>(), std::vector<std::string>(), - std::vector<std::string>()); - assert(main_block); - tsh.GetBlocks().push_back(main_block); + // reset static counters + input_tensor_index = 0; + intermediate_tensor_index = 0; + output_tensor_index = 0; - TosaSerializationBlockBuilder block_builder(main_block, &tsh, main_region); - if (block_builder.BuildAllOpsInRegion(main_returns).failed()) { + TosaSerializationRegion *ser_main_region = + BuildRegion(*main_region, "main", /* isolated_from_above = */ true, + /* parent_value_scope = */ nullptr, &tsh, main_returns, + /* is_top = */ true); + if (!ser_main_region) { return mlir::failure(); } @@ -1580,20 +2117,42 @@ mlir::LogicalResult dumpTosaJSON(mlir::func::FuncOp &func) { return mlir::success(); } -namespace mlir { +#define GEN_PASS_DEF_TOSASERIALIZATIONPASS +namespace mlir { namespace tosa { - namespace { class TosaSerialize : public TosaSerializationPassBase<TosaSerialize> { public: void runOnOperation() final { - auto function = getOperation(); + auto moduleOp = getOperation(); - if (dumpTosaFlatbuffer(function).failed()) { - llvm::errs() << "Failed to generate TOSA flatbuffer...\n"; - return signalPassFailure(); + // iterate through each op in the moduleOp, call dumpTosaFlatbuffer if + // that's a func.funcOp + + auto regions = moduleOp->getRegions(); + auto region_size = regions.size(); + + auto region_0 = regions.begin(); + auto block_size = region_0->getBlocks().size(); + + auto block_0 = region_0->getBlocks().begin(); + + auto op_size = block_0->getOperations().size(); + + for (auto it = block_0->getOperations().begin(); + it != block_0->getOperations().end(); ++it) { + // read variableOps that are declared outside of functionOps + if (llvm::isa<mlir::tosa::VariableOp>(*it)) { + RegisterVariableOp(*it); + } else if (llvm::isa<mlir::func::FuncOp>(*it)) { + auto funcOp = dyn_cast<mlir::func::FuncOp>((*it)); + if (dumpTosaFlatbuffer(funcOp).failed()) { + llvm::errs() << "Failed to generate TOSA flatbuffer...\n"; + return signalPassFailure(); + } + } } } }; @@ -1614,7 +2173,7 @@ public: } // anonymous namespace // Creates an instance of the TOSA flatbuffer generation pass -std::unique_ptr<Pass> createTosaSerializePass() { +std::unique_ptr<OperationPass<ModuleOp>> createTosaSerializePass() { return std::make_unique<TosaSerialize>(); } diff --git a/third_party/serialization_lib b/third_party/serialization_lib -Subproject 343d6a703c3a270a01102ec468b59ef2967b595 +Subproject 3aebe2bd863d6e0cb82171984cd49e5ad516d0d |