aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.pre-commit-config.yaml15
-rw-r--r--CMakeLists.txt8
-rw-r--r--include/DeserializationPasses.h46
-rw-r--r--include/DeserializationPasses.td25
-rw-r--r--include/SerializationPasses.h6
-rw-r--r--include/SerializationPasses.td2
-rw-r--r--include/operator.def22
-rw-r--r--include/schema_operator.def103
-rw-r--r--src/TosaDeserialize.cpp2128
-rw-r--r--src/TosaSerialize.cpp1735
m---------third_party/serialization_lib0
11 files changed, 3494 insertions, 596 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..e52e423
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,15 @@
+# Copyright (c) 2023 Arm Limited.
+# SPDX-License-Identifier: Apache-2.0
+
+# See https://pre-commit.com for more information
+# See https://pre-commit.com/hooks.html for more hooks
+repos:
+- repo: local
+ hooks:
+ - id: clang-format
+ name: clang-format
+ exclude: build|third_party
+ language: system
+ entry: clang-format
+ types: ["c++"]
+ args: ["-i"]
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 968d73b..6bc3e6b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -3,7 +3,6 @@
cmake_minimum_required(VERSION 3.13.4)
project(MlirTosaPasses)
-set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ standard to conform to")
set(CMAKE_CXX_STANDARD_REQUIRED YES)
set(CMAKE_VERBOSE_MAKEFILE ON)
@@ -12,21 +11,28 @@ set(CMAKE_VERBOSE_MAKEFILE ON)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/third_party/serialization_lib/include)
+include_directories(${PROJECT_SOURCE_DIR}/third_party/serialization_lib/third_party/half/include)
include_directories(${PROJECT_SOURCE_DIR}/third_party/serialization_lib/third_party/flatbuffers/include)
set(LLVM_TARGET_DEFINITIONS include/SerializationPasses.td)
mlir_tablegen(include/SerializationPasses.h.inc -gen-pass-decls -name TosaSerialization)
add_public_tablegen_target(tosa_serialization_passes_inc_gen)
+set(LLVM_TARGET_DEFINITIONS include/DeserializationPasses.td)
+mlir_tablegen(include/DeserializationPasses.h.inc -gen-pass-decls -name TosaDeserialization)
+add_public_tablegen_target(tosa_deserialization_passes_inc_gen)
+
# Compile the TOSA serialization_lib
add_subdirectory(third_party/serialization_lib)
add_mlir_library(tosa_serialize
src/TosaSerialize.cpp
+ src/TosaDeserialize.cpp
DEPENDS
mlir-headers
tosa_serialization_passes_inc_gen
+ tosa_deserialization_passes_inc_gen
LINK_LIBS PRIVATE
tosa_serialization_lib
diff --git a/include/DeserializationPasses.h b/include/DeserializationPasses.h
new file mode 100644
index 0000000..1a38814
--- /dev/null
+++ b/include/DeserializationPasses.h
@@ -0,0 +1,46 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef INCLUDE_DESERIALIZATION_PASSES_H
+#define INCLUDE_DESERIALIZATION_PASSES_H
+
+#include <memory>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
+#include "mlir/IR/BuiltinOps.h" // from @llvm-project
+#include "mlir/IR/OwningOpRef.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaDeserializePass();
+std::unique_ptr<Pass> createTosaDeserializeJSONPass();
+
+// deserializes a tosa file and return an mlir module
+// if file_is_fbs is true, then treat file_name as a tosa flatbuffer file
+// otherwise, treat file_name as a tosa json file
+mlir::OwningOpRef<mlir::ModuleOp>
+BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context,
+ bool file_is_fbs = true);
+
+#define GEN_PASS_REGISTRATION
+#define GEN_PASS_CLASSES
+#include "include/DeserializationPasses.h.inc"
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // INCLUDE_DESERIALIZATION_PASSES_H
diff --git a/include/DeserializationPasses.td b/include/DeserializationPasses.td
new file mode 100644
index 0000000..999f0b4
--- /dev/null
+++ b/include/DeserializationPasses.td
@@ -0,0 +1,25 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+include "mlir/Pass/PassBase.td"
+
+def TosaDeserializationPass : Pass<"tosa-deserialize", "func::FuncOp"> {
+ let summary = "Deserialize TOSA flatbuffer. Clear original MLIR graph and generate TOSA MLIR";
+ let constructor = "createTosaDeserializePass()";
+}
+
+def TosaDeserializationJSONPass : Pass<"tosa-deserialize-json", "func::FuncOp"> {
+ let summary = "Deserialize TOSA flatbuffer JSON form. Clear original MLIR graph and generate TOSA MLIR";
+ let constructor = "createTosaDeserializeJSONPass()";
+}
diff --git a/include/SerializationPasses.h b/include/SerializationPasses.h
index 66c6d80..c769b15 100644
--- a/include/SerializationPasses.h
+++ b/include/SerializationPasses.h
@@ -19,16 +19,18 @@
#include <memory>
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
-#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h" // from @llvm-project
namespace mlir {
namespace tosa {
-std::unique_ptr<Pass> createTosaSerializePass();
+std::unique_ptr<OperationPass<ModuleOp>> createTosaSerializePass();
std::unique_ptr<Pass> createTosaSerializeJSONPass();
#define GEN_PASS_REGISTRATION
#define GEN_PASS_CLASSES
+#define GEN_PASS_DECL_TOSASERIALIZEPASS
#include "include/SerializationPasses.h.inc"
} // namespace tosa
diff --git a/include/SerializationPasses.td b/include/SerializationPasses.td
index 3bdeb1b..9cfc204 100644
--- a/include/SerializationPasses.td
+++ b/include/SerializationPasses.td
@@ -14,7 +14,7 @@
include "mlir/Pass/PassBase.td"
-def TosaSerializationPass : Pass<"tosa-serialize", "func::FuncOp"> {
+def TosaSerializationPass : Pass<"tosa-serialize", "mlir::ModuleOp"> {
let summary = "Generate TOSA flatbuffer serialized form";
let constructor = "createTosaSerializePass()";
}
diff --git a/include/operator.def b/include/operator.def
index 85bb5c9..0bd0d08 100644
--- a/include/operator.def
+++ b/include/operator.def
@@ -1,8 +1,8 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
-// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
-// (the "License"); you may not use this file except in compliance with
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://llvm.org/LICENSE.txt
@@ -27,12 +27,15 @@ DEF_OPERATOR(AvgPool2d)
DEF_OPERATOR(Conv2D)
DEF_OPERATOR(Conv3D)
DEF_OPERATOR(DepthwiseConv2D)
+DEF_OPERATOR(FFT2d)
DEF_OPERATOR(FullyConnected)
DEF_OPERATOR(MatMul)
DEF_OPERATOR(MaxPool2d)
+DEF_OPERATOR(RFFT2d)
DEF_OPERATOR(TransposeConv2D)
/* activation */
+DEF_OPERATOR(Erf)
DEF_OPERATOR(Clamp)
DEF_OPERATOR(Sigmoid)
DEF_OPERATOR(Tanh)
@@ -43,7 +46,7 @@ DEF_OPERATOR(ArithmeticRightShift)
DEF_OPERATOR(BitwiseAnd)
DEF_OPERATOR(BitwiseOr)
DEF_OPERATOR(BitwiseXor)
-DEF_OPERATOR(Div)
+DEF_OPERATOR(IntDiv)
DEF_OPERATOR(LogicalAnd)
DEF_OPERATOR(LogicalLeftShift)
DEF_OPERATOR(LogicalRightShift)
@@ -61,6 +64,7 @@ DEF_OPERATOR(Abs)
DEF_OPERATOR(BitwiseNot)
DEF_OPERATOR(Ceil)
DEF_OPERATOR(Clz)
+DEF_OPERATOR(Cos)
DEF_OPERATOR(Exp)
DEF_OPERATOR(Floor)
DEF_OPERATOR(Log)
@@ -68,6 +72,7 @@ DEF_OPERATOR(LogicalNot)
DEF_OPERATOR(Negate)
DEF_OPERATOR(Reciprocal)
DEF_OPERATOR(Rsqrt)
+DEF_OPERATOR(Sin)
/* elementwise - ternary */
DEF_OPERATOR(Select)
@@ -88,6 +93,7 @@ DEF_OPERATOR(ReduceSum)
/* memory operation */
DEF_OPERATOR(Concat)
DEF_OPERATOR(Pad)
+DEF_OPERATOR(Dim)
DEF_OPERATOR(Reshape)
DEF_OPERATOR(Reverse)
DEF_OPERATOR(Slice)
@@ -115,3 +121,11 @@ DEF_OPERATOR(Custom)
/* control flow operators */
DEF_OPERATOR(If)
DEF_OPERATOR(While)
+
+/* shape operators */
+DEF_OPERATOR(ConstShape)
+DEF_OPERATOR(ConcatShape)
+DEF_OPERATOR(AddShape)
+DEF_OPERATOR(SubShape)
+DEF_OPERATOR(MulShape)
+DEF_OPERATOR(DivShape)
diff --git a/include/schema_operator.def b/include/schema_operator.def
new file mode 100644
index 0000000..02b639e
--- /dev/null
+++ b/include/schema_operator.def
@@ -0,0 +1,103 @@
+// Copyright (c) 2023-2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ Syntax:
+ DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME)
+
+ Description:
+ SCHEMA_OP_NAME: the schema operator names, must match Op names in schema/tosa.fbs in serialization_lib repo
+*/
+
+/* schema operators */
+DEF_SCHEMA_OPERATOR(ARGMAX)
+DEF_SCHEMA_OPERATOR(AVG_POOL2D)
+DEF_SCHEMA_OPERATOR(CONV2D)
+DEF_SCHEMA_OPERATOR(CONV3D)
+DEF_SCHEMA_OPERATOR(DEPTHWISE_CONV2D)
+DEF_SCHEMA_OPERATOR(FULLY_CONNECTED)
+DEF_SCHEMA_OPERATOR(MATMUL)
+DEF_SCHEMA_OPERATOR(MAX_POOL2D)
+DEF_SCHEMA_OPERATOR(TRANSPOSE_CONV2D)
+DEF_SCHEMA_OPERATOR(CLAMP)
+DEF_SCHEMA_OPERATOR(RESERVED)
+DEF_SCHEMA_OPERATOR(ERF)
+DEF_SCHEMA_OPERATOR(SIGMOID)
+DEF_SCHEMA_OPERATOR(TANH)
+DEF_SCHEMA_OPERATOR(ADD)
+DEF_SCHEMA_OPERATOR(ARITHMETIC_RIGHT_SHIFT)
+DEF_SCHEMA_OPERATOR(BITWISE_AND)
+DEF_SCHEMA_OPERATOR(BITWISE_OR)
+DEF_SCHEMA_OPERATOR(BITWISE_XOR)
+DEF_SCHEMA_OPERATOR(INTDIV)
+DEF_SCHEMA_OPERATOR(LOGICAL_AND)
+DEF_SCHEMA_OPERATOR(LOGICAL_LEFT_SHIFT)
+DEF_SCHEMA_OPERATOR(LOGICAL_RIGHT_SHIFT)
+DEF_SCHEMA_OPERATOR(LOGICAL_OR)
+DEF_SCHEMA_OPERATOR(LOGICAL_XOR)
+DEF_SCHEMA_OPERATOR(MAXIMUM)
+DEF_SCHEMA_OPERATOR(MINIMUM)
+DEF_SCHEMA_OPERATOR(MUL)
+DEF_SCHEMA_OPERATOR(POW)
+DEF_SCHEMA_OPERATOR(SUB)
+DEF_SCHEMA_OPERATOR(TABLE)
+DEF_SCHEMA_OPERATOR(ABS)
+DEF_SCHEMA_OPERATOR(BITWISE_NOT)
+DEF_SCHEMA_OPERATOR(CEIL)
+DEF_SCHEMA_OPERATOR(CLZ)
+DEF_SCHEMA_OPERATOR(EXP)
+DEF_SCHEMA_OPERATOR(FLOOR)
+DEF_SCHEMA_OPERATOR(LOG)
+DEF_SCHEMA_OPERATOR(LOGICAL_NOT)
+DEF_SCHEMA_OPERATOR(NEGATE)
+DEF_SCHEMA_OPERATOR(RECIPROCAL)
+DEF_SCHEMA_OPERATOR(RSQRT)
+DEF_SCHEMA_OPERATOR(SELECT)
+DEF_SCHEMA_OPERATOR(EQUAL)
+DEF_SCHEMA_OPERATOR(GREATER)
+DEF_SCHEMA_OPERATOR(GREATER_EQUAL)
+DEF_SCHEMA_OPERATOR(REDUCE_ANY)
+DEF_SCHEMA_OPERATOR(REDUCE_ALL)
+DEF_SCHEMA_OPERATOR(REDUCE_MAX)
+DEF_SCHEMA_OPERATOR(REDUCE_MIN)
+DEF_SCHEMA_OPERATOR(REDUCE_PRODUCT)
+DEF_SCHEMA_OPERATOR(REDUCE_SUM)
+DEF_SCHEMA_OPERATOR(CONCAT)
+DEF_SCHEMA_OPERATOR(PAD)
+DEF_SCHEMA_OPERATOR(DIM)
+DEF_SCHEMA_OPERATOR(RESHAPE)
+DEF_SCHEMA_OPERATOR(REVERSE)
+DEF_SCHEMA_OPERATOR(SLICE)
+DEF_SCHEMA_OPERATOR(TILE)
+DEF_SCHEMA_OPERATOR(TRANSPOSE)
+DEF_SCHEMA_OPERATOR(GATHER)
+DEF_SCHEMA_OPERATOR(SCATTER)
+DEF_SCHEMA_OPERATOR(RESIZE)
+DEF_SCHEMA_OPERATOR(CAST)
+DEF_SCHEMA_OPERATOR(RESCALE)
+DEF_SCHEMA_OPERATOR(CONST)
+DEF_SCHEMA_OPERATOR(IDENTITY)
+DEF_SCHEMA_OPERATOR(CUSTOM)
+DEF_SCHEMA_OPERATOR(COND_IF)
+DEF_SCHEMA_OPERATOR(WHILE_LOOP)
+DEF_SCHEMA_OPERATOR(FFT2D)
+DEF_SCHEMA_OPERATOR(RFFT2D)
+DEF_SCHEMA_OPERATOR(CONST_SHAPE)
+DEF_SCHEMA_OPERATOR(CONCAT_SHAPE)
+DEF_SCHEMA_OPERATOR(ADD_SHAPE)
+DEF_SCHEMA_OPERATOR(SUB_SHAPE)
+DEF_SCHEMA_OPERATOR(MUL_SHAPE)
+DEF_SCHEMA_OPERATOR(DIV_SHAPE)
+DEF_SCHEMA_OPERATOR(COS)
+DEF_SCHEMA_OPERATOR(SIN) \ No newline at end of file
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
new file mode 100644
index 0000000..215d760
--- /dev/null
+++ b/src/TosaDeserialize.cpp
@@ -0,0 +1,2128 @@
+
+// Copyright (c) 2023-2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// TOSA MLIR deserialize passes
+
+#include "include/DeserializationPasses.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
+#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
+#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
+#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
+#include "mlir/IR/Matchers.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "tosa_serialization_handler.h"
+#include <functional>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+// The namespace might be confusing here. We have mlir::tosa:: defined in MLIR
+// and tosa:: defined in serialization library
+// TODO: align the namespace
+using namespace tosa;
+
+namespace cl = llvm::cl;
+
+llvm::cl::opt<std::string> tosa_deserialize_filename(
+ "tosa-deserialize-filename", llvm::cl::desc("<tosa flatbuffer filename>"),
+ llvm::cl::init("tosa_dump.tosa"), llvm::cl::value_desc("filename"));
+
+llvm::cl::opt<std::string> tosa_deserialize_schema(
+ "tosa-deserialize-schema", llvm::cl::desc("<tosa flatbuffer schema file>"),
+ llvm::cl::init(""), llvm::cl::value_desc("filename"));
+
+const std::string kDefaultExportedName = "tosa_deserialized";
+const std::string kDefaultInputPrefix = "input_";
+const std::string kDefaultOutputPrefix = "output_";
+const std::string kDefaultFBSDescription = "Tosa FBS Converted";
+const std::string kDefaultJSONDescription = "Tosa JSON Converted";
+const std::string kMainFunctionName = "main";
+
+namespace {
+
+// a global map from flatbuffer variable names to serialized tensors
+std::unordered_map<std::string, TosaSerializationTensor *> variable_tensor_map;
+
+void RegisterVariableTensor(TosaSerializationTensor *ts) {
+ assert(ts->GetVariable());
+ // insert variable tensor ts only if not already present
+ variable_tensor_map.insert({ts->GetName(), ts});
+}
+
+bool IsVariableTensor(const std::string flatbuffer_tensor_name) {
+ return variable_tensor_map.count(flatbuffer_tensor_name);
+}
+
+// return the variable name corresponding to flatbuffer_tensor_name
+const std::string GetVariableTensorName(TosaSerializationTensor *ts) {
+ assert(ts->GetVariable());
+ const auto name = ts->GetVariableName();
+ if (name == "") {
+ // for legacy flatbuffers which may not have variable_name fields
+ return ts->GetName();
+ }
+ return name;
+}
+
+// return the variable name corresponding to flatbuffer_tensor_name
+const std::string
+GetVariableTensorName(const std::string flatbuffer_tensor_name) {
+ if (!IsVariableTensor(flatbuffer_tensor_name)) {
+ llvm::errs() << "ERROR: Variable tensor " << flatbuffer_tensor_name
+ << " is not found in variable_tensor_map";
+ return "";
+ }
+ return GetVariableTensorName(variable_tensor_map[flatbuffer_tensor_name]);
+}
+
+bool IsVariableReadOp(TosaSerializationOperator *op) {
+ return (op->GetOp() == tosa::Op::Op_IDENTITY) &&
+ IsVariableTensor(op->GetInputTensorNames()[0]);
+}
+
+bool IsVariableWriteOp(TosaSerializationOperator *op) {
+ return (op->GetOp() == tosa::Op::Op_IDENTITY) &&
+ IsVariableTensor(op->GetOutputTensorNames()[0]);
+}
+
+// construct tensor type from dtype and shape of TosaSerializationTensor
+mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
+ TosaSerializationTensor *ts,
+ mlir::RankedTensorType &type) {
+ mlir::Type element_type;
+ switch (ts->GetDtype()) {
+ case DType_BOOL:
+ element_type = op_builder->getI1Type();
+ break;
+ case DType_UINT8:
+ element_type = op_builder->getIntegerType(8, false);
+ break;
+ case DType_INT4:
+ element_type = op_builder->getI4Type();
+ break;
+ case DType_INT8:
+ element_type = op_builder->getI8Type();
+ break;
+ case DType_INT16:
+ element_type = op_builder->getIntegerType(16);
+ break;
+ case DType_INT32:
+ element_type = op_builder->getI32Type();
+ break;
+ case DType_INT48:
+ element_type = op_builder->getIntegerType(48);
+ break;
+ case DType_FP32:
+ element_type = op_builder->getF32Type();
+ break;
+ case DType_UINT16:
+ element_type = op_builder->getIntegerType(16, false);
+ break;
+ case DType_FP16:
+ element_type = op_builder->getF16Type();
+ break;
+ case DType_BF16:
+ element_type = op_builder->getBF16Type();
+ break;
+ case DType_FP8E4M3:
+ element_type = op_builder->getFloat8E4M3FNType();
+ break;
+ case DType_FP8E5M2:
+ element_type = op_builder->getFloat8E5M2Type();
+ break;
+ case DType_SHAPE:
+ llvm::errs()
+ << "ERROR: Cannot construct RankedTensorType out of tosa.shape type \n";
+ return mlir::failure();
+ default:
+ llvm::errs() << "ERROR: unknown type " << EnumNamesDType()[ts->GetDtype()]
+ << "\n";
+ return mlir::failure();
+ }
+ llvm::SmallVector<int64_t> shape;
+ for (auto dim : ts->GetShape()) {
+ if (dim > 0) {
+ shape.push_back(dim);
+ } else {
+ // dynamic dim
+ shape.push_back(mlir::ShapedType::kDynamic);
+ }
+ }
+ type = mlir::RankedTensorType::get(llvm::ArrayRef(shape), element_type);
+ return mlir::success();
+}
+
+mlir::DenseElementsAttr GetConstAttr(const std::vector<uint8_t> &data,
+ const mlir::RankedTensorType &output_type,
+ uint32_t out_size) {
+ auto element_type = output_type.getElementType();
+ if (element_type.isF32()) {
+ // for FP32, value attributes are stored as FP32 values
+ std::vector<float> float_data;
+ TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data);
+ return mlir::DenseElementsAttr::get(output_type,
+ llvm::ArrayRef(float_data));
+ }
+ if (element_type.isBF16()) {
+ mlir::SmallVector<mlir::APFloat> bf16_data;
+ for (uint32_t i = 0; i < out_size; i++) {
+ uint64_t byte0 = data[i * sizeof(int16_t)];
+ uint64_t byte1 = data[i * sizeof(int16_t) + 1];
+ uint64_t bits = byte0 + (byte1 << 8);
+ mlir::APInt bf16_bits(16, bits);
+ mlir::APFloat bf16(mlir::APFloat::BFloat(), bf16_bits);
+ bf16_data.push_back(bf16);
+ }
+ return mlir::DenseElementsAttr::get(output_type, bf16_data);
+ }
+ if (element_type.isFloat8E4M3FN()) {
+ mlir::SmallVector<mlir::APFloat> f8_data;
+ for (uint32_t i = 0; i < out_size; i++) {
+ mlir::APInt f8_bits(8, static_cast<uint64_t>(data[i]));
+ mlir::APFloat f8(mlir::APFloat::Float8E4M3FN(), f8_bits);
+ f8_data.push_back(f8);
+ }
+ return mlir::DenseElementsAttr::get(output_type, f8_data);
+ }
+ if (element_type.isFloat8E5M2()) {
+ mlir::SmallVector<mlir::APFloat> f8_data;
+ for (uint32_t i = 0; i < out_size; i++) {
+ mlir::APInt f8_bits(8, static_cast<uint64_t>(data[i]));
+ mlir::APFloat f8(mlir::APFloat::Float8E5M2(), f8_bits);
+ f8_data.push_back(f8);
+ }
+ return mlir::DenseElementsAttr::get(output_type, f8_data);
+ }
+ if (element_type.isInteger(4)) {
+ std::vector<int8_t> int4_data;
+ TosaSerializationHandler::ConvertU8toI4(data, out_size, int4_data);
+ return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data));
+ }
+ if (element_type.isInteger(8)) {
+ std::vector<int8_t> int8_data;
+ TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data);
+ return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data));
+ }
+ if (element_type.isInteger(16)) {
+ std::vector<int16_t> int16_data;
+ TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data);
+ return mlir::DenseElementsAttr::get(output_type,
+ llvm::ArrayRef(int16_data));
+ }
+ if (element_type.isInteger(32)) {
+ std::vector<int32_t> int32_data;
+ TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data);
+ return mlir::DenseElementsAttr::get(output_type,
+ llvm::ArrayRef(int32_data));
+ }
+ if (element_type.isInteger(48)) {
+ std::vector<int64_t> int48_data;
+ TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data);
+ std::vector<mlir::APInt> apint_data;
+ for (const auto v : int48_data) {
+ mlir::APInt apint_value(48, static_cast<uint64_t>(v),
+ /* isSigned = */ false);
+ apint_data.push_back(apint_value);
+ }
+ return mlir::DenseElementsAttr::get(output_type,
+ llvm::ArrayRef(apint_data));
+ }
+ if (element_type.isInteger(1)) {
+ std::vector<bool> bool_data;
+ TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data);
+ llvm::SmallVector<bool> bool_values(bool_data.begin(), bool_data.end());
+ return mlir::DenseElementsAttr::get(output_type, bool_values);
+ }
+ if (element_type.isF16()) {
+ std::vector<half_float::half> half_data;
+ TosaSerializationHandler::ConvertU8toF16(data, out_size, half_data);
+ return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(half_data));
+ }
+
+ return nullptr;
+}
+
+mlir::DenseElementsAttr
+ConstructConstAttr(const mlir::RankedTensorType &output_type,
+ TosaSerializationTensor *ts, const std::string &op_name) {
+ // compute output data size
+ uint32_t out_size = 1;
+ for (const auto dim : ts->GetShape()) {
+ out_size *= dim;
+ }
+ auto attr = GetConstAttr(ts->GetData(), output_type, out_size);
+ if (!attr) {
+ llvm::errs() << "ERROR: " << op_name
+ << " contains unsupported element type\n";
+ }
+ return attr;
+}
+
+mlir::LogicalResult ConstructVariableOps(mlir::ModuleOp &module) {
+ if (variable_tensor_map.empty()) {
+ return mlir::success();
+ }
+ auto loc = module.getLoc();
+ auto op_builder = mlir::OpBuilder(module.getBodyRegion());
+ for (auto [flatbuffer_name, ts] : variable_tensor_map) {
+ auto name = GetVariableTensorName(ts);
+ mlir::RankedTensorType type;
+ if (BuildTensorType(&op_builder, ts, type).failed()) {
+ return mlir::failure();
+ }
+
+ mlir::Attribute value_attr = nullptr;
+ if (!ts->GetData().empty()) {
+ value_attr = ConstructConstAttr(type, ts, name);
+ }
+ op_builder.create<mlir::tosa::VariableOp>(loc, llvm::StringRef(name), type,
+ value_attr);
+ }
+
+ return mlir::success();
+}
+
+template <class T>
+mlir::DenseElementsAttr BuildDenseI8ElementsAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ llvm::SmallVector<int8_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ auto type = mlir::RankedTensorType::get({static_cast<int64_t>(vec.size())},
+ op_builder->getI8Type());
+ return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec));
+}
+
+template <class T>
+mlir::DenseElementsAttr
+BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ llvm::SmallVector<int16_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ auto type = mlir::RankedTensorType::get({static_cast<int64_t>(vec.size())},
+ op_builder->getI16Type());
+ return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec));
+}
+
+template <class T>
+mlir::DenseElementsAttr
+BuildDenseI32ElementsAttr(mlir::OpBuilder *op_builder,
+ mlir::RankedTensorType &type,
+ const std::vector<T> &values) {
+ llvm::SmallVector<int32_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec));
+}
+
+template <class T>
+mlir::DenseI8ArrayAttr BuildDenseI8ArrayAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int8_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ return op_builder->getDenseI8ArrayAttr(vec);
+}
+
+template <class T>
+mlir::DenseI32ArrayAttr BuildDenseI32ArrayAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int32_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ return op_builder->getDenseI32ArrayAttr(vec);
+}
+
+template <class T>
+mlir::DenseI64ArrayAttr BuildDenseI64ArrayAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int64_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ return op_builder->getDenseI64ArrayAttr(vec);
+}
+
+const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) {
+ if (mode == ResizeMode_NEAREST) {
+ return "NEAREST_NEIGHBOR";
+ } else if (mode == ResizeMode_BILINEAR) {
+ return "BILINEAR";
+ }
+ return "";
+}
+
+// this is a counter part to Type2AccDType
+mlir::Type AccDType2Type(mlir::OpBuilder *op_builder, DType dtype) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
+ if (dtype == DType_INT32) {
+ return op_builder->getI32Type();
+ } else if (dtype == DType_INT48) {
+ return op_builder->getIntegerType(48);
+ } else if (dtype == DType_FP32) {
+ return op_builder->getF32Type();
+ } else if (dtype == DType_FP16) {
+ return op_builder->getF16Type();
+ } else {
+ // unknown acc type
+ // for now, default to F32
+ return op_builder->getF32Type();
+ }
+}
+
+} // namespace
+
+class TosaMlirRegionBuilder;
+class TosaMlirBlockBuilder;
+
+class TosaMlirOperatorBuilder {
+public:
+ TosaMlirOperatorBuilder(
+ mlir::OpBuilder *_op_builder, TosaSerializationBasicBlock *_ser_block,
+ mlir::Block *_block, mlir::Location _loc,
+ TosaMlirBlockBuilder *_block_builder,
+ std::unordered_map<std::string, mlir::Value> *_tensor_map,
+ std::unordered_map<std::string, mlir::RankedTensorType> *_tensor_type_map,
+ std::unordered_map<std::string, mlir::tosa::shapeType> *_shape_type_map)
+ : op_builder(_op_builder), ser_block(_ser_block), block(_block),
+ loc(_loc), block_builder(_block_builder), tensor_map(_tensor_map),
+ tensor_type_map(_tensor_type_map), shape_type_map(_shape_type_map) {}
+
+ template <Op OPCODE>
+ std::vector<mlir::Value> build(TosaSerializationOperator *op) const;
+
+ std::vector<mlir::Value> BuildVariableOp(TosaSerializationOperator *op) const;
+
+ std::vector<mlir::Value>
+ BuildVariableReadOp(TosaSerializationOperator *op) const;
+
+ void BuildVariableWriteOp(TosaSerializationOperator *op) const;
+
+ std::string get_string(TosaSerializationOperator *op) const {
+ std::string op_string;
+ op_string += "operator opcode=";
+ op_string += EnumNamesOp()[op->GetOp()];
+ op_string += ", input=[";
+ for (auto ts : op->GetInputTensorNames()) {
+ op_string += (ts + " ");
+ }
+ op_string += "], output=[";
+ for (auto ts : op->GetOutputTensorNames()) {
+ op_string += (ts + " ");
+ }
+ op_string += "]";
+ return op_string;
+ }
+
+ TosaSerializationHandler *GetTsh() const;
+ TosaMlirRegionBuilder *GetRegionBuilder() const;
+
+private:
+ template <class MLIR_OP>
+ std::vector<mlir::Value>
+ BuildEwiseUnaryOp(TosaSerializationOperator *op) const;
+
+ template <class MLIR_OP>
+ std::vector<mlir::Value>
+ BuildEwiseBinaryOp(TosaSerializationOperator *op) const;
+
+ template <class MLIR_OP>
+ std::vector<mlir::Value>
+ BuildEwiseBinaryShapeOp(TosaSerializationOperator *op) const;
+
+ template <class MLIR_OP>
+ std::vector<mlir::Value>
+ BuildReductionOp(TosaSerializationOperator *op) const;
+
+ template <class T>
+ mlir::Value BuildConstShape(mlir::OpBuilder *op_builder, mlir::Location loc,
+ const std::vector<T> &values) const;
+
+ template <class MLIR_OP>
+ std::vector<mlir::Value> BuildConvOp(TosaSerializationOperator *op) const;
+
+ mlir::OpBuilder *op_builder;
+ TosaSerializationBasicBlock *ser_block;
+ mlir::Block *block;
+ mlir::Location loc;
+ TosaMlirBlockBuilder *block_builder;
+ std::unordered_map<std::string, mlir::Value> *tensor_map;
+ std::unordered_map<std::string, mlir::RankedTensorType> *tensor_type_map;
+ std::unordered_map<std::string, mlir::tosa::shapeType> *shape_type_map;
+};
+
+// Main template to catch unimplemented translation
+template <Op OPCODE>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const {
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " translation hasn't been implemented\n";
+ return {};
+}
+
+// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D)
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_MAX_POOL2D>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_PoolAttribute); // double check attribute type
+ TosaPoolAttribute *attr =
+ static_cast<TosaPoolAttribute *>(op->GetAttribute());
+ mlir::DenseI64ArrayAttr kernel =
+ BuildDenseI64ArrayAttr(op_builder, attr->kernel());
+ mlir::DenseI64ArrayAttr stride =
+ BuildDenseI64ArrayAttr(op_builder, attr->stride());
+ mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
+ int32_t input_zp = attr->input_zp();
+ int32_t output_zp = attr->output_zp();
+ assert(input_zp == 0 && output_zp == 0);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MaxPool2dOp>(
+ loc, output_type, input_val, kernel, stride, pad);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+// BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D)
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_PoolAttribute); // double check attribute type
+ TosaPoolAttribute *attr =
+ static_cast<TosaPoolAttribute *>(op->GetAttribute());
+ mlir::DenseI64ArrayAttr kernel =
+ BuildDenseI64ArrayAttr(op_builder, attr->kernel());
+ mlir::DenseI64ArrayAttr stride =
+ BuildDenseI64ArrayAttr(op_builder, attr->stride());
+ mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
+ auto acc_attr =
+ mlir::TypeAttr::get(AccDType2Type(op_builder, attr->acc_type()));
+
+ int32_t input_zp = attr->input_zp();
+ int32_t output_zp = attr->output_zp();
+ mlir::Operation *mlir_op;
+ if (input_zp == 0 && output_zp == 0) {
+ mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>(
+ loc, output_type, input_val, kernel, stride, pad, acc_attr);
+ } else {
+ auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp);
+ auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp);
+ mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>(
+ loc, output_type, input_val, kernel, stride, pad, acc_attr,
+ input_zp_attr, output_zp_attr);
+ }
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildVariableReadOp(
+ TosaSerializationOperator *op) const {
+ auto input_tensor_name = op->GetInputTensorNames()[0];
+ auto output_tensor_name = op->GetOutputTensorNames()[0];
+
+ assert(IsVariableTensor(input_tensor_name));
+
+ auto variable_name = GetVariableTensorName(input_tensor_name);
+ mlir::RankedTensorType output_type = tensor_type_map->at(output_tensor_name);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::VariableReadOp>(
+ loc, output_type, llvm::StringRef(variable_name));
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+void TosaMlirOperatorBuilder::BuildVariableWriteOp(
+ TosaSerializationOperator *op) const {
+ auto input_tensor_name = op->GetInputTensorNames()[0];
+ auto output_tensor_name = op->GetOutputTensorNames()[0];
+
+ assert(IsVariableTensor(output_tensor_name));
+ auto variable_name = GetVariableTensorName(output_tensor_name);
+ mlir::Value input_val = tensor_map->at(input_tensor_name);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::VariableWriteOp>(
+ loc, llvm::StringRef(variable_name), input_val);
+ block->push_back(mlir_op);
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseUnaryOp(
+ TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op =
+ op_builder->create<MLIR_OP>(loc, output_type, input_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseBinaryOp(
+ TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op =
+ op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseBinaryShapeOp(
+ TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::tosa::shapeType output_type =
+ shape_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op =
+ op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::BuildReductionOp(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_AxisAttribute); // double check attribute type
+
+ TosaAxisAttribute *attr =
+ static_cast<TosaAxisAttribute *>(op->GetAttribute());
+ auto axis = op_builder->getI32IntegerAttr(attr->axis());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<MLIR_OP>(loc, output_type, input_val, axis);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+#define BUILD_OP_ELEMENTWISE_UNARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> \
+ TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator * op) const { \
+ return BuildEwiseUnaryOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+#define BUILD_OP_ELEMENTWISE_BINARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> \
+ TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator * op) const { \
+ return BuildEwiseBinaryOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+#define BUILD_OP_ELEMENTWISE_BINARY_SHAPE(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> \
+ TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator * op) const { \
+ return BuildEwiseBinaryShapeOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+#define BUILD_OP_REDUCTION(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> \
+ TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator * op) const { \
+ return BuildReductionOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D)
+// BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D)
+
+BUILD_OP_ELEMENTWISE_BINARY(Add, ADD)
+BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND)
+BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR)
+BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR)
+BUILD_OP_ELEMENTWISE_BINARY(IntDiv, INTDIV)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalOr, LOGICAL_OR)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalXor, LOGICAL_XOR)
+BUILD_OP_ELEMENTWISE_BINARY(Maximum, MAXIMUM)
+BUILD_OP_ELEMENTWISE_BINARY(Minimum, MINIMUM)
+BUILD_OP_ELEMENTWISE_BINARY(Pow, POW)
+BUILD_OP_ELEMENTWISE_BINARY(Sub, SUB)
+
+BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS)
+BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT)
+BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL)
+BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ)
+BUILD_OP_ELEMENTWISE_UNARY(Cos, COS)
+BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP)
+BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR)
+BUILD_OP_ELEMENTWISE_UNARY(Log, LOG)
+BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT)
+BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL)
+BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT)
+BUILD_OP_ELEMENTWISE_UNARY(Sin, SIN)
+
+BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY)
+BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL)
+BUILD_OP_REDUCTION(ReduceMax, REDUCE_MAX)
+BUILD_OP_REDUCTION(ReduceMin, REDUCE_MIN)
+BUILD_OP_REDUCTION(ReduceProd, REDUCE_PRODUCT)
+BUILD_OP_REDUCTION(ReduceSum, REDUCE_SUM)
+
+BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL)
+BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER)
+BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL)
+
+BUILD_OP_ELEMENTWISE_UNARY(Erf, ERF)
+BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID)
+BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH)
+BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY)
+BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST)
+
+BUILD_OP_ELEMENTWISE_BINARY_SHAPE(AddShape, ADD_SHAPE)
+BUILD_OP_ELEMENTWISE_BINARY_SHAPE(SubShape, SUB_SHAPE)
+BUILD_OP_ELEMENTWISE_BINARY_SHAPE(MulShape, MUL_SHAPE)
+BUILD_OP_ELEMENTWISE_BINARY_SHAPE(DivShape, DIV_SHAPE)
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_CONST>(TosaSerializationOperator *op) const {
+ const auto &output_name = op->GetOutputTensorNames()[0];
+ mlir::RankedTensorType output_type = tensor_type_map->at(output_name);
+ TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name);
+ auto value_attr = ConstructConstAttr(output_type, ts, get_string(op));
+ if (!value_attr) {
+ return {};
+ }
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, output_type, value_attr);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <class T>
+mlir::Value
+TosaMlirOperatorBuilder::BuildConstShape(mlir::OpBuilder *op_builder,
+ mlir::Location loc,
+ const std::vector<T> &values) const {
+ std::vector<int64_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ auto attr = op_builder->getIndexTensorAttr(vec);
+ auto type = mlir::tosa::shapeType::get(op_builder->getContext(),
+ /* rank = */ vec.size());
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::ConstShapeOp>(loc, type, attr);
+ block->push_back(mlir_op);
+ return mlir_op->getResult(0);
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONST_SHAPE>(
+ TosaSerializationOperator *op) const {
+ const auto &output_name = op->GetOutputTensorNames()[0];
+ mlir::tosa::shapeType output_type = shape_type_map->at(output_name);
+ TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name);
+
+ const auto &data = ts->GetData();
+
+ std::vector<int64_t> i64_data;
+ TosaSerializationHandler::ConvertU8toI64(data, output_type.getRank(),
+ i64_data);
+ mlir::Value result = BuildConstShape(op_builder, loc, i64_data);
+ return std::vector<mlir::Value>({result});
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_ConvAttribute); // double check attribute type
+ TosaConvAttribute *attr =
+ static_cast<TosaConvAttribute *>(op->GetAttribute());
+ mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
+ mlir::DenseI64ArrayAttr stride =
+ BuildDenseI64ArrayAttr(op_builder, attr->stride());
+ mlir::DenseI64ArrayAttr dilation =
+ BuildDenseI64ArrayAttr(op_builder, attr->dilation());
+ auto input_zp = attr->input_zp();
+ auto weight_zp = attr->weight_zp();
+ bool local_bound = attr->local_bound();
+ auto acc_type = AccDType2Type(op_builder, attr->acc_type());
+
+ // input_zp/weight_zp is not allowed for float type
+ mlir::Operation *mlir_op;
+ if (output_type.getElementType().isa<mlir::FloatType>()) {
+ assert(input_zp == 0 && weight_zp == 0);
+ }
+
+ auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp);
+ auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp);
+ mlir_op = op_builder->create<MLIR_OP>(
+ loc, output_type, input0_val, input1_val, input2_val, pad, stride,
+ dilation, acc_type, input_zp_attr, weight_zp_attr, local_bound);
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+#define BUILD_OP_CONV(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> \
+ TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator * op) const { \
+ return BuildConvOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+BUILD_OP_CONV(Conv2D, CONV2D)
+BUILD_OP_CONV(Conv3D, CONV3D)
+BUILD_OP_CONV(DepthwiseConv2D, DEPTHWISE_CONV2D)
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE_CONV2D>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_TransposeConvAttribute); // double check attribute type
+ TosaTransposeConvAttribute *attr =
+ static_cast<TosaTransposeConvAttribute *>(op->GetAttribute());
+ mlir::DenseI64ArrayAttr out_pad =
+ BuildDenseI64ArrayAttr(op_builder, attr->out_pad());
+ mlir::DenseI64ArrayAttr stride =
+ BuildDenseI64ArrayAttr(op_builder, attr->stride());
+ auto input_zp = attr->input_zp();
+ auto weight_zp = attr->weight_zp();
+ bool local_bound = attr->local_bound();
+ auto acc_type = AccDType2Type(op_builder, attr->acc_type());
+
+ // input_zp/weight_zp is not allowed for float type
+ mlir::Operation *mlir_op;
+ if (output_type.getElementType().isa<mlir::FloatType>()) {
+ assert(input_zp == 0 && weight_zp == 0);
+ }
+
+ auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp);
+ auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp);
+
+ mlir_op = op_builder->create<mlir::tosa::TransposeConv2DOp>(
+ loc, output_type, input0_val, input1_val, input2_val, out_pad, stride,
+ acc_type, input_zp_attr, weight_zp_attr, local_bound);
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_FULLY_CONNECTED>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_FullyConnectedAttribute); // double check attribute type
+ TosaFullyConnectedAttribute *attr =
+ static_cast<TosaFullyConnectedAttribute *>(op->GetAttribute());
+ auto input_zp = attr->input_zp();
+ auto weight_zp = attr->weight_zp();
+
+ // input_zp/weight_zp is not allowed for float type
+ mlir::Operation *mlir_op;
+ if (output_type.getElementType().isa<mlir::FloatType>()) {
+ assert(input_zp == 0 && weight_zp == 0);
+ mlir_op = op_builder->create<mlir::tosa::FullyConnectedOp>(
+ loc, output_type, input0_val, input1_val, input2_val);
+ } else {
+ auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp);
+ auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp);
+ mlir_op = op_builder->create<mlir::tosa::FullyConnectedOp>(
+ loc, output_type, input0_val, input1_val, input2_val, input_zp_attr,
+ weight_zp_attr);
+ }
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_MATMUL>(TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_MatMulAttribute); // double check attribute type
+ TosaMatMulAttribute *attr =
+ static_cast<TosaMatMulAttribute *>(op->GetAttribute());
+ auto A_zp = attr->a_zp();
+ auto B_zp = attr->b_zp();
+
+ mlir::Operation *mlir_op;
+ if (A_zp == 0 && B_zp == 0) {
+ mlir_op = op_builder->create<mlir::tosa::MatMulOp>(loc, output_type,
+ input0_val, input1_val);
+ } else {
+ auto a_zp_attr = op_builder->getI32IntegerAttr(A_zp);
+ auto b_zp_attr = op_builder->getI32IntegerAttr(B_zp);
+ mlir_op = op_builder->create<mlir::tosa::MatMulOp>(
+ loc, output_type, input0_val, input1_val, a_zp_attr, b_zp_attr);
+ }
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_SELECT>(TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::SelectOp>(
+ loc, output_type, input0_val, input1_val, input2_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_ClampAttribute); // double check attribute type
+ TosaClampAttribute *attr =
+ static_cast<TosaClampAttribute *>(op->GetAttribute());
+
+ mlir::Type element_type =
+ llvm::cast<mlir::ShapedType>(input_val.getType()).getElementType();
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(element_type)) {
+ element_type = quantType.getStorageType();
+ }
+
+ auto element_const_type = mlir::RankedTensorType::get({1}, element_type);
+ auto min_values_attr = GetConstAttr(attr->min_val(), element_const_type, 1);
+ auto max_values_attr = GetConstAttr(attr->max_val(), element_const_type, 1);
+
+ mlir::Attribute min_val_attr, max_val_attr;
+ if (element_type.isa<mlir::FloatType>()) {
+ min_val_attr = op_builder->getFloatAttr(
+ element_type, min_values_attr.getValues<mlir::APFloat>()[0]);
+ max_val_attr = op_builder->getFloatAttr(
+ element_type, max_values_attr.getValues<mlir::APFloat>()[0]);
+ } else {
+ min_val_attr = op_builder->getIntegerAttr(
+ element_type, min_values_attr.getValues<mlir::APInt>()[0]);
+ max_val_attr = op_builder->getIntegerAttr(
+ element_type, max_values_attr.getValues<mlir::APInt>()[0]);
+ }
+
+ auto mlir_op = op_builder->create<mlir::tosa::ClampOp>(
+ loc, output_type, input_val, min_val_attr, max_val_attr);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+// ArgMax has single input, and single I64 axis attribute
+BUILD_OP_REDUCTION(ArgMax, ARGMAX)
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_CONCAT>(TosaSerializationOperator *op) const {
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ llvm::SmallVector<mlir::Value> input_values;
+ for (auto &input_name : op->GetInputTensorNames()) {
+ mlir::Value input_val = tensor_map->at(input_name);
+ input_values.push_back(input_val);
+ }
+
+ assert(op->GetAttributeType() ==
+ Attribute_AxisAttribute); // double check attribute type
+ TosaAxisAttribute *attr =
+ static_cast<TosaAxisAttribute *>(op->GetAttribute());
+ auto axis = op_builder->getI32IntegerAttr(attr->axis());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ConcatOp>(
+ loc, output_type, input_values, axis);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONCAT_SHAPE>(
+ TosaSerializationOperator *op) const {
+ mlir::tosa::shapeType output_type =
+ shape_type_map->at(op->GetOutputTensorNames()[0]);
+
+ llvm::SmallVector<mlir::Value> input_values;
+ for (auto &input_name : op->GetInputTensorNames()) {
+ mlir::Value input_val = tensor_map->at(input_name);
+ input_values.push_back(input_val);
+ }
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ConcatShapeOp>(
+ loc, output_type, input_values);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_NEGATE>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_NegateAttribute); // double check attribute type
+ TosaNegateAttribute *attr =
+ static_cast<TosaNegateAttribute *>(op->GetAttribute());
+
+ auto input_zp = attr->input1_zp();
+ auto output_zp = attr->output_zp();
+
+ mlir::Operation *mlir_op;
+ if (input_zp == 0 && output_zp == 0) {
+ mlir_op =
+ op_builder->create<mlir::tosa::NegateOp>(loc, output_type, input_val);
+ } else {
+ auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp);
+ auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp);
+ mlir_op = op_builder->create<mlir::tosa::NegateOp>(
+ loc, output_type, input_val, input_zp_attr, output_zp_attr);
+ }
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESHAPE>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ mlir::Value shape_val = tensor_map->at(op->GetInputTensorNames()[1]);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ReshapeOp>(
+ loc, output_type, input_val, shape_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value padding_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType input_type =
+ tensor_type_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ const auto element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
+
+ assert(op->GetAttributeType() ==
+ Attribute_PadAttribute); // double check attribute type
+ TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute());
+ const auto &pad_const_u8_data = attr->pad_const();
+
+ // check for any value in pad_const_u8_data
+ bool has_pad_const = false;
+ for (auto v : pad_const_u8_data) {
+ if (v != 0) {
+ has_pad_const = true;
+ break;
+ }
+ }
+ if (!has_pad_const) {
+ // handle the cases where no explicit pad_const input.
+ auto mlir_op = op_builder->create<mlir::tosa::PadOp>(
+ loc, output_type, input_val, padding_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+ }
+
+ // has pad const - create a const op for pad_const input
+ auto pad_const_type = mlir::RankedTensorType::get({}, element_type);
+ auto pad_const_attr = GetConstAttr(pad_const_u8_data, pad_const_type, 1);
+
+ auto pad_const_op = op_builder->create<mlir::tosa::ConstOp>(
+ loc, pad_const_type, pad_const_attr);
+
+ block->push_back(pad_const_op);
+ mlir::Value pad_const_value = pad_const_op->getResult(0);
+
+ auto mlir_op = op_builder->create<mlir::tosa::PadOp>(
+ loc, output_type, input_val, padding_val, pad_const_value);
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_DIM>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_AxisAttribute); // double check attribute type
+ TosaAxisAttribute *attr =
+ static_cast<TosaAxisAttribute *>(op->GetAttribute());
+ auto axis = op_builder->getI32IntegerAttr(attr->axis());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::DimOp>(loc, output_type, input_val, axis);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_TransposeAttribute); // double check attribute type
+ TosaTransposeAttribute *attr =
+ static_cast<TosaTransposeAttribute *>(op->GetAttribute());
+ // make a constant op from attr->perms values, of type: { shape = { perms.size
+ // }, element_type = I32 }
+ const auto perms_values = attr->perms();
+ auto const_type = mlir::RankedTensorType::get(
+ {static_cast<int64_t>(perms_values.size())}, op_builder->getI32Type());
+ mlir::DenseElementsAttr const_attr =
+ BuildDenseI32ElementsAttr(op_builder, const_type, perms_values);
+ mlir::Operation *mlir_const_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
+ auto perms_val = mlir_const_op->getResult(0);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TransposeOp>(
+ loc, output_type, input_val, perms_val);
+
+ block->push_back(mlir_const_op);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_SLICE>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ mlir::Value start = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value size = tensor_map->at(op->GetInputTensorNames()[2]);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::SliceOp>(
+ loc, output_type, input_val, start, size);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_TILE>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value multiples = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check attribute type
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TileOp>(
+ loc, output_type, input_val, multiples);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+// Gather is a binary op
+BUILD_OP_ELEMENTWISE_BINARY(Gather, GATHER)
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_SCATTER>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ScatterOp>(
+ loc, output_type, input0_val, input1_val, input2_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_RESIZE>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value scale_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value offset_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::Value border_val = tensor_map->at(op->GetInputTensorNames()[3]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_ResizeAttribute); // double check attribute type
+ TosaResizeAttribute *attr =
+ static_cast<TosaResizeAttribute *>(op->GetAttribute());
+
+ auto mode = op_builder->getStringAttr(ResizeEnum2Str(attr->mode()));
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ResizeOp>(
+ loc, output_type, input_val, scale_val, offset_val, border_val, mode);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+// Reverse has single input, and single I64 axis attribute
+BUILD_OP_REDUCTION(Reverse, REVERSE)
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_MUL>(TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]);
+
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_MulAttribute); // double check attribute type
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MulOp>(
+ loc, output_type, input0_val, input1_val, shift_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_ARITHMETIC_RIGHT_SHIFT>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(
+ op->GetAttributeType() ==
+ Attribute_ArithmeticRightShiftAttribute); // double check attribute type
+ TosaArithmeticRightShiftAttribute *attr =
+ static_cast<TosaArithmeticRightShiftAttribute *>(op->GetAttribute());
+
+ auto round = op_builder->getBoolAttr(attr->round());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::ArithmeticRightShiftOp>(
+ loc, output_type, input0_val, input1_val, round);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_TableAttribute); // double check attribute type
+ TosaTableAttribute *attr =
+ static_cast<TosaTableAttribute *>(op->GetAttribute());
+
+ // create a const op for table value attribute
+ const auto table_values = attr->table();
+ mlir::RankedTensorType const_type;
+ mlir::DenseElementsAttr const_attr;
+ const auto input_element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
+ if (input_element_type.isInteger(8)) {
+ // table is signed 8 mode
+ const_type = mlir::RankedTensorType::get(
+ {static_cast<int64_t>(table_values.size())}, op_builder->getI8Type());
+ const_attr = BuildDenseI8ElementsAttr(op_builder, table_values);
+ } else {
+ // table is signed 16 mode
+ const_type = mlir::RankedTensorType::get(
+ {static_cast<int64_t>(table_values.size())}, op_builder->getI16Type());
+ const_attr = BuildDenseI16ElementsAttr(op_builder, table_values);
+ }
+ mlir::Operation *mlir_const_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
+ auto table_value = mlir_const_op->getResult(0);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TableOp>(
+ loc, output_type, input_val, table_value);
+ block->push_back(mlir_const_op);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESCALE>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value multiplier_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_RescaleAttribute); // double check attribute type
+ TosaRescaleAttribute *attr =
+ static_cast<TosaRescaleAttribute *>(op->GetAttribute());
+
+ auto input_zp = op_builder->getI32IntegerAttr(attr->input_zp());
+ auto output_zp = op_builder->getI32IntegerAttr(attr->output_zp());
+
+ auto scale32 = op_builder->getBoolAttr(attr->scale32());
+ auto double_round = op_builder->getBoolAttr(attr->double_round());
+ auto per_channel = op_builder->getBoolAttr(attr->per_channel());
+
+ auto input_unsigned = op_builder->getBoolAttr(attr->input_unsigned());
+ auto output_unsigned = op_builder->getBoolAttr(attr->output_unsigned());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RescaleOp>(
+ loc, output_type, input_val, multiplier_val, shift_val, input_zp,
+ output_zp, scale32, double_round, per_channel, input_unsigned,
+ output_unsigned);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_CUSTOM>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_CustomAttribute); // double check attribute type
+ TosaCustomAttribute *attr =
+ static_cast<TosaCustomAttribute *>(op->GetAttribute());
+
+ auto operator_name = op_builder->getStringAttr(attr->operator_name());
+ auto domain_name = op_builder->getStringAttr(attr->domain_name());
+ std::string impl_str;
+ impl_str.resize(attr->implementation_attrs().size() + 1);
+ int idx = 0;
+ for (auto c : attr->implementation_attrs()) {
+ impl_str[idx++] = c;
+ }
+ auto impl = op_builder->getStringAttr(impl_str);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::CustomOp>(
+ loc, output_type, operator_name, domain_name, impl, input_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_RFFT2D>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output0_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ mlir::RankedTensorType output1_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[1]);
+ assert(op->GetAttributeType() ==
+ Attribute_RFFTAttribute); // double check attribute type
+ TosaRFFTAttribute *attr =
+ static_cast<TosaRFFTAttribute *>(op->GetAttribute());
+
+ bool local_bound = attr->local_bound();
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RFFT2dOp>(
+ loc, output0_type, output1_type, input_val, local_bound);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(
+ {mlir_op->getResult(0), mlir_op->getResult(1)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_FFT2D>(TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output0_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ mlir::RankedTensorType output1_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[1]);
+
+ assert(op->GetAttributeType() == Attribute_FFTAttribute);
+ TosaFFTAttribute *attr = static_cast<TosaFFTAttribute *>(op->GetAttribute());
+ auto inverse = op_builder->getBoolAttr(attr->inverse());
+ auto local_bound = op_builder->getBoolAttr(attr->local_bound());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::FFT2dOp>(
+ loc, output0_type, output1_type, input0_val, input1_val, inverse,
+ local_bound);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(
+ {mlir_op->getResult(0), mlir_op->getResult(1)});
+}
+
+class TosaMlirRegionBuilder {
+public:
+ TosaMlirRegionBuilder(TosaSerializationRegion *_ser_region,
+ TosaSerializationHandler *_tsh, mlir::Region *_region,
+ mlir::OpBuilder *_op_builder, mlir::Location _loc,
+ TosaMlirRegionBuilder *_parent_value_scope = nullptr)
+ : ser_region(_ser_region), tsh(_tsh), region(_region),
+ op_builder(_op_builder), loc(_loc) {
+ if (_parent_value_scope) {
+ // inherit parent_value_scope's tensor_map
+ for (auto &kv : _parent_value_scope->GetTensorMap()) {
+ tensor_map.insert(kv);
+ }
+ }
+ }
+
+ mlir::LogicalResult
+ BuildAllBlocksInRegion(std::vector<mlir::Value> &return_values);
+
+ mlir::OpBuilder *GetOpBuilder() { return op_builder; }
+ mlir::Location GetLocation() { return loc; }
+ std::unordered_map<std::string, mlir::Value> &GetTensorMap() {
+ return tensor_map;
+ }
+ TosaSerializationHandler *GetTsh() const { return tsh; }
+
+private:
+ mlir::Region *region;
+ TosaSerializationRegion *ser_region;
+ TosaSerializationHandler *tsh;
+ mlir::OpBuilder *op_builder;
+ mlir::Location loc;
+ std::unordered_map<std::string, mlir::Value> tensor_map;
+};
+
+class TosaMlirBlockBuilder {
+public:
+ TosaMlirBlockBuilder(TosaSerializationBasicBlock *_ser_block,
+ TosaMlirRegionBuilder *_region_builder,
+ mlir::Block *_block)
+ : ser_block(_ser_block), region_builder(_region_builder), block(_block) {}
+
+ mlir::LogicalResult
+ BuildAllOpsInBlock(std::vector<mlir::Value> &return_values);
+
+ mlir::OpBuilder *GetOpBuilder() { return region_builder->GetOpBuilder(); }
+ mlir::Location GetLocation() { return region_builder->GetLocation(); }
+ std::unordered_map<std::string, mlir::Value> &GetTensorMap() {
+ return region_builder->GetTensorMap();
+ }
+
+ TosaSerializationHandler *GetTsh() const { return region_builder->GetTsh(); }
+ TosaMlirRegionBuilder *GetRegionBuilder() const { return region_builder; }
+
+private:
+ TosaSerializationBasicBlock *ser_block;
+ TosaMlirRegionBuilder *region_builder;
+ mlir::Block *block;
+ std::unordered_map<std::string, mlir::RankedTensorType> tensor_type_map;
+ std::unordered_map<std::string, mlir::tosa::shapeType> shape_type_map;
+ std::unordered_set<std::string> unranked_tensors;
+};
+
+TosaSerializationHandler *TosaMlirOperatorBuilder::GetTsh() const {
+ return block_builder->GetTsh();
+}
+
+TosaMlirRegionBuilder *TosaMlirOperatorBuilder::GetRegionBuilder() const {
+ return block_builder->GetRegionBuilder();
+}
+
+// build control flow ops:
+
+namespace {
+
+mlir::LogicalResult
+BuildRegion(TosaSerializationRegion *ser_region, TosaSerializationHandler *tsh,
+ mlir::Region *mlir_region, mlir::OpBuilder *op_builder,
+ mlir::Location loc, std::vector<mlir::Value> &return_values,
+ bool isolated_from_above = false,
+ TosaMlirRegionBuilder *parent_region_builder = nullptr) {
+ TosaMlirRegionBuilder *parent_value_scope =
+ isolated_from_above ? nullptr : parent_region_builder;
+ TosaMlirRegionBuilder region_builder(ser_region, tsh, mlir_region, op_builder,
+ loc, parent_value_scope);
+ return region_builder.BuildAllBlocksInRegion(return_values);
+}
+
+} // namespace
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_COND_IF>(
+ TosaSerializationOperator *op) const {
+ mlir::Value cond_val = tensor_map->at(op->GetInputTensorNames().at(0));
+ std::vector<mlir::Value> input_values;
+ for (auto idx = 1u; idx < op->GetInputTensorNames().size(); idx++) {
+ input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx)));
+ }
+ std::vector<mlir::Type> output_types;
+ for (auto &name : op->GetInputTensorNames()) {
+ output_types.push_back(tensor_type_map->at(name));
+ }
+
+ assert(op->GetAttributeType() ==
+ Attribute_CondIfAttribute); // double check attribute type
+ TosaCondIfAttribute *attr =
+ static_cast<TosaCondIfAttribute *>(op->GetAttribute());
+ auto ser_then_region = GetTsh()->GetRegionByName(attr->then_graph());
+ auto ser_else_region = GetTsh()->GetRegionByName(attr->else_graph());
+
+ if (!ser_then_region || !ser_else_region) {
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " region serialization hasn't been implemented\n";
+ return {};
+ }
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::IfOp>(
+ loc, output_types, cond_val, input_values);
+
+ const bool isolated_from_above =
+ mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+ mlir::Region &then_region = mlir_op->getRegion(0);
+ mlir::Region &else_region = mlir_op->getRegion(1);
+
+ auto curr_region_builder = GetRegionBuilder();
+
+ std::vector<mlir::Value> then_returns, else_returns;
+
+ if (BuildRegion(ser_then_region, GetTsh(), &then_region, op_builder, loc,
+ then_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (then_returns.size() != mlir_op->getNumResults()) {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " then_region yield.size() doesn't match cond_if's output size\n";
+ return {};
+ }
+
+ if (BuildRegion(ser_else_region, GetTsh(), &else_region, op_builder, loc,
+ else_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (else_returns.size() != mlir_op->getNumResults()) {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " else_region yield.size() doesn't match cond_if's output size\n";
+ return {};
+ }
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(mlir_op->getResults().begin(),
+ mlir_op->getResults().end());
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_WHILE_LOOP>(
+ TosaSerializationOperator *op) const {
+ std::vector<mlir::Value> input_values;
+ for (auto idx = 0u; idx < op->GetInputTensorNames().size(); idx++) {
+ input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx)));
+ }
+ std::vector<mlir::Type> output_types;
+ for (auto &name : op->GetInputTensorNames()) {
+ output_types.push_back(tensor_type_map->at(name));
+ }
+ assert(op->GetAttributeType() ==
+ Attribute_WhileLoopAttribute); // double check attribute type
+ TosaWhileLoopAttribute *attr =
+ static_cast<TosaWhileLoopAttribute *>(op->GetAttribute());
+ auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_graph());
+ auto ser_body_region = GetTsh()->GetRegionByName(attr->body_graph());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::WhileOp>(loc, output_types, input_values);
+
+ const bool isolated_from_above =
+ mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+
+ mlir::Region &cond_region = mlir_op->getRegion(0);
+ mlir::Region &body_region = mlir_op->getRegion(1);
+
+ auto curr_region_builder = GetRegionBuilder();
+
+ std::vector<mlir::Value> cond_returns, body_returns;
+
+ if (BuildRegion(ser_cond_region, GetTsh(), &cond_region, op_builder, loc,
+ cond_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (cond_returns.size() != 1) {
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " cond_region yield.size() is not 1\n";
+ return {};
+ }
+
+ if (BuildRegion(ser_body_region, GetTsh(), &body_region, op_builder, loc,
+ body_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (body_returns.size() != mlir_op->getNumResults()) {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " body_region yield.size() doesn't match while_loop's output size\n";
+ return {};
+ }
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(mlir_op->getResults().begin(),
+ mlir_op->getResults().end());
+}
+
+mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
+ std::vector<mlir::Value> &return_values) {
+ block->clear();
+ auto loc = GetLocation();
+ auto op_builder = GetOpBuilder();
+ auto &tensor_map = GetTensorMap();
+
+ std::unordered_set<TosaSerializationOperator *> operator_built;
+ std::queue<TosaSerializationOperator *> operator_queue;
+
+ TosaMlirOperatorBuilder tosa_op_builder(op_builder, ser_block, block, loc,
+ this, &tensor_map, &tensor_type_map,
+ &shape_type_map);
+
+ for (auto ts : ser_block->GetTensors()) {
+ if (ts->GetVariable()) {
+ RegisterVariableTensor(ts);
+ }
+ const auto &ts_name = ts->GetName();
+ if (ts->GetDtype() == DType::DType_SHAPE) {
+ // ts is tosa.shape type
+ auto shape_rank = ts->GetShape()[0];
+ shape_type_map[ts_name] =
+ mlir::tosa::shapeType::get(op_builder->getContext(), shape_rank);
+ continue;
+ }
+ mlir::RankedTensorType type;
+ if (BuildTensorType(op_builder, ts, type).failed()) {
+ return mlir::failure();
+ }
+ tensor_type_map[ts_name] = type;
+ if (ts->GetIsUnranked()) {
+ assert(ts->GetShape().empty()); // unranked tensors should have shape = {}
+ unranked_tensors.insert(ts_name);
+ }
+ }
+
+ // Initialize tensor_map/operator_queue based on block input arguments
+ for (const std::string &block_input_name : ser_block->GetInputs()) {
+ mlir::Type type = tensor_type_map[block_input_name];
+ if (unranked_tensors.count(block_input_name)) {
+ // recast type as unranked tensor type
+ auto element_type = type.cast<mlir::RankedTensorType>().getElementType();
+ type = mlir::UnrankedTensorType::get(element_type);
+ }
+ auto input_value = block->addArgument(type, loc);
+ if (tensor_map.count(block_input_name)) {
+ llvm::errs() << "ERROR: block input tensor " << block_input_name
+ << " already exists\n";
+ return mlir::failure();
+ }
+ tensor_map[block_input_name] = input_value;
+ }
+
+ for (auto op : ser_block->GetOperators()) {
+
+ // skip if operator has been built
+ if (operator_built.count(op)) {
+ // this happens when same input appears twice or more in operator, eg,
+ // concat(%0, %0)
+ continue;
+ }
+ operator_built.insert(op);
+
+ std::vector<mlir::Value> output_values;
+ if (IsVariableReadOp(op)) {
+ output_values = tosa_op_builder.BuildVariableReadOp(op);
+ } else if (IsVariableWriteOp(op)) {
+ tosa_op_builder.BuildVariableWriteOp(op);
+ }
+#define DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME) \
+ else if (op->GetOp() == Op_##SCHEMA_OP_NAME) { \
+ output_values = tosa_op_builder.build<Op_##SCHEMA_OP_NAME>(op); \
+ }
+#include "schema_operator.def"
+#undef DEF_SCHEMA_OPERATOR
+ else {
+ llvm::errs() << "ERROR: unsupported opcode=" << EnumNamesOp()[op->GetOp()]
+ << "\n";
+ return mlir::failure();
+ }
+
+ if (IsVariableWriteOp(op)) {
+ // the sanity checking below does not apply for variable write op because
+ // it has no output tensors whereas the original identity op has
+ continue;
+ }
+
+ // Sanity check if number of built mlir::Value is expected
+ if (op->GetOutputTensorNames().size() != output_values.size()) {
+ llvm::errs() << "ERROR: number of built mlir::Value is not matching "
+ "number of operator output tensor\n";
+ return mlir::failure();
+ }
+
+ for (size_t i = 0; i < output_values.size(); i++) {
+ // Sanity check tensor hasn't been built
+ std::string op_output_name = op->GetOutputTensorNames()[i];
+ if (tensor_map.count(op_output_name)) {
+ llvm::errs() << "ERROR: tensor " << op_output_name
+ << " is already built\n";
+ return mlir::failure();
+ }
+ tensor_map[op_output_name] = output_values[i];
+ }
+ }
+
+ // Construct return values
+ std::vector<mlir::Value> return_operands;
+ for (const auto &output_name : ser_block->GetOutputs()) {
+ // Sanity check if terminator mlir::Value is built
+ if (!tensor_map.count(output_name)) {
+ llvm::errs() << "ERROR: terminator mlir::Value " << output_name
+ << " is not built in block " << ser_block->GetName() << "\n";
+ return mlir::failure();
+ }
+ mlir::Value output_value = tensor_map.at(output_name);
+ return_operands.push_back(output_value);
+ return_values.push_back(output_value);
+ }
+ mlir::Operation *terminator_op;
+ auto parent_op = block->getParentOp();
+ if (mlir::isa<mlir::func::FuncOp>(parent_op)) {
+ terminator_op =
+ op_builder->create<mlir::func::ReturnOp>(loc, return_operands);
+ } else {
+ terminator_op =
+ op_builder->create<mlir::tosa::YieldOp>(loc, return_operands);
+ }
+ block->push_back(terminator_op);
+
+ // need topological sorting?
+
+ return mlir::success();
+}
+
+mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion(
+ std::vector<mlir::Value> &return_values) {
+ for (auto &ser_block : ser_region->GetBlocks()) {
+ auto &block = region->emplaceBlock();
+ TosaMlirBlockBuilder block_builder(ser_block, this, &block);
+
+ if (block_builder.BuildAllOpsInBlock(return_values).failed()) {
+ return mlir::failure();
+ }
+
+ if (return_values.empty()) {
+ llvm::errs() << "Warning: graph doesn't have return values\n";
+ }
+ }
+ return mlir::success();
+}
+
+mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp &func,
+ mlir::MLIRContext &context,
+ tosa::TosaSerializationHandler &tsh,
+ std::vector<mlir::Value> &main_returns) {
+
+ mlir::Region *main_region = func.getCallableRegion();
+ if (!main_region) {
+ llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n";
+ return mlir::failure();
+ }
+
+ TosaSerializationRegion *ser_main_region = tsh.GetRegions().front();
+
+ auto loc = func.getLoc();
+
+ main_region->takeBody(*main_region); // empty old func body
+ auto op_builder = mlir::OpBuilder(func.getBody());
+
+ if (BuildRegion(ser_main_region, &tsh, main_region, &op_builder, loc,
+ main_returns)
+ .failed()) {
+ return mlir::failure();
+ }
+
+ if (main_returns.empty()) {
+ llvm::errs() << "Warning: graph doesn't have return values\n";
+ }
+
+ return mlir::success();
+}
+
+// Load Tosa Schema into TosaSerializationHandler, required for JSON save/load
+mlir::LogicalResult loadTosaSchema(tosa::TosaSerializationHandler &tsh) {
+ const char *tosa_schema = tosa_deserialize_schema.c_str();
+
+ if (!tosa_schema) {
+ llvm::errs() << "Flatbuffer schema not defined\n";
+ return mlir::failure();
+ }
+
+ if (tsh.LoadFileSchema(tosa_schema)) {
+ llvm::errs() << "Error loading tosa schema file: " << tosa_schema << "\n";
+ return mlir::failure();
+ }
+ return mlir::success();
+}
+
+namespace {
+
+mlir::NamedAttribute DefaultEntryFuncitonAttr(mlir::Builder &builder,
+ bool is_input, int count) {
+ std::string names;
+ for (int i = 0; i < count; i++) {
+ std::string name = kDefaultExportedName + "_";
+ name += (is_input ? kDefaultInputPrefix : kDefaultOutputPrefix);
+ name += std::to_string(i) + ":0";
+ if (i > 0) {
+ names += ",";
+ }
+ names += name;
+ }
+ return builder.getNamedAttr((is_input ? "inputs" : "outputs"),
+ builder.getStringAttr(names));
+}
+
+// erase all ops in block except for FuncOp
+void ClearNonFuncOps(mlir::Block *block) {
+ std::vector<mlir::Operation *> to_delete;
+ for (auto &op : block->getOperations()) {
+ if (!mlir::isa<mlir::func::FuncOp>(op)) {
+ to_delete.push_back(&op);
+ }
+ }
+ for (mlir::Operation *op : to_delete) {
+ op->erase();
+ }
+}
+
+// erase function attrs and empty function region's body
+void ResetFunction(mlir::func::FuncOp &function, mlir::MLIRContext &context) {
+ function->setAttrs(mlir::DictionaryAttr::get(&context, {}));
+ mlir::Region *main_region = function.getCallableRegion();
+ main_region->takeBody(*main_region);
+}
+
+// replace attrs and body of @a to_function and its parent module
+// by @a from_module and its "main" function
+mlir::LogicalResult CloneIntoModuleAndFunction(
+ mlir::MLIRContext &context, mlir::func::FuncOp &to_function,
+ mlir::ModuleOp &to_module, mlir::func::FuncOp &from_function,
+ mlir::ModuleOp &from_module) {
+ auto from_block = from_function.getOperation()->getBlock();
+ auto to_block = to_function.getOperation()->getBlock();
+ ClearNonFuncOps(to_block);
+ // copy all attrs from new_module to module
+ to_module->setAttrs(from_module->getAttrDictionary());
+ // erase attrs and body of function
+ ResetFunction(to_function, context);
+ // clone new_func attrs and region into function
+ mlir::IRMapping mapping;
+ from_function.cloneInto(to_function, mapping);
+
+ // copy variable ops in from_block to to_block
+ // collect variable ops in from_block in reverse order
+ std::vector<mlir::Operation *> variable_ops;
+ for (mlir::Operation &op : *from_block) {
+ if (mlir::isa<mlir::tosa::VariableOp>(op)) {
+ variable_ops.push_back(&op);
+ }
+ }
+ auto cloneOptions =
+ mlir::Operation::CloneOptions::all().cloneRegions(false).cloneOperands(
+ false);
+ for (auto iter = variable_ops.rbegin(); iter != variable_ops.rend(); iter++) {
+ auto op = *iter;
+ to_block->push_front(op->clone(mapping, cloneOptions));
+ }
+ return mlir::success();
+}
+
+} // namespace
+
+namespace mlir {
+
+namespace tosa {
+
+mlir::OwningOpRef<mlir::ModuleOp>
+BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context,
+ bool file_is_fbs) {
+ TosaSerializationHandler tsh;
+ if (file_is_fbs) {
+ if (tsh.LoadFileTosaFlatbuffer(file_name)) {
+ llvm::errs() << "Fail to load TOSA file " << file_name << "\n";
+ return nullptr;
+ }
+ } else {
+ // must load tosa schema before loading json file
+ if (loadTosaSchema(tsh).failed()) {
+ return nullptr;
+ }
+ if (tsh.LoadFileJson(file_name)) {
+ llvm::errs() << "Fail to load TOSA JSON file " << file_name << "\n";
+ return nullptr;
+ }
+ }
+
+ // create new module
+ auto base_loc = mlir::FileLineColLoc::get(context, file_name, 0, 0);
+ auto module = mlir::ModuleOp::create(base_loc);
+
+ // set module attributes
+ const auto &tosa_version = tsh.GetVersion().to_string();
+ std::string tosa_description =
+ file_is_fbs ? kDefaultFBSDescription : kDefaultJSONDescription;
+ auto builder = mlir::Builder(context);
+ module->setAttr("tosa.fbs_version", builder.getStringAttr(tosa_version));
+ module->setAttr("tosa.description", builder.getStringAttr(tosa_description));
+ module->setAttr("tf_saved_model.semantics", mlir::UnitAttr::get(context));
+
+ // construct function with input and return types
+ llvm::SmallVector<mlir::Type, 2> ret_types;
+ llvm::SmallVector<mlir::Type, 4> input_types;
+ auto func_type = builder.getFunctionType(input_types, ret_types);
+ auto func_loc =
+ mlir::NameLoc::get(builder.getStringAttr(kMainFunctionName), base_loc);
+ auto func = mlir::func::FuncOp::create(func_loc, kMainFunctionName, func_type,
+ /* attrs= */ {});
+ func.addEntryBlock();
+
+ // deserialize tosa fbs into function
+ std::vector<mlir::Value> main_returns;
+ if (buildTosaMlir(func, *context, tsh, main_returns).failed()) {
+ llvm::errs() << "Failed to deserialize flatbuffer "
+ << tosa_deserialize_filename << "\n";
+ return nullptr;
+ }
+ auto main_args = func.getCallableRegion()->getArguments();
+ // extract function input types
+ for (auto arg : main_args) {
+ input_types.push_back(arg.getType());
+ }
+ // extract function return types
+ for (auto ret : main_returns) {
+ ret_types.push_back(ret.getType());
+ }
+ // set function type with full input and return types
+ func_type = builder.getFunctionType(input_types, ret_types);
+ func.setType(func_type);
+
+ // set function attributes
+ llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
+ if (!input_types.empty()) {
+ attributes.push_back(DefaultEntryFuncitonAttr(
+ builder, /* is_input = */ true, /* count = */ input_types.size()));
+ for (int i = 0; i < input_types.size(); i++) {
+ std::string input_i = kDefaultInputPrefix + std::to_string(i);
+ func.setArgAttr(i, "tf_saved_model.index_path",
+ mlir::ArrayAttr::get(
+ context, {mlir::StringAttr::get(context, input_i)}));
+ }
+ }
+ if (!ret_types.empty()) {
+ attributes.push_back(DefaultEntryFuncitonAttr(
+ builder, /* is_input = */ false, /* count = */ ret_types.size()));
+ for (int i = 0; i < ret_types.size(); i++) {
+ std::string output_i = kDefaultOutputPrefix + std::to_string(i);
+ func.setResultAttr(
+ i, "tf_saved_model.index_path",
+ mlir::ArrayAttr::get(context,
+ {mlir::StringAttr::get(context, output_i)}));
+ }
+ }
+ func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
+ func->setAttr(
+ "tf_saved_model.exported_names",
+ mlir::ArrayAttr::get(
+ context, {mlir::StringAttr::get(context, kDefaultExportedName)}));
+
+ // deserialize variable ops in the new module just before adding func op
+ if (ConstructVariableOps(module).failed()) {
+ return nullptr;
+ }
+
+ // add func to module
+ module.push_back(std::move(func));
+ return mlir::OwningOpRef<mlir::ModuleOp>(module);
+}
+
+namespace {
+
+class TosaDeserialize : public TosaDeserializationPassBase<TosaDeserialize> {
+public:
+ void runOnOperation() final {
+ auto function = getOperation();
+ auto &context = getContext();
+
+ auto new_module_ref = BuildMlirFromTosaFile(
+ tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ true);
+ if (!new_module_ref) {
+ return signalPassFailure();
+ }
+
+ mlir::ModuleOp new_module = *new_module_ref;
+ auto builder = mlir::Builder(&context);
+ auto module = function->getParentOfType<mlir::ModuleOp>();
+ auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>(
+ builder.getStringAttr(kMainFunctionName));
+ if (!new_function) {
+ llvm::errs() << "Failed to find main function in deserialized module\n";
+ return signalPassFailure();
+ }
+ if (CloneIntoModuleAndFunction(context,
+ /* to_function = */ function,
+ /* to_module = */ module,
+ /* from_function = */ new_function,
+ /* from_module = */ new_module)
+ .failed()) {
+ return signalPassFailure();
+ }
+ }
+};
+
+class TosaDeserializeJSON
+ : public TosaDeserializationJSONPassBase<TosaDeserializeJSON> {
+public:
+ void runOnOperation() final {
+ auto function = getOperation();
+ auto &context = getContext();
+
+ auto new_module_ref = BuildMlirFromTosaFile(
+ tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ false);
+ if (!new_module_ref) {
+ return signalPassFailure();
+ }
+
+ mlir::ModuleOp new_module = *new_module_ref;
+ auto builder = mlir::Builder(&context);
+ auto module = function->getParentOfType<mlir::ModuleOp>();
+ auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>(
+ builder.getStringAttr(kMainFunctionName));
+ if (!new_function) {
+ llvm::errs() << "Failed to find main function in deserialized module\n";
+ return signalPassFailure();
+ }
+ if (CloneIntoModuleAndFunction(context,
+ /* to_function = */ function,
+ /* to_module = */ module,
+ /* from_function = */ new_function,
+ /* from_module = */ new_module)
+ .failed()) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // anonymous namespace
+
+// Creates an instance of the TOSA flatbuffer deserialization pass
+std::unique_ptr<Pass> createTosaDeserializePass() {
+ return std::make_unique<TosaDeserialize>();
+}
+
+std::unique_ptr<Pass> createTosaDeserializeJSONPass() {
+ return std::make_unique<TosaDeserializeJSON>();
+}
+
+static PassRegistration<TosaDeserialize> passDeserialize([] {
+ return createTosaDeserializePass();
+});
+
+static PassRegistration<TosaDeserializeJSON> passDeserializeJSON([] {
+ return createTosaDeserializeJSONPass();
+});
+
+} // namespace tosa
+} // namespace mlir
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index a6ea5ff..c3a9878 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
// (the "License"); you may not use this file except in compliance with
@@ -21,9 +21,13 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
-#include "mlir/Pass/Pass.h" // from @llvm-project
-#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tosa_serialization_handler.h"
+#include <algorithm>
#include <functional>
#include <map>
#include <unordered_map>
@@ -61,7 +65,7 @@ template <> struct equal_to<mlir::Value> {
} // namespace std
-ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
+static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
if (mode_str == "NEAREST_NEIGHBOR")
return ResizeMode_NEAREST;
else if (mode_str == "BILINEAR")
@@ -70,14 +74,21 @@ ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
return ResizeMode_UNKNOWN;
}
-DType Type2DType(mlir::Type element_type) {
- if (element_type.isF64() || element_type.isF32() || element_type.isF16() ||
- element_type.isBF16()) {
- return DType_FLOAT;
+static DType Type2DType(mlir::Type element_type) {
+ if (element_type.isF64() || element_type.isF32()) {
+ return DType_FP32;
+ } else if (element_type.isFloat8E5M2()) {
+ return DType_FP8E5M2;
+ } else if (element_type.isFloat8E4M3FN()) {
+ return DType_FP8E4M3;
+ } else if (element_type.isF16()) {
+ return DType_FP16;
+ } else if (element_type.isBF16()) {
+ return DType_BF16;
} else if (element_type.isUnsignedInteger(8)) {
return DType_UINT8;
} else if (element_type.isInteger(4)) {
- return DType_INT8;
+ return DType_INT4;
} else if (element_type.isInteger(8)) {
return DType_INT8;
} else if (element_type.isInteger(16)) {
@@ -94,18 +105,79 @@ DType Type2DType(mlir::Type element_type) {
return DType_UNKNOWN;
}
+static DType Type2AccDType(mlir::Type element_type) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
+ if (element_type.isF32()) {
+ return DType_FP32;
+ } else if (element_type.isF16()) {
+ return DType_FP16;
+ } else if (element_type.isInteger(32)) {
+ return DType_INT32;
+ } else if (element_type.isInteger(48)) {
+ return DType_INT48;
+ }
+ return DType_UNKNOWN;
+}
class TosaSerializationBlockBuilder;
+class TosaSerializationRegionBuilder;
+
+std::unordered_map<std::string, mlir::Operation *> variable_tensor_op_map;
+std::unordered_map<std::string, std::string>
+ variable_tensor_flatbuffer_name_map;
+static int variable_tensor_index = 0;
+
+namespace {
+
+// for now, this is a global map of variables
+void RegisterVariableOp(mlir::Operation &op) {
+ std::string variable_tensor_flatbuffer_name =
+ "Variable_" + std::to_string(variable_tensor_index++);
+ std::string variable_tensor_mlir_name =
+ op.getAttr("name").cast<mlir::StringAttr>().getValue().str();
+ variable_tensor_op_map[variable_tensor_flatbuffer_name] = &op;
+ variable_tensor_flatbuffer_name_map[variable_tensor_mlir_name] =
+ variable_tensor_flatbuffer_name;
+}
+
+} // namespace
class TosaSerializationOperatorBuilder {
public:
TosaSerializationOperatorBuilder(
TosaSerializationBlockBuilder *_block_builder)
: block_builder(_block_builder) {}
+
template <typename T>
TosaSerializationOperator *build(mlir::Operation &op) const;
+ TosaSerializationHandler *GetTsh() const;
+ TosaSerializationRegionBuilder *GetRegionBuilder() const;
+ mlir::LogicalResult GetDataFromAttribute(mlir::Operation &op,
+ mlir::Attribute &attr,
+ mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const;
+
+ // populate u8_data with either int64_value or float_value depending on
+ // element_type
+ mlir::LogicalResult
+ GetU8DataFromIntOrFloatValue(int64_t int64_value, float fp_value,
+ mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const;
+
+ // populate u8_data with int_value depending on non-float element_type
+ mlir::LogicalResult
+ GetU8DataFromIntValues(const std::vector<int64_t> &int_values,
+ mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const;
+
+ // populate u8_data with fp_value depending on float element_type
+ mlir::LogicalResult
+ GetU8DataFromFloatValues(const std::vector<float> &fp_values,
+ mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const;
private:
std::string GetTensorName(mlir::Value val) const;
+ std::string GetVariableTensorName(mlir::Operation *op) const;
TosaSerializationOperator *BuildPoolOpFromMlirOp(mlir::Operation &op,
Op opcode) const;
TosaSerializationOperator *BuildEwiseBinaryOpFromMlirOp(mlir::Operation &op,
@@ -120,37 +192,275 @@ private:
// This builder assumes each region only has only one block
class TosaSerializationBlockBuilder {
public:
- friend class TosaSerializationOperatorBuilder;
- TosaSerializationBlockBuilder(TosaSerializationBasicBlock *_block,
- TosaSerializationHandler *_tsh,
- mlir::Region *_region)
- : block(_block), tsh(_tsh), region(_region) {}
+ // constructor
+ TosaSerializationBlockBuilder(TosaSerializationBasicBlock *_ser_block,
+ TosaSerializationRegionBuilder *_region_builder,
+ mlir::Block *_block)
+ : ser_block(_ser_block), region_builder(_region_builder), block(_block) {}
mlir::LogicalResult
- BuildAllOpsInRegion(std::vector<mlir::Value> &return_values);
- TosaSerializationBasicBlock *GetBlock() { return block; }
- TosaSerializationHandler *GetTsh() { return tsh; }
+ BuildAllOpsInBlock(std::vector<mlir::Value> &return_values);
+ TosaSerializationBasicBlock *GetBlock() const { return ser_block; }
+ TosaSerializationRegionBuilder *GetRegionBuilder() const {
+ return region_builder;
+ }
+ TosaSerializationHandler *GetTsh() const;
+ std::unordered_map<mlir::Value, std::string> &GetTensorMap() {
+ return tensor_map;
+ }
private:
TosaSerializationOperator *BuildTosaSerializationOperator(
const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op);
TosaSerializationTensor *
+ BuildTosaSerializationVariableTensor(mlir::RankedTensorType tensor_type,
+ const std::string &name,
+ const std::string &variable_mlir_name);
+ TosaSerializationTensor *
BuildTosaSerializationTensor(mlir::Value val, const std::string &name);
- TosaSerializationBasicBlock *block;
- TosaSerializationHandler *tsh;
- mlir::Region *region;
+ TosaSerializationBasicBlock *ser_block;
+ TosaSerializationRegionBuilder *region_builder;
+ mlir::Block *block;
std::unordered_map<mlir::Value, std::string> tensor_map;
std::unordered_map<mlir::Value, std::string> input_tensor_map;
};
+class TosaSerializationRegionBuilder {
+public:
+ // Constructor
+ TosaSerializationRegionBuilder(
+ TosaSerializationRegion *_ser_region, mlir::Region *_region,
+ TosaSerializationRegionBuilder *_parent_value_scope,
+ TosaSerializationHandler *_tsh)
+ : ser_region(_ser_region), region(_region),
+ parent_value_scope(_parent_value_scope), tsh(_tsh) {}
+ TosaSerializationHandler *GetTsh() const { return tsh; }
+ mlir::LogicalResult
+ BuildAllBlocksInRegion(bool is_top, std::vector<mlir::Value> &return_values);
+ TosaSerializationRegionBuilder *GetParentValueScope() const {
+ return parent_value_scope;
+ }
+ std::vector<TosaSerializationBlockBuilder *> &GetBlockBuilders() {
+ return block_builders;
+ }
+
+private:
+ TosaSerializationRegion *ser_region;
+ mlir::Region *region;
+ TosaSerializationRegionBuilder *parent_value_scope;
+ TosaSerializationHandler *tsh;
+ std::vector<TosaSerializationBlockBuilder *> block_builders;
+};
+
+TosaSerializationHandler *TosaSerializationOperatorBuilder::GetTsh() const {
+ return block_builder->GetTsh();
+}
+TosaSerializationHandler *TosaSerializationBlockBuilder::GetTsh() const {
+ return region_builder->GetTsh();
+}
+TosaSerializationRegionBuilder *
+TosaSerializationOperatorBuilder::GetRegionBuilder() const {
+ return block_builder->GetRegionBuilder();
+}
+
std::string
TosaSerializationOperatorBuilder::GetTensorName(mlir::Value val) const {
- if (block_builder->tensor_map.find(val) == block_builder->tensor_map.end()) {
- llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map";
+ auto value_scope = GetRegionBuilder();
+ while (value_scope) {
+ // Traverse through each block builder in the region
+ for (auto curr_block_builder : value_scope->GetBlockBuilders()) {
+ const auto &tensor_map = curr_block_builder->GetTensorMap();
+ if (tensor_map.count(val)) {
+ return tensor_map.at(val);
+ }
+ }
+ value_scope = value_scope->GetParentValueScope();
+ }
+ // Didn't find anything
+ llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map\n";
+ assert(0);
+}
+
+// Unpack 64-bit integer attribute element and pack into a std vector.
+template <class T>
+static std::vector<T> getDenseI64ArrayAttr(mlir::Attribute attr) {
+ auto array_ref = attr.cast<mlir::DenseI64ArrayAttr>().asArrayRef();
+
+ std::vector<T> vec;
+ for (auto val : array_ref) {
+ vec.push_back(val);
+ }
+
+ return vec;
+}
+
+// Unpack 8-bit integer attribute element and pack into a std vector.
+template <class T>
+static std::vector<T> getDenseI8ArrayAttr(mlir::Attribute attr) {
+ auto array_ref = attr.cast<mlir::DenseI8ArrayAttr>().asArrayRef();
+
+ std::vector<T> vec;
+ for (auto val : array_ref) {
+ vec.push_back(val);
+ }
+
+ return vec;
+}
+
+std::string TosaSerializationOperatorBuilder::GetVariableTensorName(
+ mlir::Operation *op) const {
+ std::string variable_tensor_mlir_name =
+ op->getAttr("name").cast<mlir::StringAttr>().getValue().str();
+
+ if (variable_tensor_flatbuffer_name_map.find(variable_tensor_mlir_name) ==
+ variable_tensor_flatbuffer_name_map.end()) {
+ llvm::errs() << "ERROR: Failed to find key " << variable_tensor_mlir_name
+ << " from variable_tensor_flatbuffer_name_map\n";
assert(0);
}
- return block_builder->tensor_map[val];
+ return variable_tensor_flatbuffer_name_map[variable_tensor_mlir_name];
+}
+
+mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
+ mlir::Operation &op, mlir::Attribute &attr, mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const {
+ if (!element_type.isIntOrFloat()) {
+ return mlir::failure();
+ }
+ auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
+
+ // handle float types
+ if (element_type.isa<mlir::FloatType>()) {
+ std::vector<float> fp_data;
+ auto val_attr = attr.dyn_cast<mlir::FloatAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<mlir::APFloat>()) {
+ fp_data.push_back(val.convertToFloat());
+ }
+ } else if (val_attr) {
+ fp_data.push_back((float)val_attr.getValueAsDouble());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return mlir::failure();
+ }
+
+ return GetU8DataFromFloatValues(fp_data, element_type, u8_data);
+ }
+
+ // element_type is integer type
+
+ bool isInt48 = element_type.isInteger(48);
+ std::vector<int64_t> i64_data;
+
+ auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
+ if (dense_attr) {
+ for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
+ int64_t val = isInt48 ? static_cast<int64_t>(valueIt.getLimitedValue())
+ : valueIt.getSExtValue();
+ i64_data.push_back(val);
+ }
+ } else if (val_attr) {
+ i64_data.push_back(val_attr.getInt());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return mlir::failure();
+ }
+
+ return GetU8DataFromIntValues(i64_data, element_type, u8_data);
+}
+
+mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromIntValues(
+ const std::vector<int64_t> &int64_values, mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const {
+ switch (element_type.getIntOrFloatBitWidth()) {
+ case 1: {
+ // bool use bool vec
+ std::vector<bool> bool_values;
+ for (auto v : int64_values) {
+ bool bool_value = v == 0 ? false : true;
+ bool_values.push_back(bool_value);
+ }
+ TosaSerializationHandler::ConvertBooltoU8(bool_values, u8_data);
+ break;
+ }
+ case 4:
+ case 8: {
+ // I4 and I8 use int8_t vec
+ std::vector<int8_t> i8_values;
+ for (auto v : int64_values) {
+ i8_values.push_back(static_cast<int8_t>(v));
+ }
+ if (element_type.isInteger(4)) {
+ TosaSerializationHandler::ConvertI4toU8(i8_values, u8_data);
+ } else {
+ TosaSerializationHandler::ConvertI8toU8(i8_values, u8_data);
+ }
+ break;
+ }
+ case 16: {
+ // I16 use int16_t vec
+ std::vector<int16_t> i16_values;
+ for (auto v : int64_values) {
+ i16_values.push_back(static_cast<int16_t>(v));
+ }
+ TosaSerializationHandler::ConvertI16toU8(i16_values, u8_data);
+ break;
+ }
+ case 32: {
+ // I32 use int32_t vec
+ std::vector<int32_t> i32_values;
+ for (auto v : int64_values) {
+ i32_values.push_back(static_cast<int32_t>(v));
+ }
+ TosaSerializationHandler::ConvertI32toU8(i32_values, u8_data);
+ break;
+ }
+ case 48: {
+ // I48 use int64_t vec
+ TosaSerializationHandler::ConvertI48toU8(int64_values, u8_data);
+ break;
+ }
+ default: {
+ // unsupported bit widths
+ return mlir::failure();
+ }
+ }
+ return mlir::success();
+}
+
+mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromFloatValues(
+ const std::vector<float> &fp_values, mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const {
+ assert(
+ element_type
+ .isa<mlir::FloatType>()); // this should only be called for float type
+ if (element_type.isF16()) {
+ TosaSerializationHandler::ConvertF16toU8(fp_values, u8_data);
+ } else if (element_type.isBF16()) {
+ TosaSerializationHandler::ConvertBF16toU8(fp_values, u8_data);
+ } else if (element_type.isFloat8E4M3FN()) {
+ TosaSerializationHandler::ConvertFP8E4M3toU8(fp_values, u8_data);
+ } else if (element_type.isFloat8E5M2()) {
+ TosaSerializationHandler::ConvertFP8E5M2toU8(fp_values, u8_data);
+ } else if (element_type.isF32()) {
+ TosaSerializationHandler::ConvertF32toU8(fp_values, u8_data);
+ } else {
+ return mlir::failure();
+ }
+ return mlir::success();
+}
+
+mlir::LogicalResult
+TosaSerializationOperatorBuilder::GetU8DataFromIntOrFloatValue(
+ int64_t int64_value, float fp_value, mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const {
+ if (element_type.isa<mlir::FloatType>()) {
+ return GetU8DataFromFloatValues({fp_value}, element_type, u8_data);
+ } else {
+ return GetU8DataFromIntValues({int64_value}, element_type, u8_data);
+ }
}
// Main template to catch unimplemented translation.
@@ -176,43 +486,42 @@ TosaSerializationOperatorBuilder::build(mlir::Operation &op) const {
TosaSerializationOperator *
TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op,
Op opcode) const {
- std::vector<int> pad, stride, kernel;
-
- auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : pad_attr) {
- pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto pad = getDenseI64ArrayAttr<int>(op.getAttr("pad"));
ASSERT_VECTOR_LENGTH(pad, 4);
- auto stride_attr =
- op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : stride_attr) {
- stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride"));
ASSERT_VECTOR_LENGTH(stride, 2);
- auto kernel_attr =
- op.getAttr("kernel").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : kernel_attr) {
- kernel.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto kernel = getDenseI64ArrayAttr<int>(op.getAttr("kernel"));
ASSERT_VECTOR_LENGTH(kernel, 2);
+ DType acc_dtype = DType_FP32;
+ // AvgPool has acc_type, MaxPool does not
+ if (op.hasAttr("acc_type")) {
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ acc_dtype = Type2AccDType(acc_type);
+ }
+
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
- auto quant_info = op.getAttrOfType<mlir::tosa::UnaryOpQuantizationAttr>(
- "quantization_info");
+ int32_t input_zp =
+ op.hasAttr("input_zp")
+ ? input_zp = op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t output_zp =
+ op.hasAttr("output_zp")
+ ? output_zp =
+ op.getAttr("output_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
- int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
- int32_t output_zp = quant_info ? quant_info.output_zp().getInt() : 0;
+ TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp,
+ acc_dtype);
- TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp);
-
- TosaSerializationOperator *tyop = new TosaSerializationOperator(
- opcode, Attribute_PoolAttribute, &attribute,
- std::vector<std::string>{input_name},
- std::vector<std::string>{output_name});
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(opcode, Attribute_PoolAttribute, &attribute,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
return tyop;
}
@@ -239,8 +548,7 @@ TosaSerializationOperatorBuilder::BuildEwiseUnaryOpFromMlirOp(
std::string output_name = GetTensorName(op.getResult(0));
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- opcode, Attribute_NONE, nullptr,
- std::vector<std::string>{input_name},
+ opcode, Attribute_NONE, nullptr, std::vector<std::string>{input_name},
std::vector<std::string>{output_name});
return tyop;
@@ -255,10 +563,10 @@ TosaSerializationOperatorBuilder::BuildReductionOpFromMlirOp(
int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt();
TosaAxisAttribute attribute(axis);
- TosaSerializationOperator *tyop = new TosaSerializationOperator(
- opcode, Attribute_AxisAttribute, &attribute,
- std::vector<std::string>{input_name},
- std::vector<std::string>{output_name});
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(opcode, Attribute_AxisAttribute, &attribute,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
return tyop;
}
@@ -302,7 +610,7 @@ BUILD_OP_ELEMENTWISE_BINARY(Add, ADD)
BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND)
BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR)
BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR)
-BUILD_OP_ELEMENTWISE_BINARY(Div, INTDIV)
+BUILD_OP_ELEMENTWISE_BINARY(IntDiv, INTDIV)
BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND)
BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
@@ -317,12 +625,14 @@ BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS)
BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT)
BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL)
BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ)
+BUILD_OP_ELEMENTWISE_UNARY(Cos, COS)
BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP)
BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR)
BUILD_OP_ELEMENTWISE_UNARY(Log, LOG)
BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT)
BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL)
BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT)
+BUILD_OP_ELEMENTWISE_UNARY(Sin, SIN)
BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY)
BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL)
@@ -335,14 +645,20 @@ BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL)
BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER)
BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL)
+BUILD_OP_ELEMENTWISE_UNARY(Erf, ERF)
BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID)
BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH)
BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY)
BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST)
+BUILD_OP_ELEMENTWISE_BINARY(AddShape, ADD_SHAPE)
+BUILD_OP_ELEMENTWISE_BINARY(SubShape, SUB_SHAPE)
+BUILD_OP_ELEMENTWISE_BINARY(MulShape, MUL_SHAPE)
+BUILD_OP_ELEMENTWISE_BINARY(DivShape, DIV_SHAPE)
+
template <>
TosaSerializationOperator *
-TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
+TosaSerializationOperatorBuilder::build<mlir::tosa::ConstShapeOp>(
mlir::Operation &op) const {
std::string output_name = GetTensorName(op.getResult(0));
TosaSerializationTensor *ts =
@@ -353,142 +669,73 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
return nullptr;
}
-#if 0
- // Gracefully handle constants of "constant unit" type which have no value
- // by creating a numpy value of 0.
- auto unit_val = op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::UnitAttr>();
-
- if (unit_val)
- {
- std::vector<float> data = { 0.0 };
- type = DType_FLOAT;
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- }
-#endif
-
// Update tensor.data array with Const value attribute
+ mlir::Attribute value_attr = op.getAttr("value");
+ if (!value_attr) {
+ op.emitOpError("ERROR: tosa.const_shape doesn't have value");
+ return nullptr;
+ }
+ assert(ts->GetDtype() == DType::DType_SHAPE);
std::vector<uint8_t> u8_data;
- DType type = ts->GetDtype();
- if (type == DType_FLOAT) {
- std::vector<float> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
- if (dense_attr) {
- for (auto val : dense_attr.getValues<float>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back((float)val_attr.getValueAsDouble());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- } else if (type == DType_INT8) {
- std::vector<int8_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
+ std::vector<int64_t> data;
+ auto dense_attr = op.getAttr(llvm::StringRef("value"))
+ .dyn_cast<mlir::DenseIntElementsAttr>();
+ if (!dense_attr) {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
+ }
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int8_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI8toU8(data, u8_data);
- } else if (type == DType_INT16) {
- std::vector<int16_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
+ for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
+ int64_t val = valueIt.getSExtValue();
+ data.push_back(val);
+ }
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int16_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI16toU8(data, u8_data);
- } else if (type == DType_INT32) {
- std::vector<int32_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
+ TosaSerializationHandler::ConvertI64toU8(data, u8_data);
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int32_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI32toU8(data, u8_data);
- } else if (type == DType_INT48) {
- std::vector<int64_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
+ ts->SetData(u8_data);
- if (dense_attr) {
- for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
- uint64_t val = valueIt.getLimitedValue();
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI48toU8(data, u8_data);
- } else if (type == DType_BOOL) {
- std::vector<bool> data;
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_CONST_SHAPE, Attribute_NONE, nullptr, std::vector<std::string>{},
+ std::vector<std::string>{output_name});
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::BoolAttr>();
+ return tyop;
+}
- if (dense_attr) {
- for (auto val : dense_attr.getValues<bool>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getValue());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
+ mlir::Operation &op) const {
+ std::string output_name = GetTensorName(op.getResult(0));
+ TosaSerializationTensor *ts =
+ block_builder->GetBlock()->GetTensorByName(output_name);
+ if (!ts) {
+ op.emitOpError(
+ "ERROR: serialization tensor must be built before building operator");
+ return nullptr;
+ }
- TosaSerializationHandler::ConvertBooltoU8(data, u8_data);
- } else {
- op.emitOpError("Unknown element type of const attribute");
+ // Update tensor.data array with Const value attribute
+ mlir::Attribute value_attr = op.getAttr("value");
+ if (!value_attr) {
+ op.emitOpError("ERROR: tosa.const doesn't have value");
+ return nullptr;
+ }
+ std::vector<uint8_t> u8_data;
+ mlir::Attribute attr = op.getAttr(llvm::StringRef("value"));
+ mlir::Type element_type =
+ llvm::cast<mlir::ShapedType>(op.getResult(0).getType()).getElementType();
+
+ if (GetDataFromAttribute(op, attr, element_type, u8_data).failed()) {
+ op.emitOpError("ERROR: GetDataFromAttribute() fails when building value of "
+ "const tensor");
return nullptr;
}
- ts->SetData(u8_data);
+ ts->SetData(u8_data);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_CONST, Attribute_NONE, nullptr,
- std::vector<std::string>{}, std::vector<std::string>{output_name});
+ Op_CONST, Attribute_NONE, nullptr, std::vector<std::string>{},
+ std::vector<std::string>{output_name});
return tyop;
}
@@ -497,26 +744,13 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>(
mlir::Operation &op) const {
- std::vector<int> pad, stride, dilation;
-
- auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : pad_attr) {
- pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto pad = getDenseI64ArrayAttr<int>(op.getAttr("pad"));
ASSERT_VECTOR_LENGTH(pad, 4);
- auto stride_attr =
- op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : stride_attr) {
- stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride"));
ASSERT_VECTOR_LENGTH(stride, 2);
- auto dilation_attr =
- op.getAttr("dilation").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : dilation_attr) {
- dilation.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto dilation = getDenseI64ArrayAttr<int>(op.getAttr("dilation"));
ASSERT_VECTOR_LENGTH(dilation, 2);
std::string input0_name = GetTensorName(op.getOperand(0));
@@ -524,14 +758,25 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>(
std::string input2_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(op.getResult(0));
+ int32_t input_zp =
+ op.hasAttr("input_zp")
+ ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t weight_zp =
+ op.hasAttr("weight_zp")
+ ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
- auto quant_info =
- op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+ bool local_bound =
+ op.hasAttr("local_bound")
+ ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
+ : false;
- int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
- int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0;
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
- TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp);
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp,
+ local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CONV2D, Attribute_ConvAttribute, &attribute,
@@ -543,28 +788,61 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>(
template <>
TosaSerializationOperator *
-TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>(
+TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>(
mlir::Operation &op) const {
- std::vector<int> pad, stride, dilation;
+ auto pad = getDenseI64ArrayAttr<int>(op.getAttr("pad"));
+ ASSERT_VECTOR_LENGTH(pad, 6);
- auto pad_attr = op.getAttr("pad").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : pad_attr) {
- pad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride"));
+ ASSERT_VECTOR_LENGTH(stride, 3);
+
+ auto dilation = getDenseI64ArrayAttr<int>(op.getAttr("dilation"));
+ ASSERT_VECTOR_LENGTH(dilation, 3);
+
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
+ std::string input2_name = GetTensorName(op.getOperand(2));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ int32_t input_zp =
+ op.hasAttr("input_zp")
+ ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t weight_zp =
+ op.hasAttr("weight_zp")
+ ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+
+ bool local_bound =
+ op.hasAttr("local_bound")
+ ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
+ : false;
+
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
+
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp,
+ local_bound, acc_dtype);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_CONV3D, Attribute_ConvAttribute, &attribute,
+ std::vector<std::string>{input0_name, input1_name, input2_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>(
+ mlir::Operation &op) const {
+ auto pad = getDenseI64ArrayAttr<int>(op.getAttr("pad"));
ASSERT_VECTOR_LENGTH(pad, 4);
- auto stride_attr =
- op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : stride_attr) {
- stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride"));
ASSERT_VECTOR_LENGTH(stride, 2);
- auto dilation_attr =
- op.getAttr("dilation").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : dilation_attr) {
- dilation.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto dilation = getDenseI64ArrayAttr<int>(op.getAttr("dilation"));
ASSERT_VECTOR_LENGTH(dilation, 2);
std::string input0_name = GetTensorName(op.getOperand(0));
@@ -572,13 +850,25 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>(
std::string input2_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(op.getResult(0));
- auto quant_info =
- op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+ int32_t input_zp =
+ op.hasAttr("input_zp")
+ ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t weight_zp =
+ op.hasAttr("weight_zp")
+ ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+
+ bool local_bound =
+ op.hasAttr("local_bound")
+ ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
+ : false;
- int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
- int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0;
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
- TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp);
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp,
+ local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute,
@@ -592,41 +882,39 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>(
mlir::Operation &op) const {
- std::vector<int> outpad, stride, dilation, output_shape;
+ auto out_pad = getDenseI64ArrayAttr<int>(op.getAttr("out_pad"));
+ ASSERT_VECTOR_LENGTH(out_pad, 4);
- auto outpad_attr =
- op.getAttr("out_pad").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : outpad_attr) {
- outpad.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
- ASSERT_VECTOR_LENGTH(outpad, 4);
-
- auto stride_attr =
- op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : stride_attr) {
- stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto stride = getDenseI64ArrayAttr<int>(op.getAttr("stride"));
ASSERT_VECTOR_LENGTH(stride, 2);
- auto output_shape_attr =
- op.getAttr("out_shape").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : output_shape_attr) {
- output_shape.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
- ASSERT_VECTOR_LENGTH(output_shape, 4);
-
std::string input0_name = GetTensorName(op.getOperand(0));
std::string input1_name = GetTensorName(op.getOperand(1));
std::string input2_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(op.getResult(0));
- auto quant_info =
- op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+ int32_t input_zp =
+ op.hasAttr("input_zp")
+ ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t weight_zp =
+ op.hasAttr("weight_zp")
+ ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
- int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
- int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0;
+ mlir::RankedTensorType tensor =
+ op.getOperand(0).getType().cast<mlir::RankedTensorType>();
- TosaTransposeConvAttribute attribute(outpad, stride, output_shape, input_zp, weight_zp);
+ bool local_bound =
+ op.hasAttr("local_bound")
+ ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
+ : false;
+
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
+
+ TosaTransposeConvAttribute attribute(out_pad, stride, input_zp, weight_zp,
+ local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute,
@@ -645,11 +933,15 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FullyConnectedOp>(
std::string input2_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(op.getResult(0));
- auto quant_info =
- op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info");
+ int32_t input_zp =
+ op.hasAttr("input_zp")
+ ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t weight_zp =
+ op.hasAttr("weight_zp")
+ ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
- int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
- int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0;
TosaFullyConnectedAttribute attribute(input_zp, weight_zp);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
@@ -668,11 +960,12 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::MatMulOp>(
std::string input1_name = GetTensorName(op.getOperand(1));
std::string output_name = GetTensorName(op.getResult(0));
- auto quant_info = op.getAttrOfType<mlir::tosa::MatMulOpQuantizationAttr>(
- "quantization_info");
-
- int32_t A_zp = quant_info ? quant_info.a_zp().getInt() : 0;
- int32_t B_zp = quant_info ? quant_info.b_zp().getInt() : 0;
+ int32_t A_zp = op.hasAttr("a_zp")
+ ? op.getAttr("a_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t B_zp = op.hasAttr("b_zp")
+ ? op.getAttr("b_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
TosaMatMulAttribute attribute(A_zp, B_zp);
@@ -705,23 +998,49 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>(
mlir::Operation &op) const {
- int32_t min_int =
- op.getAttr("min_int").dyn_cast<mlir::IntegerAttr>().getInt();
- int32_t max_int =
- op.getAttr("max_int").dyn_cast<mlir::IntegerAttr>().getInt();
- float min_fp = op.getAttr("min_fp")
- .dyn_cast<mlir::FloatAttr>()
- .getValue()
- .convertToFloat();
- float max_fp = op.getAttr("max_fp")
- .dyn_cast<mlir::FloatAttr>()
- .getValue()
- .convertToFloat();
+ auto min_val_attr = op.getAttr("min_val");
+ auto max_val_attr = op.getAttr("max_val");
+
+ mlir::Type input_element_type =
+ llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
+ if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
+ input_element_type)) {
+ input_element_type = quantType.getStorageType();
+ }
+
+ std::vector<uint8_t> min_val, max_val;
+ float min_fp, max_fp;
+ int64_t min_int, max_int;
+
+ if (input_element_type.isa<mlir::FloatType>()) {
+ min_fp =
+ mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat();
+ max_fp =
+ mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat();
+ min_int = max_int = 0;
+ } else {
+ assert(input_element_type.isa<mlir::IntegerType>());
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt();
+ min_fp = max_fp = 0.f;
+ }
+
+ if (GetU8DataFromIntOrFloatValue(min_int, min_fp, input_element_type, min_val)
+ .failed()) {
+ op.emitOpError("Failed to serialize min value");
+ return nullptr;
+ }
+
+ if (GetU8DataFromIntOrFloatValue(max_int, max_fp, input_element_type, max_val)
+ .failed()) {
+ op.emitOpError("Failed to serialize max value");
+ return nullptr;
+ }
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
- TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp);
+ TosaClampAttribute attribute(min_val, max_val);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CLAMP, Attribute_ClampAttribute, &attribute,
@@ -767,8 +1086,27 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConcatOp>(
TosaAxisAttribute attribute(axis);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_CONCAT, Attribute_AxisAttribute, &attribute,
- inputs, std::vector<std::string>{output_name});
+ Op_CONCAT, Attribute_AxisAttribute, &attribute, inputs,
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::ConcatShapeOp>(
+ mlir::Operation &op) const {
+ std::vector<std::string> inputs;
+ for (uint32_t i = 0; i < op.getNumOperands(); i++) {
+ std::string input_name = GetTensorName(op.getOperand(i));
+ inputs.push_back(input_name);
+ }
+
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_CONCAT_SHAPE, Attribute_NONE, nullptr, inputs,
+ std::vector<std::string>{output_name});
return tyop;
}
@@ -780,13 +1118,16 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::NegateOp>(
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
- auto quant_info = op.getAttrOfType<mlir::tosa::UnaryOpQuantizationAttr>(
- "quantization_info");
-
- int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0;
- int32_t output_zp = quant_info ? quant_info.output_zp().getInt() : 0;
+ int32_t input1_zp =
+ op.hasAttr("input1_zp")
+ ? op.getAttr("input1_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
+ int32_t output_zp =
+ op.hasAttr("output_zp")
+ ? op.getAttr("output_zp").cast<mlir::IntegerAttr>().getInt()
+ : 0;
- TosaNegateAttribute attribute(input_zp, output_zp);
+ TosaNegateAttribute attribute(input1_zp, output_zp);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_NEGATE, Attribute_NegateAttribute, &attribute,
@@ -800,20 +1141,12 @@ TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::ReshapeOp>(
mlir::Operation &op) const {
std::string input_name = GetTensorName(op.getOperand(0));
+ std::string shape_name = GetTensorName(op.getOperand(1));
std::string output_name = GetTensorName(op.getResult(0));
- std::vector<int> shape;
- auto shape_attr =
- op.getAttr("new_shape").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : shape_attr) {
- shape.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
-
- TosaReshapeAttribute attribute(shape);
-
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_RESHAPE, Attribute_ReshapeAttribute, &attribute,
- std::vector<std::string>{input_name},
+ Op_RESHAPE, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name, shape_name},
std::vector<std::string>{output_name});
return tyop;
@@ -824,38 +1157,80 @@ TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
mlir::Operation &op) const {
std::string input_name = GetTensorName(op.getOperand(0));
+ std::string padding_name = GetTensorName(op.getOperand(1));
std::string output_name = GetTensorName(op.getResult(0));
+ auto pad_op = llvm::cast<mlir::tosa::PadOp>(op);
- // Match padding tensor as compile-time constant attribute
- // TODO: fix when MLIR dialect changes
- mlir::ElementsAttr paddings_elems;
- if (!matchPattern(op.getOperand(1), m_Constant(&paddings_elems)))
- return nullptr;
+ auto input_zp_attr = pad_op.getInputZpAttr();
+ // pad_const includes the zero point if the tensor uses a zero point.
+ int32_t pad_const_int = input_zp_attr ? input_zp_attr.getInt() : 0;
+ float pad_const_fp = 0.f;
+
+ if (auto tensor = pad_op.getPadConst()) {
+ // Match pad_const tensor as compile-time constant attribute if present.
+ mlir::DenseElementsAttr attr;
+ if (!matchPattern(tensor, m_Constant(&attr)))
+ return nullptr;
+
+ assert(attr.getNumElements() == 1);
+ auto elementTy = attr.getElementType();
+
+ if (elementTy.isa<mlir::IntegerType>()) {
+ pad_const_int = (attr.getValues<mlir::APInt>()[0]).getSExtValue();
+ } else if (elementTy.isa<mlir::FloatType>()) {
+ pad_const_fp = (attr.getValues<mlir::APFloat>()[0]).convertToFloat();
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
+ }
+ }
+
+ std::vector<uint8_t> pad_const;
+ mlir::Type input_element_type =
+ llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
- std::vector<int> paddings;
- for (int32_t val : paddings_elems.getValues<int32_t>()) {
- paddings.push_back(val);
+ if (GetU8DataFromIntOrFloatValue(pad_const_int, pad_const_fp,
+ input_element_type, pad_const)
+ .failed()) {
+ op.emitOpError("Failed to serialize pad_const value");
+ return nullptr;
}
- TosaPadAttribute attribute(paddings, 0 /* pad_const_int */,
- 0.0f /* pad_const_fp */);
+ TosaPadAttribute attribute(pad_const);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_PAD, Attribute_PadAttribute, &attribute,
- std::vector<std::string>{input_name},
+ std::vector<std::string>{input_name, padding_name},
std::vector<std::string>{output_name});
return tyop;
}
template <>
TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::DimOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt();
+ TosaAxisAttribute attribute(axis);
+
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_DIM, Attribute_AxisAttribute, &attribute,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeOp>(
mlir::Operation &op) const {
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
// Match perm tensor as compile-time constant attribute
- // TODO: fix when MLIR dialect changes
mlir::ElementsAttr perm_elems;
if (!matchPattern(op.getOperand(1), m_Constant(&perm_elems)))
return nullptr;
@@ -879,26 +1254,14 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::SliceOp>(
mlir::Operation &op) const {
- std::vector<int> start, size;
- auto begin_attr = op.getAttr("start").dyn_cast<mlir::ArrayAttr>().getValue();
- auto size_attr = op.getAttr("size").dyn_cast<mlir::ArrayAttr>().getValue();
-
- for (auto &int_attr : begin_attr) {
- start.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
-
- for (auto &int_attr : size_attr) {
- size.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
-
- TosaSliceAttribute attribute(start, size);
-
std::string input_name = GetTensorName(op.getOperand(0));
+ std::string start_name = GetTensorName(op.getOperand(1));
+ std::string size_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(op.getResult(0));
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_SLICE, Attribute_SliceAttribute, &attribute,
- std::vector<std::string>{input_name},
+ Op_SLICE, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name, start_name, size_name},
std::vector<std::string>{output_name});
return tyop;
@@ -908,21 +1271,13 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::TileOp>(
mlir::Operation &op) const {
- std::string input_name = GetTensorName(op.getOperand(0));
+ std::string input0_name = GetTensorName(op.getOperand(0));
+ std::string input1_name = GetTensorName(op.getOperand(1));
std::string output_name = GetTensorName(op.getResult(0));
- std::vector<int> multiples;
- auto multiples_attr =
- op.getAttr("multiples").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : multiples_attr) {
- multiples.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
-
- TosaTileAttribute attribute(multiples);
-
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_TILE, Attribute_TileAttribute, &attribute,
- std::vector<std::string>{input_name},
+ Op_TILE, Attribute_NONE, nullptr,
+ std::vector<std::string>{input0_name, input1_name},
std::vector<std::string>{output_name});
return tyop;
@@ -966,60 +1321,21 @@ TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::ResizeOp>(
mlir::Operation &op) const {
std::string input_name = GetTensorName(op.getOperand(0));
+ std::string scale_name = GetTensorName(op.getOperand(1));
+ std::string offset_name = GetTensorName(op.getOperand(2));
+ std::string border_name = GetTensorName(op.getOperand(3));
std::string output_name = GetTensorName(op.getResult(0));
- std::vector<int> output_size;
- auto output_size_attr =
- op.getAttr("output_size").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : output_size_attr) {
- output_size.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
- ASSERT_VECTOR_LENGTH(output_size, 2);
-
- std::vector<int> stride;
- auto stride_attr =
- op.getAttr("stride").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : stride_attr) {
- stride.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
- ASSERT_VECTOR_LENGTH(stride, 2);
-
- std::vector<int> offset;
- auto offset_attr =
- op.getAttr("offset").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &int_attr : offset_attr) {
- offset.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
- ASSERT_VECTOR_LENGTH(offset, 2);
-
- int32_t shift = op.getAttr("shift").dyn_cast<mlir::IntegerAttr>().getInt();
-
- std::vector<float> stride_fp;
- auto stride_fp_attr =
- op.getAttr("stride_fp").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &fp_attr : stride_fp_attr) {
- stride_fp.push_back(fp_attr.dyn_cast<mlir::FloatAttr>().getValueAsDouble());
- }
- ASSERT_VECTOR_LENGTH(stride_fp, 2);
-
- std::vector<float> offset_fp;
- auto offset_fp_attr =
- op.getAttr("offset_fp").dyn_cast<mlir::ArrayAttr>().getValue();
- for (auto &fp_attr : offset_fp_attr) {
- offset_fp.push_back(fp_attr.dyn_cast<mlir::FloatAttr>().getValueAsDouble());
- }
- ASSERT_VECTOR_LENGTH(offset_fp, 2);
-
auto mode_str =
op.getAttr("mode").dyn_cast<mlir::StringAttr>().getValue().str();
ResizeMode mode = ResizeModeStr2Enum(mode_str);
- TosaResizeAttribute attribute(output_size, stride, offset, shift, stride_fp,
- offset_fp, mode);
+ TosaResizeAttribute attribute({}, {}, {}, mode);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_RESIZE, Attribute_ResizeAttribute, &attribute,
- std::vector<std::string>{input_name},
+ std::vector<std::string>{input_name, scale_name, offset_name,
+ border_name},
std::vector<std::string>{output_name});
return tyop;
@@ -1048,17 +1364,15 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::MulOp>(
mlir::Operation &op) const {
- std::string input0_name = GetTensorName(op.getOperand(0));
- std::string input1_name = GetTensorName(op.getOperand(1));
- std::string output_name = GetTensorName(op.getResult(0));
-
- int32_t shift = op.getAttr("shift").dyn_cast<mlir::IntegerAttr>().getInt();
- TosaMulAttribute attribute(shift);
+ mlir::tosa::MulOp mul_op = mlir::cast<mlir::tosa::MulOp>(op);
+ std::string input0_name = GetTensorName(mul_op.getInput1());
+ std::string input1_name = GetTensorName(mul_op.getInput2());
+ std::string output_name = GetTensorName(mul_op.getOutput());
+ std::string shift_name = GetTensorName(mul_op.getShift());
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_MUL, Attribute_MulAttribute, &attribute,
- std::vector<std::string>{input0_name, input1_name},
+ Op_MUL, Attribute_NONE, nullptr, {input0_name, input1_name, shift_name},
std::vector<std::string>{output_name});
return tyop;
@@ -1078,8 +1392,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ArithmeticRightShiftOp>(
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_ARITHMETIC_RIGHT_SHIFT, Attribute_ArithmeticRightShiftAttribute,
- &attribute,
- std::vector<std::string>{input0_name, input1_name},
+ &attribute, std::vector<std::string>{input0_name, input1_name},
std::vector<std::string>{output_name});
return tyop;
@@ -1093,7 +1406,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TableOp>(
std::string output_name = GetTensorName(op.getResult(0));
// Match table tensor as compile-time constant attribute
- // TODO: fix when MLIR dialect changes
mlir::ElementsAttr table_elems;
if (!matchPattern(op.getOperand(1), m_Constant(&table_elems)))
return nullptr;
@@ -1127,28 +1439,27 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>(
bool per_channel =
op.getAttr("per_channel").dyn_cast<mlir::BoolAttr>().getValue();
- std::vector<int> multiplier, shift;
- auto multiplier_attr =
- op.getAttr("multiplier").dyn_cast<mlir::ArrayAttr>().getValue();
- auto shift_attr = op.getAttr("shift").dyn_cast<mlir::ArrayAttr>().getValue();
+ bool input_unsigned =
+ op.getAttr("input_unsigned").dyn_cast<mlir::BoolAttr>().getValue();
+ bool output_unsigned =
+ op.getAttr("output_unsigned").dyn_cast<mlir::BoolAttr>().getValue();
- for (auto &int_attr : multiplier_attr) {
- multiplier.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ auto input = op.getOperand(0);
+ auto input_ty = input.getType().cast<mlir::RankedTensorType>();
+ auto output = op.getResult(0);
+ auto output_ty = output.getType().cast<mlir::RankedTensorType>();
- for (auto &int_attr : shift_attr) {
- shift.push_back(int_attr.dyn_cast<mlir::IntegerAttr>().getInt());
- }
+ std::string input_name = GetTensorName(input);
+ std::string multiplier_name = GetTensorName(op.getOperand(1));
+ std::string shift_name = GetTensorName(op.getOperand(2));
+ std::string output_name = GetTensorName(output);
- std::string input_name = GetTensorName(op.getOperand(0));
- std::string output_name = GetTensorName(op.getResult(0));
-
- TosaRescaleAttribute attribute(input_zp, output_zp, multiplier, shift,
- scale32, double_round, per_channel);
+ TosaRescaleAttribute attribute(input_zp, output_zp, scale32, double_round,
+ per_channel, input_unsigned, output_unsigned);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_RESCALE, Attribute_RescaleAttribute, &attribute,
- std::vector<std::string>{input_name},
+ std::vector<std::string>{input_name, multiplier_name, shift_name},
std::vector<std::string>{output_name});
return tyop;
@@ -1161,39 +1472,77 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::CustomOp>(
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
+ const std::string implementation_attrs = op.getAttr("implementation_attrs")
+ .cast<mlir::StringAttr>()
+ .getValue()
+ .str();
+
+ std::vector<uint8_t> attrs_data(implementation_attrs.size());
+ memcpy(attrs_data.data(), implementation_attrs.data(), attrs_data.size());
+ TosaCustomAttribute attribute(
+ op.getAttr("operator_name").cast<mlir::StringAttr>().getValue().str(),
+ op.getAttr("domain_name").cast<mlir::StringAttr>().getValue().str(),
+ attrs_data);
+
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_CUSTOM, Attribute_NONE, nullptr,
+ Op_CUSTOM, Attribute_CustomAttribute, &attribute,
std::vector<std::string>{input_name},
std::vector<std::string>{output_name});
return tyop;
}
+namespace {
+
+// serialize a region and all its blocks, and return region's return values
+TosaSerializationRegion *
+BuildRegion(mlir::Region &region, const std::string region_name,
+ const bool isolated_from_above,
+ TosaSerializationRegionBuilder *curr_region_builder,
+ TosaSerializationHandler *tsh,
+ std::vector<mlir::Value> &return_values, bool is_top = false) {
+ TosaSerializationRegion *ser_region =
+ new TosaSerializationRegion(region_name, {});
+ assert(ser_region);
+ tsh->GetRegions().push_back(ser_region);
+
+ TosaSerializationRegionBuilder *parent_value_scope =
+ isolated_from_above ? nullptr : curr_region_builder;
+
+ TosaSerializationRegionBuilder region_builder(ser_region, &region,
+ parent_value_scope, tsh);
+ if (region_builder.BuildAllBlocksInRegion(is_top, return_values).failed()) {
+ return nullptr;
+ }
+ return ser_region;
+}
+
+static int input_tensor_index = 0;
+static int intermediate_tensor_index = 0;
+static int output_tensor_index = 0;
+
+} // namespace
+
template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
mlir::Operation &op) const {
+ const std::string op_name = op.getName().getStringRef().str();
+ const bool isolated_from_above =
+ op.hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+ auto curr_region_builder = GetRegionBuilder();
std::vector<std::string> input_names, output_names;
+ std::vector<mlir::Value> then_yields, else_yields;
+ auto tsh = GetTsh();
mlir::Region &then_region = op.getRegion(0);
mlir::Region &else_region = op.getRegion(1);
- std::vector<mlir::Value> then_yields, else_yields;
- TosaSerializationBasicBlock *then_block = nullptr;
- TosaSerializationBasicBlock *else_block = nullptr;
-
- // Building then branch block
- std::string then_block_name =
- "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size());
- then_block = new TosaSerializationBasicBlock(
- then_block_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(then_block);
- block_builder->GetTsh()->GetBlocks().push_back(then_block);
-
- TosaSerializationBlockBuilder then_block_builder(
- then_block, block_builder->GetTsh(), &then_region);
- if (then_block_builder.BuildAllOpsInRegion(then_yields).failed()) {
+
+ const std::string then_region_name = op_name + "_then_region";
+ TosaSerializationRegion *ser_then_region =
+ BuildRegion(then_region, then_region_name, isolated_from_above,
+ curr_region_builder, tsh, then_yields);
+ if (!ser_then_region) {
return nullptr;
}
if (then_yields.size() != op.getNumResults()) {
@@ -1202,19 +1551,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
return nullptr;
}
- // Building else branch block
- std::string else_block_name =
- "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size());
- else_block = new TosaSerializationBasicBlock(
- else_block_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(else_block);
- block_builder->GetTsh()->GetBlocks().push_back(else_block);
-
- TosaSerializationBlockBuilder else_block_builder(
- else_block, block_builder->GetTsh(), &else_region);
- if (else_block_builder.BuildAllOpsInRegion(else_yields).failed()) {
+ const std::string else_region_name = op_name + "_else_region";
+ TosaSerializationRegion *ser_else_region =
+ BuildRegion(else_region, else_region_name, isolated_from_above,
+ curr_region_builder, tsh, else_yields);
+ if (!ser_else_region) {
return nullptr;
}
if (else_yields.size() != op.getNumResults()) {
@@ -1223,7 +1564,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
return nullptr;
}
- TosaCondIfAttribute attribute(then_block->GetName(), else_block->GetName());
+ TosaCondIfAttribute attribute(then_region_name, else_region_name);
for (size_t i = 0; i < op.getNumOperands(); i++) {
std::string input_name = GetTensorName(op.getOperand(i));
@@ -1235,9 +1576,9 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
output_names.push_back(output_name);
}
- TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_COND_IF, Attribute_CondIfAttribute, &attribute,
- input_names, output_names);
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_COND_IF, Attribute_CondIfAttribute,
+ &attribute, input_names, output_names);
return tyop;
}
@@ -1246,27 +1587,22 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
mlir::Operation &op) const {
+ const std::string op_name = op.getName().getStringRef().str();
+ const bool isolated_from_above =
+ op.hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+ auto curr_region_builder = GetRegionBuilder();
std::vector<std::string> input_names, output_names;
+ auto tsh = GetTsh();
mlir::Region &cond_region = op.getRegion(0);
mlir::Region &body_region = op.getRegion(1);
std::vector<mlir::Value> cond_yields, body_yields;
- TosaSerializationBasicBlock *cond_block = nullptr;
- TosaSerializationBasicBlock *body_block = nullptr;
-
- // Building cond branch block
- std::string cond_block_name =
- "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size());
- cond_block = new TosaSerializationBasicBlock(
- cond_block_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(cond_block);
- block_builder->GetTsh()->GetBlocks().push_back(cond_block);
-
- TosaSerializationBlockBuilder cond_block_builder(
- cond_block, block_builder->GetTsh(), &cond_region);
- if (cond_block_builder.BuildAllOpsInRegion(cond_yields).failed()) {
+
+ const std::string cond_region_name = op_name + "_cond_region";
+ TosaSerializationRegion *ser_cond_region =
+ BuildRegion(cond_region, cond_region_name, isolated_from_above,
+ curr_region_builder, tsh, cond_yields);
+ if (!ser_cond_region) {
return nullptr;
}
if (cond_yields.size() != 1) {
@@ -1274,19 +1610,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
return nullptr;
}
- // Building body branch block
- std::string body_block_name =
- "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size());
- body_block = new TosaSerializationBasicBlock(
- body_block_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(body_block);
- block_builder->GetTsh()->GetBlocks().push_back(body_block);
-
- TosaSerializationBlockBuilder body_block_builder(
- body_block, block_builder->GetTsh(), &body_region);
- if (body_block_builder.BuildAllOpsInRegion(body_yields).failed()) {
+ const std::string body_region_name = op_name + "_body_region";
+ TosaSerializationRegion *ser_body_region =
+ BuildRegion(body_region, body_region_name, isolated_from_above,
+ curr_region_builder, tsh, body_yields);
+ if (!ser_body_region) {
return nullptr;
}
if (body_yields.size() != op.getNumResults()) {
@@ -1295,8 +1623,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
return nullptr;
}
- TosaWhileLoopAttribute attribute(cond_block->GetName(),
- body_block->GetName());
+ TosaWhileLoopAttribute attribute(cond_region_name, body_region_name);
for (size_t i = 0; i < op.getNumOperands(); i++) {
std::string input_name = GetTensorName(op.getOperand(i));
@@ -1308,135 +1635,286 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
output_names.push_back(output_name);
}
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_WHILE_LOOP, Attribute_WhileLoopAttribute,
+ &attribute, input_names, output_names);
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::RFFT2dOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_real_name = GetTensorName(op.getResult(0));
+ std::string output_imag_name = GetTensorName(op.getResult(1));
+
+ bool local_bound =
+ op.hasAttr("local_bound")
+ ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
+ : false;
+
+ TosaRFFTAttribute attribute(local_bound);
+
+ TosaSerializationOperator *tyop = new TosaSerializationOperator(
+ Op_RFFT2D, Attribute_RFFTAttribute, &attribute,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_real_name, output_imag_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::FFT2dOp>(
+ mlir::Operation &op) const {
+
+ bool inverse = op.getAttr("inverse").dyn_cast<mlir::BoolAttr>().getValue();
+
+ bool local_bound =
+ op.hasAttr("local_bound")
+ ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
+ : false;
+
+ std::string input_real_name = GetTensorName(op.getOperand(0));
+ std::string input_imag_name = GetTensorName(op.getOperand(1));
+ std::string output_real_name = GetTensorName(op.getResult(0));
+ std::string output_imag_name = GetTensorName(op.getResult(1));
+
+ TosaFFTAttribute attribute(inverse, local_bound);
+
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute,
- input_names, output_names);
+ Op_FFT2D, Attribute_FFTAttribute, &attribute,
+ std::vector<std::string>{input_real_name, input_imag_name},
+ std::vector<std::string>{output_real_name, output_imag_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::VariableReadOp>(
+ mlir::Operation &op) const {
+
+ std::string input_name = GetVariableTensorName(&op);
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_IDENTITY, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::VariableWriteOp>(
+ mlir::Operation &op) const {
+
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetVariableTensorName(&op);
+
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_IDENTITY, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
return tyop;
}
/* End translating TOSA operator */
+mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(
+ bool is_top, std::vector<mlir::Value> &return_values) {
+ std::string region_name = ser_region->GetName();
+ int block_index = 0;
+ for (auto &block : this->region->getBlocks()) {
+ // must name first block of top region "main"
+ const std::string block_name =
+ (is_top && block_index == 0)
+ ? "main"
+ : (region_name + "_bb" + std::to_string(block_index++));
+ TosaSerializationBasicBlock *ser_block = new TosaSerializationBasicBlock(
+ block_name, region_name, std::vector<TosaSerializationOperator *>(),
+ std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
+ std::vector<std::string>());
+
+ // build the block
+ TosaSerializationBlockBuilder block_builder(ser_block, this, &block);
+ // Region Builders need access to block builders
+ block_builders.push_back(&block_builder);
+
+ if (block_builder.BuildAllOpsInBlock(return_values).failed()) {
+ return mlir::failure();
+ }
+
+ if (return_values.empty()) {
+ llvm::errs() << "BWarning: graph doesn't have return values\n";
+ }
-mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion(
+ // Add serialized block to serialized region
+ ser_region->GetBlocks().push_back(ser_block);
+ }
+
+ return mlir::success();
+}
+
+mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
std::vector<mlir::Value> &return_values) {
TosaSerializationOperator *ser_operator = nullptr;
TosaSerializationTensor *ser_tensor = nullptr;
size_t num_blocks_in_region = 0;
- static int input_tensor_index = 0;
- static int intermediate_tensor_index = 0;
- static int output_tensor_index = 0;
TosaSerializationOperatorBuilder op_builder(this);
- for (auto &bb : region->getBlocks()) {
- num_blocks_in_region++;
-
- if (num_blocks_in_region > 1) {
- llvm::errs() << "Invalid MLIR: multiple blocks in a region\n";
- return mlir::failure();
- }
-
- // We always have one block for each region right now
- assert(bb.isEntryBlock());
-
- // Specify block input tensor name
- for (auto args : bb.getArguments()) {
- std::string block_input_name =
- "TosaInput_" + std::to_string(input_tensor_index++);
- block->GetInputs().push_back(block_input_name);
- tensor_map[args] = block_input_name;
- input_tensor_map[args] = block_input_name;
- }
+ // Specify block input tensor name
+ for (auto args : block->getArguments()) {
+ std::string block_input_name =
+ "TosaInput_" + std::to_string(input_tensor_index++);
+ ser_block->GetInputs().push_back(block_input_name);
+ tensor_map[args] = block_input_name;
+ input_tensor_map[args] = block_input_name;
+ }
+
+ // Build tensor_map
+ for (auto &op : block->getOperations()) {
+ if (llvm::isa<mlir::tosa::VariableOp>(op)) {
+ RegisterVariableOp(op);
+ } else if (!(llvm::isa<mlir::tosa::YieldOp>(op) ||
+ llvm::isa<mlir::func::ReturnOp>(op) ||
+ llvm::isa<mlir::tensor::CastOp>(op))) {
+ for (uint32_t i = 0; i < op.getNumResults(); i++) {
+ std::string intermediate_tensor_name =
+ "layer_" + std::to_string(intermediate_tensor_index++);
+ tensor_map[op.getResult(i)] = intermediate_tensor_name;
+ }
+ } else {
+ if (llvm::isa<mlir::tensor::CastOp>(op))
+ continue;
+ // Override return tensor name
+ for (auto val : op.getOperands()) {
+ // Workaround to skip mlir::tensor::CastOp before return
+ mlir::Operation *val_defining_op = val.getDefiningOp();
+ if (val_defining_op) {
+ if (llvm::isa<mlir::tensor::CastOp>(*val_defining_op))
+ val = val_defining_op->getOperand(0);
+ }
- // Build tensor_map
- for (auto &op : bb) {
- if (!(llvm::isa<mlir::tosa::YieldOp>(op) ||
- llvm::isa<mlir::func::ReturnOp>(op) ||
- llvm::isa<mlir::tensor::CastOp>(op))) {
- for (uint32_t i = 0; i < op.getNumResults(); i++) {
- std::string intermediate_tensor_name =
- "layer_" + std::to_string(intermediate_tensor_index++);
- tensor_map[op.getResult(i)] = intermediate_tensor_name;
+ // Sanity check. This mlir::Value should be built in map since graph
+ // is DAG
+ if (tensor_map.find(val) == tensor_map.end()) {
+ llvm::errs() << "ERROR: Can't find built mlir::Value key.\n";
+ return mlir::failure();
}
- } else {
- if (llvm::isa<mlir::tensor::CastOp>(op))
- continue;
- // Override return tensor name
- for (auto val : op.getOperands()) {
- // Workaround to skip mlir::tensor::CastOp before return
- mlir::Operation *val_defining_op = val.getDefiningOp();
- if (val_defining_op) {
- if (llvm::isa<mlir::tensor::CastOp>(*val_defining_op))
- val = val_defining_op->getOperand(0);
- }
-
- // Sanity check. This mlir::Value should be built in map since graph
- // is DAG
- if (tensor_map.find(val) == tensor_map.end()) {
- llvm::errs() << "ERROR: Can't find built mlir::Value key.\n";
- return mlir::failure();
- }
-
- // If returned value is block input, short-circuit the tensor name
- // Otherwise, build a new output name and override the origin tensor
- // name
- if (input_tensor_map.find(val) != input_tensor_map.end()) {
- block->GetOutputs().push_back(input_tensor_map[val]);
- return_values.push_back(val);
- } else {
- std::string output_name =
- "TosaOutput_" + std::to_string(output_tensor_index++);
- tensor_map[val] = output_name;
- block->GetOutputs().push_back(output_name);
- return_values.push_back(val);
- }
+
+ // If returned value is block input, short-circuit the tensor name
+ // Otherwise, build a new output name and override the origin tensor
+ // name
+ if (input_tensor_map.find(val) != input_tensor_map.end()) {
+ ser_block->GetOutputs().push_back(input_tensor_map[val]);
+ return_values.push_back(val);
+ } else {
+ std::string output_name =
+ "TosaOutput_" + std::to_string(output_tensor_index++);
+ tensor_map[val] = output_name;
+ ser_block->GetOutputs().push_back(output_name);
+ return_values.push_back(val);
}
}
}
+ }
- // Build tensor
+ // Build variable tensor
+ for (auto pair : variable_tensor_op_map) {
+ mlir::Operation *op = pair.second;
+ mlir::Value val = op->getResult(0);
+ mlir::RankedTensorType tensor_type = op->getAttr("type")
+ .cast<mlir::TypeAttr>()
+ .getValue()
+ .cast<mlir::RankedTensorType>();
- // The tensor_map is sorted by hashed mlir::Value types.
- // For serialization, sort tensors alphabetically by name for a
- // deterministic and human-friendly ordering.
- std::map<std::string, mlir::Value> tensor_name_sort;
- for (auto pair : tensor_map)
- tensor_name_sort[pair.second] = pair.first;
+ std::string variable_mlir_name =
+ op->getAttr("name").cast<mlir::StringAttr>().getValue().str();
- for (auto pair : tensor_name_sort) {
- ser_tensor = BuildTosaSerializationTensor(pair.second /* val */,
- pair.first /* name */);
- if (!ser_tensor) {
- llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n";
- return mlir::failure();
- }
- block->GetTensors().push_back(ser_tensor);
+ ser_tensor = BuildTosaSerializationVariableTensor(
+ tensor_type /* tensor_type */, pair.first /* flatbuffer name */,
+ variable_mlir_name);
+ if (!ser_tensor) {
+ llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n";
+ return mlir::failure();
}
-
- // Build operator
- for (auto &op : bb) {
- if (llvm::isa<mlir::tosa::YieldOp>(op) ||
- llvm::isa<mlir::func::ReturnOp>(op) ||
- llvm::isa<mlir::tensor::CastOp>(op))
- continue;
- ser_operator = BuildTosaSerializationOperator(op_builder, op);
- if (!ser_operator) {
- llvm::errs() << "ERROR: Failed to build TosaSerializationOperator\n";
+ // Initialize if "initial_value" attribute exists. If not, set data to all
+ // zeros
+ mlir::Attribute initial_value = op->getAttr("initial_value");
+ std::vector<uint8_t> u8_data;
+ if (initial_value) {
+ if (initial_value.isa<mlir::DenseElementsAttr>()) {
+ if (op_builder
+ .GetDataFromAttribute(*op, initial_value,
+ tensor_type.getElementType(), u8_data)
+ .failed()) {
+ llvm::errs() << "ERROR: GetDataFromAttribute() fails when building "
+ "initial_value of variable tensor\n";
+ return mlir::failure();
+ }
+ } else {
+ llvm::errs() << "ERROR: Unknown initial_value attribute type\n";
return mlir::failure();
}
- block->GetOperators().push_back(ser_operator);
+ } else {
+ TosaSerializationHandler::ForceAlignTensorData(u8_data);
}
+ ser_tensor->SetData(u8_data);
+ ser_block->GetTensors().push_back(ser_tensor);
+ }
+
+ // Build tensor
+
+ // The tensor_map is sorted by hashed mlir::Value types.
+ // For serialization, sort tensors alphabetically by name for a
+ // deterministic and human-friendly ordering.
+ std::map<std::string, mlir::Value> tensor_name_sort;
+ for (auto pair : tensor_map)
+ tensor_name_sort[pair.second] = pair.first;
+
+ for (auto pair : tensor_name_sort) {
+ mlir::RankedTensorType tensor_type =
+ pair.second.getType().cast<mlir::RankedTensorType>();
+ ser_tensor = BuildTosaSerializationTensor(pair.second /* val */,
+ pair.first /* name */);
+ if (!ser_tensor) {
+ llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n";
+ return mlir::failure();
+ }
+ ser_block->GetTensors().push_back(ser_tensor);
+ }
+
+ // Build operator
+ for (auto &op : block->getOperations()) {
+ if (llvm::isa<mlir::tosa::YieldOp>(op) ||
+ llvm::isa<mlir::func::ReturnOp>(op) ||
+ llvm::isa<mlir::tensor::CastOp>(op) ||
+ llvm::isa<mlir::tosa::VariableOp>(op))
+ continue;
+ ser_operator = BuildTosaSerializationOperator(op_builder, op);
+ if (!ser_operator) {
+ llvm::errs() << "ERROR: Failed to build TosaSerializationOperator\n";
+ return mlir::failure();
+ }
+ ser_block->GetOperators().push_back(ser_operator);
}
-
return mlir::success();
}
TosaSerializationOperator *
TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) {
- std::string full_op_name = op.getName().getStringRef().str();
TosaSerializationOperator *target_operator = nullptr;
- if (false) {
+ if (llvm::isa<mlir::tosa::VariableReadOp>(op)) {
+ target_operator = op_builder.build<mlir::tosa::VariableReadOp>(op);
+ } else if (llvm::isa<mlir::tosa::VariableWriteOp>(op)) {
+ target_operator = op_builder.build<mlir::tosa::VariableWriteOp>(op);
}
#define DEF_OPERATOR(MLIR_OP) \
else if (llvm::isa<mlir::tosa::MLIR_OP##Op>(op)) { \
@@ -1455,17 +1933,22 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
return nullptr;
}
+ if (llvm::isa<mlir::tosa::VariableReadOp>(op) ||
+ llvm::isa<mlir::tosa::VariableWriteOp>(op)) {
+ return target_operator;
+ }
+
// Sanity check the number of inputs/outputs of TOSA dialect matches the
// number of TOSA flatbuffer
if (op.getNumOperands() != target_operator->GetInputTensorNames().size()) {
- llvm::errs() << "WARNING. MLIR operator has " << op.getNumOperands()
+ llvm::errs() << "WARNING: MLIR operator has " << op.getNumOperands()
<< " input tensors != Flatbuffer "
"operator has "
<< target_operator->GetInputTensorNames().size()
<< " input tensors\n";
}
if (op.getNumResults() != target_operator->GetOutputTensorNames().size()) {
- llvm::errs() << "WARNING. MLIR operator has " << op.getNumResults()
+ llvm::errs() << "WARNING: MLIR operator has " << op.getNumResults()
<< " output tensors != Flatbuffer "
"operator has "
<< target_operator->GetOutputTensorNames().size()
@@ -1476,30 +1959,83 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
}
TosaSerializationTensor *
+TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor(
+ mlir::RankedTensorType tensor_type, const std::string &name,
+ const std::string &variable_mlir_name) {
+ // If tensor already created before, use that tensor directly, create a new
+ // one otherwise
+ TosaSerializationTensor *ts = ser_block->GetTensorByName(name);
+ if (ts) {
+ return nullptr;
+ }
+
+ std::vector<int32_t> shape(tensor_type.getShape().begin(),
+ tensor_type.getShape().end());
+
+ DType type = Type2DType(tensor_type.getElementType());
+
+ ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(),
+ /* is_variable = */ true,
+ /* is_unranked = */ false,
+ variable_mlir_name);
+
+ return ts;
+}
+
+TosaSerializationTensor *
TosaSerializationBlockBuilder::BuildTosaSerializationTensor(
mlir::Value val, const std::string &name) {
// If tensor already created before, use that tensor directly, create a new
// one otherwise
- TosaSerializationTensor *ts = block->GetTensorByName(name);
+ TosaSerializationTensor *ts = ser_block->GetTensorByName(name);
if (ts) {
return nullptr;
}
- mlir::RankedTensorType tensor =
- val.getType().dyn_cast<mlir::RankedTensorType>();
- std::vector<int32_t> shape(tensor.getShape().begin(),
- tensor.getShape().end());
- DType type = Type2DType(tensor.getElementType());
+ // handling of tosa.shape values
+ if (auto shape_ty = val.getType().dyn_cast<mlir::tosa::shapeType>()) {
+ auto rank = shape_ty.getRank();
+ std::vector<int32_t> shape;
+ if (rank > 0) {
+ shape.push_back(rank);
+ }
+ ts = new TosaSerializationTensor(name,
+ /* shape = */ shape,
+ /* type = */ DType::DType_SHAPE,
+ /* data = */ std::vector<uint8_t>());
+ return ts;
+ }
- ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>());
+ auto ttype = val.getType().dyn_cast<mlir::TensorType>();
+ if (!ttype) {
+ llvm::errs() << "TOSA serialization, supplied value is not of TensorType\n";
+ return nullptr;
+ }
+
+ const bool is_unranked = !ttype.hasRank();
+ std::vector<int32_t> shape;
+ if (!is_unranked) {
+ auto shaped = val.getType().dyn_cast<mlir::ShapedType>();
+ assert(shaped);
+ for (int idx = 0; idx < ttype.getRank(); idx++) {
+ if (shaped.isDynamicDim(idx)) {
+ shape.push_back(0); // size of 0 represents dynamic dimension
+ } else {
+ auto dim = shaped.getDimSize(idx);
+ shape.push_back(dim);
+ }
+ }
+ }
+
+ DType type = Type2DType(ttype.getElementType());
+ ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(),
+ /* variable = */ false, is_unranked);
return ts;
}
mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func,
TosaSerializationHandler &tsh) {
- TosaSerializationBasicBlock *main_block;
-
mlir::Region *main_region = func.getCallableRegion();
std::vector<mlir::Value> main_returns;
@@ -1508,21 +2044,22 @@ mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func,
return mlir::failure();
}
- if (!tsh.GetBlocks().empty()) {
- llvm::errs() << "Internal Error: TosaSerializationHandler's block list "
+ if (!tsh.GetRegions().empty()) {
+ llvm::errs() << "Internal Error: TosaSerializationHandler's region list "
"must be empty\n";
return mlir::failure();
}
- main_block = new TosaSerializationBasicBlock(
- std::string("main"), std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(main_block);
- tsh.GetBlocks().push_back(main_block);
+ // reset static counters
+ input_tensor_index = 0;
+ intermediate_tensor_index = 0;
+ output_tensor_index = 0;
- TosaSerializationBlockBuilder block_builder(main_block, &tsh, main_region);
- if (block_builder.BuildAllOpsInRegion(main_returns).failed()) {
+ TosaSerializationRegion *ser_main_region =
+ BuildRegion(*main_region, "main", /* isolated_from_above = */ true,
+ /* parent_value_scope = */ nullptr, &tsh, main_returns,
+ /* is_top = */ true);
+ if (!ser_main_region) {
return mlir::failure();
}
@@ -1580,20 +2117,42 @@ mlir::LogicalResult dumpTosaJSON(mlir::func::FuncOp &func) {
return mlir::success();
}
-namespace mlir {
+#define GEN_PASS_DEF_TOSASERIALIZATIONPASS
+namespace mlir {
namespace tosa {
-
namespace {
class TosaSerialize : public TosaSerializationPassBase<TosaSerialize> {
public:
void runOnOperation() final {
- auto function = getOperation();
+ auto moduleOp = getOperation();
- if (dumpTosaFlatbuffer(function).failed()) {
- llvm::errs() << "Failed to generate TOSA flatbuffer...\n";
- return signalPassFailure();
+ // iterate through each op in the moduleOp, call dumpTosaFlatbuffer if
+ // that's a func.funcOp
+
+ auto regions = moduleOp->getRegions();
+ auto region_size = regions.size();
+
+ auto region_0 = regions.begin();
+ auto block_size = region_0->getBlocks().size();
+
+ auto block_0 = region_0->getBlocks().begin();
+
+ auto op_size = block_0->getOperations().size();
+
+ for (auto it = block_0->getOperations().begin();
+ it != block_0->getOperations().end(); ++it) {
+ // read variableOps that are declared outside of functionOps
+ if (llvm::isa<mlir::tosa::VariableOp>(*it)) {
+ RegisterVariableOp(*it);
+ } else if (llvm::isa<mlir::func::FuncOp>(*it)) {
+ auto funcOp = dyn_cast<mlir::func::FuncOp>((*it));
+ if (dumpTosaFlatbuffer(funcOp).failed()) {
+ llvm::errs() << "Failed to generate TOSA flatbuffer...\n";
+ return signalPassFailure();
+ }
+ }
}
}
};
@@ -1614,7 +2173,7 @@ public:
} // anonymous namespace
// Creates an instance of the TOSA flatbuffer generation pass
-std::unique_ptr<Pass> createTosaSerializePass() {
+std::unique_ptr<OperationPass<ModuleOp>> createTosaSerializePass() {
return std::make_unique<TosaSerialize>();
}
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject 343d6a703c3a270a01102ec468b59ef2967b595
+Subproject 3aebe2bd863d6e0cb82171984cd49e5ad516d0d