aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-02-16 22:57:53 +0000
committerTai Ly <tai.ly@arm.com>2023-03-06 21:42:10 +0000
commit581fb5d0e706d8669dd5ce21e1be2770b4951e02 (patch)
treef0d7d607b59f29399e7ac5801cd360c78840e2a1
parent13a5b0fcc6663ed03ab88254493439f6b8b9fda1 (diff)
downloadtosa_mlir_translator-581fb5d0e706d8669dd5ce21e1be2770b4951e02.tar.gz
Add Tosa Deserialization
Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I8b0220a8465e75b1accf6b0854e911a425730da6
-rw-r--r--CMakeLists.txt6
-rw-r--r--include/DeserializationPasses.h37
-rw-r--r--include/DeserializationPasses.td25
-rw-r--r--include/schema_operator.def93
-rw-r--r--src/TosaDeserialize.cpp1385
5 files changed, 1546 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5512401..db0bb78 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -17,15 +17,21 @@ set(LLVM_TARGET_DEFINITIONS include/SerializationPasses.td)
mlir_tablegen(include/SerializationPasses.h.inc -gen-pass-decls -name TosaSerialization)
add_public_tablegen_target(tosa_serialization_passes_inc_gen)
+set(LLVM_TARGET_DEFINITIONS include/DeserializationPasses.td)
+mlir_tablegen(include/DeserializationPasses.h.inc -gen-pass-decls -name TosaDeserialization)
+add_public_tablegen_target(tosa_deserialization_passes_inc_gen)
+
# Compile the TOSA serialization_lib
add_subdirectory(third_party/serialization_lib)
add_mlir_library(tosa_serialize
src/TosaSerialize.cpp
+ src/TosaDeserialize.cpp
DEPENDS
mlir-headers
tosa_serialization_passes_inc_gen
+ tosa_deserialization_passes_inc_gen
LINK_LIBS PRIVATE
tosa_serialization_lib
diff --git a/include/DeserializationPasses.h b/include/DeserializationPasses.h
new file mode 100644
index 0000000..1bc195a
--- /dev/null
+++ b/include/DeserializationPasses.h
@@ -0,0 +1,37 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef INCLUDE_DESERIALIZATION_PASSES_H
+#define INCLUDE_DESERIALIZATION_PASSES_H
+
+#include <memory>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaDeserializePass();
+std::unique_ptr<Pass> createTosaDeserializeJSONPass();
+
+#define GEN_PASS_REGISTRATION
+#define GEN_PASS_CLASSES
+#include "include/DeserializationPasses.h.inc"
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // INCLUDE_DESERIALIZATION_PASSES_H
diff --git a/include/DeserializationPasses.td b/include/DeserializationPasses.td
new file mode 100644
index 0000000..999f0b4
--- /dev/null
+++ b/include/DeserializationPasses.td
@@ -0,0 +1,25 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+include "mlir/Pass/PassBase.td"
+
+def TosaDeserializationPass : Pass<"tosa-deserialize", "func::FuncOp"> {
+ let summary = "Deserialize TOSA flatbuffer. Clear original MLIR graph and generate TOSA MLIR";
+ let constructor = "createTosaDeserializePass()";
+}
+
+def TosaDeserializationJSONPass : Pass<"tosa-deserialize-json", "func::FuncOp"> {
+ let summary = "Deserialize TOSA flatbuffer JSON form. Clear original MLIR graph and generate TOSA MLIR";
+ let constructor = "createTosaDeserializeJSONPass()";
+}
diff --git a/include/schema_operator.def b/include/schema_operator.def
new file mode 100644
index 0000000..1af367e
--- /dev/null
+++ b/include/schema_operator.def
@@ -0,0 +1,93 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ Syntax:
+ DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME)
+
+ Description:
+ SCHEMA_OP_NAME: the schema operator names, must match Op names in schema/tosa.fbs in serialization_lib repo
+*/
+
+/* schema operators */
+DEF_SCHEMA_OPERATOR(ARGMAX)
+DEF_SCHEMA_OPERATOR(AVG_POOL2D)
+DEF_SCHEMA_OPERATOR(CONV2D)
+DEF_SCHEMA_OPERATOR(CONV3D)
+DEF_SCHEMA_OPERATOR(DEPTHWISE_CONV2D)
+DEF_SCHEMA_OPERATOR(FULLY_CONNECTED)
+DEF_SCHEMA_OPERATOR(MATMUL)
+DEF_SCHEMA_OPERATOR(MAX_POOL2D)
+DEF_SCHEMA_OPERATOR(TRANSPOSE_CONV2D)
+DEF_SCHEMA_OPERATOR(CLAMP)
+DEF_SCHEMA_OPERATOR(RESERVED)
+DEF_SCHEMA_OPERATOR(SIGMOID)
+DEF_SCHEMA_OPERATOR(TANH)
+DEF_SCHEMA_OPERATOR(ADD)
+DEF_SCHEMA_OPERATOR(ARITHMETIC_RIGHT_SHIFT)
+DEF_SCHEMA_OPERATOR(BITWISE_AND)
+DEF_SCHEMA_OPERATOR(BITWISE_OR)
+DEF_SCHEMA_OPERATOR(BITWISE_XOR)
+DEF_SCHEMA_OPERATOR(INTDIV)
+DEF_SCHEMA_OPERATOR(LOGICAL_AND)
+DEF_SCHEMA_OPERATOR(LOGICAL_LEFT_SHIFT)
+DEF_SCHEMA_OPERATOR(LOGICAL_RIGHT_SHIFT)
+DEF_SCHEMA_OPERATOR(LOGICAL_OR)
+DEF_SCHEMA_OPERATOR(LOGICAL_XOR)
+DEF_SCHEMA_OPERATOR(MAXIMUM)
+DEF_SCHEMA_OPERATOR(MINIMUM)
+DEF_SCHEMA_OPERATOR(MUL)
+DEF_SCHEMA_OPERATOR(POW)
+DEF_SCHEMA_OPERATOR(SUB)
+DEF_SCHEMA_OPERATOR(TABLE)
+DEF_SCHEMA_OPERATOR(ABS)
+DEF_SCHEMA_OPERATOR(BITWISE_NOT)
+DEF_SCHEMA_OPERATOR(CEIL)
+DEF_SCHEMA_OPERATOR(CLZ)
+DEF_SCHEMA_OPERATOR(EXP)
+DEF_SCHEMA_OPERATOR(FLOOR)
+DEF_SCHEMA_OPERATOR(LOG)
+DEF_SCHEMA_OPERATOR(LOGICAL_NOT)
+DEF_SCHEMA_OPERATOR(NEGATE)
+DEF_SCHEMA_OPERATOR(RECIPROCAL)
+DEF_SCHEMA_OPERATOR(RSQRT)
+DEF_SCHEMA_OPERATOR(SELECT)
+DEF_SCHEMA_OPERATOR(EQUAL)
+DEF_SCHEMA_OPERATOR(GREATER)
+DEF_SCHEMA_OPERATOR(GREATER_EQUAL)
+DEF_SCHEMA_OPERATOR(REDUCE_ANY)
+DEF_SCHEMA_OPERATOR(REDUCE_ALL)
+DEF_SCHEMA_OPERATOR(REDUCE_MAX)
+DEF_SCHEMA_OPERATOR(REDUCE_MIN)
+DEF_SCHEMA_OPERATOR(REDUCE_PRODUCT)
+DEF_SCHEMA_OPERATOR(REDUCE_SUM)
+DEF_SCHEMA_OPERATOR(CONCAT)
+DEF_SCHEMA_OPERATOR(PAD)
+DEF_SCHEMA_OPERATOR(RESHAPE)
+DEF_SCHEMA_OPERATOR(REVERSE)
+DEF_SCHEMA_OPERATOR(SLICE)
+DEF_SCHEMA_OPERATOR(TILE)
+DEF_SCHEMA_OPERATOR(TRANSPOSE)
+DEF_SCHEMA_OPERATOR(GATHER)
+DEF_SCHEMA_OPERATOR(SCATTER)
+DEF_SCHEMA_OPERATOR(RESIZE)
+DEF_SCHEMA_OPERATOR(CAST)
+DEF_SCHEMA_OPERATOR(RESCALE)
+DEF_SCHEMA_OPERATOR(CONST)
+DEF_SCHEMA_OPERATOR(IDENTITY)
+DEF_SCHEMA_OPERATOR(CUSTOM)
+DEF_SCHEMA_OPERATOR(COND_IF)
+DEF_SCHEMA_OPERATOR(WHILE_LOOP)
+DEF_SCHEMA_OPERATOR(FFT2D)
+DEF_SCHEMA_OPERATOR(RFFT2D)
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
new file mode 100644
index 0000000..f6f78fc
--- /dev/null
+++ b/src/TosaDeserialize.cpp
@@ -0,0 +1,1385 @@
+
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// TOSA MLIR deserialize passes
+
+#include "include/DeserializationPasses.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
+#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
+#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
+#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
+#include "mlir/IR/Matchers.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "tosa_serialization_handler.h"
+#include <functional>
+#include <map>
+#include <queue>
+#include <unordered_map>
+#include <vector>
+
+// The namespace might be confusing here. We have mlir::tosa:: defined in MLIR
+// and tosa:: defined in serialization library
+// TODO: align the namespace
+using namespace tosa;
+
+namespace cl = llvm::cl;
+
+llvm::cl::opt<std::string> tosa_deserialize_filename(
+ "tosa-deserialize-filename", llvm::cl::desc("<tosa flatbuffer filename>"),
+ llvm::cl::init("tosa_dump.tosa"), llvm::cl::value_desc("filename"));
+
+llvm::cl::opt<std::string> tosa_deserialize_schema(
+ "tosa-deserialize-schema", llvm::cl::desc("<tosa flatbuffer schema file>"),
+ llvm::cl::init(""), llvm::cl::value_desc("filename"));
+
+namespace {
+
+// construct tensor type from dtype and shape of TosaSerializationTensor
+mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
+ TosaSerializationTensor *ts,
+ mlir::RankedTensorType &type) {
+ mlir::Type element_type;
+ switch (ts->GetDtype()) {
+ case DType_BOOL:
+ element_type = op_builder->getI1Type();
+ break;
+ case DType_UINT8:
+ element_type = op_builder->getIntegerType(8, false);
+ break;
+ case DType_INT4:
+ element_type = op_builder->getI4Type();
+ break;
+ case DType_INT8:
+ element_type = op_builder->getI8Type();
+ break;
+ case DType_INT16:
+ element_type = op_builder->getIntegerType(16);
+ break;
+ case DType_INT32:
+ element_type = op_builder->getI32Type();
+ break;
+ case DType_INT48:
+ element_type = op_builder->getIntegerType(48);
+ break;
+ case DType_FP32:
+ element_type = op_builder->getF32Type();
+ break;
+ case DType_UINT16:
+ element_type = op_builder->getIntegerType(16, false);
+ break;
+ default:
+ llvm::errs() << "ERROR: unknown type " << EnumNamesDType()[ts->GetDtype()]
+ << "\n";
+ return mlir::failure();
+ }
+ llvm::SmallVector<int64_t> shape(ts->GetShape().begin(), ts->GetShape().end());
+ type = mlir::RankedTensorType::get(llvm::makeArrayRef(shape), element_type);
+ return mlir::success();
+}
+
+template <class T>
+mlir::DenseElementsAttr
+BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int16_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ auto type =
+ mlir::RankedTensorType::get({vec.size()}, op_builder->getI16Type());
+ return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec));
+}
+
+template <class T>
+mlir::DenseElementsAttr
+BuildDenseI32ElementsAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int32_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ auto type =
+ mlir::RankedTensorType::get({vec.size()}, op_builder->getI32Type());
+ return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec));
+}
+
+template <class T>
+mlir::DenseI32ArrayAttr BuildDenseI32ArrayAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int32_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ return op_builder->getDenseI32ArrayAttr(vec);
+}
+
+template <class T>
+mlir::DenseI64ArrayAttr BuildDenseI64ArrayAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int64_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ return op_builder->getDenseI64ArrayAttr(vec);
+}
+
+const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) {
+ if (mode == ResizeMode_NEAREST) {
+ return "NEAREST_NEIGHBOR";
+ } else if (mode == ResizeMode_BILINEAR) {
+ return "BILINEAR";
+ }
+ return "";
+}
+
+} // namespace
+
+class TosaMlirRegionBuilder;
+class TosaMlirBlockBuilder;
+
+class TosaMlirOperatorBuilder {
+public:
+ TosaMlirOperatorBuilder(
+ mlir::OpBuilder *_op_builder, TosaSerializationBasicBlock *_ser_block,
+ mlir::Block *_block, mlir::Location _loc,
+ std::unordered_map<std::string, mlir::Value> *_tensor_map,
+ std::unordered_map<std::string, mlir::RankedTensorType> *_tensor_type_map)
+ : op_builder(_op_builder), ser_block(_ser_block), block(_block),
+ loc(_loc), tensor_map(_tensor_map), tensor_type_map(_tensor_type_map) {}
+
+ template <Op OPCODE>
+ std::vector<mlir::Value> build(TosaSerializationOperator *op) const;
+
+ std::string get_string(TosaSerializationOperator *op) const {
+ std::string op_string;
+ op_string += "operator opcode=";
+ op_string += EnumNamesOp()[op->GetOp()];
+ op_string += ", input=[";
+ for (auto ts : op->GetInputTensorNames()) {
+ op_string += (ts + " ");
+ }
+ op_string += "], output=[";
+ for (auto ts : op->GetOutputTensorNames()) {
+ op_string += (ts + " ");
+ }
+ op_string += "]";
+ return op_string;
+ }
+
+private:
+ template <class MLIR_OP>
+ std::vector<mlir::Value>
+ BuildEwiseUnaryOp(TosaSerializationOperator *op) const;
+
+ template <class MLIR_OP>
+ std::vector<mlir::Value>
+ BuildEwiseBinaryOp(TosaSerializationOperator *op) const;
+
+ template <class MLIR_OP>
+ std::vector<mlir::Value>
+ BuildReductionOp(TosaSerializationOperator *op) const;
+
+ template <class MLIR_OP>
+ std::vector<mlir::Value> BuildConvOp(TosaSerializationOperator *op) const;
+
+ mlir::OpBuilder *op_builder;
+ TosaSerializationBasicBlock *ser_block;
+ mlir::Block *block;
+ mlir::Location loc;
+ std::unordered_map<std::string, mlir::Value> *tensor_map;
+ std::unordered_map<std::string, mlir::RankedTensorType> *tensor_type_map;
+};
+
+// Main template to catch unimplemented translation
+template <Op OPCODE>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const
+{
+ llvm::errs() << "ERROR: " << get_string(op) << " translation hasn't been implemented\n";
+ return std::vector<mlir::Value>();
+}
+
+// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D)
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_MAX_POOL2D>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() == Attribute_PoolAttribute); // double check attribute type
+ TosaPoolAttribute* attr = static_cast<TosaPoolAttribute*>(op->GetAttribute());
+ mlir::DenseI64ArrayAttr kernel = BuildDenseI64ArrayAttr(op_builder, attr->kernel());
+ mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride());
+ mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
+ int32_t input_zp = attr->input_zp();
+ int32_t output_zp = attr->output_zp();
+ assert(input_zp == 0 && output_zp == 0);
+
+ mlir::Operation* mlir_op = op_builder->create<mlir::tosa::MaxPool2dOp>(loc, output_type, input_val, kernel, stride, pad);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({ mlir_op->getResult(0) });
+}
+
+// BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D)
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() == Attribute_PoolAttribute); // double check attribute type
+ TosaPoolAttribute* attr = static_cast<TosaPoolAttribute*>(op->GetAttribute());
+ mlir::DenseI64ArrayAttr kernel = BuildDenseI64ArrayAttr(op_builder, attr->kernel());
+ mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride());
+ mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
+ int32_t input_zp = attr->input_zp();
+ int32_t output_zp = attr->output_zp();
+ mlir::Operation* mlir_op;
+ if (input_zp == 0 && output_zp == 0) {
+ mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>(loc, output_type, input_val, kernel, stride, pad);
+ } else {
+ auto quant = op_builder->getAttr<mlir::tosa::UnaryOpQuantizationAttr>(input_zp, output_zp);
+ mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>(loc, output_type, input_val, kernel, stride, pad, quant);
+ }
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({ mlir_op->getResult(0) });
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseUnaryOp(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() == Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation* mlir_op = op_builder->create<MLIR_OP>(loc, output_type, input_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({ mlir_op->getResult(0) });
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseBinaryOp(TosaSerializationOperator* op) const
+{
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() == Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation* mlir_op = op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({ mlir_op->getResult(0) });
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildReductionOp(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() == Attribute_AxisAttribute); // double check attribute type
+
+ TosaAxisAttribute* attr = static_cast<TosaAxisAttribute*>(op->GetAttribute());
+ auto axis = op_builder->getI64IntegerAttr(attr->axis());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<MLIR_OP>(loc, output_type, input_val, axis);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({ mlir_op->getResult(0) });
+}
+
+#define BUILD_OP_ELEMENTWISE_UNARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator* op) const { \
+ return BuildEwiseUnaryOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+#define BUILD_OP_ELEMENTWISE_BINARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator* op) const { \
+ return BuildEwiseBinaryOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+#define BUILD_OP_REDUCTION(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator* op) const { \
+ return BuildReductionOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D)
+// BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D)
+
+BUILD_OP_ELEMENTWISE_BINARY(Add, ADD)
+BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND)
+BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR)
+BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR)
+BUILD_OP_ELEMENTWISE_BINARY(Div, INTDIV)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalOr, LOGICAL_OR)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalXor, LOGICAL_XOR)
+BUILD_OP_ELEMENTWISE_BINARY(Maximum, MAXIMUM)
+BUILD_OP_ELEMENTWISE_BINARY(Minimum, MINIMUM)
+BUILD_OP_ELEMENTWISE_BINARY(Pow, POW)
+BUILD_OP_ELEMENTWISE_BINARY(Sub, SUB)
+
+BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS)
+BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT)
+BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL)
+BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ)
+BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP)
+BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR)
+BUILD_OP_ELEMENTWISE_UNARY(Log, LOG)
+BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT)
+BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL)
+BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT)
+
+BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY)
+BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL)
+BUILD_OP_REDUCTION(ReduceMax, REDUCE_MAX)
+BUILD_OP_REDUCTION(ReduceMin, REDUCE_MIN)
+BUILD_OP_REDUCTION(ReduceProd, REDUCE_PRODUCT)
+BUILD_OP_REDUCTION(ReduceSum, REDUCE_SUM)
+
+BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL)
+BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER)
+BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL)
+
+BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID)
+BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH)
+BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY)
+BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST)
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONST>(TosaSerializationOperator* op) const
+{
+ const auto &output_name = op->GetOutputTensorNames()[0];
+ mlir::RankedTensorType output_type = tensor_type_map->at(output_name);
+ TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name);
+ const auto &data = ts->GetData();
+ auto &shape = ts->GetShape();
+ // compute output data size
+ uint32_t out_size = 1;
+ for (const auto dim : shape) {
+ out_size *= dim;
+ }
+ mlir::DenseElementsAttr value_attr;
+ switch (ts->GetDtype()) {
+ case DType_FP32: {
+ std::vector<float> float_data;
+ TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data);
+ value_attr =
+ mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(float_data));
+ break;
+ }
+ case DType_INT8: {
+ std::vector<int8_t> int8_data;
+ TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data);
+ value_attr =
+ mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data));
+ break;
+ }
+ case DType_INT16: {
+ std::vector<int16_t> int16_data;
+ TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data);
+ value_attr =
+ mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int16_data));
+ break;
+ }
+ case DType_INT32: {
+ std::vector<int32_t> int32_data;
+ TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data);
+ value_attr =
+ mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int32_data));
+ break;
+ }
+ case DType_INT48: {
+ std::vector<int64_t> int48_data;
+ TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data);
+ std::vector<mlir::APInt> apint_data;
+ for (const auto v : int48_data) {
+ mlir::APInt apint_value(48, static_cast<uint64_t>(v),
+ /* isSigned = */ false);
+ apint_data.push_back(apint_value);
+ }
+ value_attr =
+ mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(apint_data));
+ break;
+ }
+ case DType_BOOL: {
+ std::vector<bool> bool_data;
+ TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data);
+ llvm::SmallVector<bool> bool_values(bool_data.begin(), bool_data.end());
+ value_attr = mlir::DenseElementsAttr::get(output_type, bool_values);
+ break;
+ }
+ default:
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " contains unsupported element type\n";
+ return std::vector<mlir::Value>();
+ }
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, output_type, value_attr);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_ConvAttribute); // double check attribute type
+ TosaConvAttribute *attr =
+ static_cast<TosaConvAttribute *>(op->GetAttribute());
+ mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
+ mlir::DenseI64ArrayAttr stride =
+ BuildDenseI64ArrayAttr(op_builder, attr->stride());
+ mlir::DenseI64ArrayAttr dilation =
+ BuildDenseI64ArrayAttr(op_builder, attr->dilation());
+ auto input_zp = attr->input_zp();
+ auto weight_zp = attr->weight_zp();
+
+ // quantizationattr is required for quantized type, and not allowed for float
+ // type
+ mlir::Operation *mlir_op;
+ if (output_type.getElementType().isa<mlir::FloatType>()) {
+ assert(input_zp == 0 && weight_zp == 0);
+ mlir_op =
+ op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val,
+ input2_val, pad, stride, dilation);
+ } else {
+ auto quant = op_builder->getAttr<mlir::tosa::ConvOpQuantizationAttr>(
+ input_zp, weight_zp);
+ mlir_op =
+ op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val,
+ input2_val, pad, stride, dilation, quant);
+ }
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+#define BUILD_OP_CONV(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> \
+ TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator * op) const { \
+ return BuildConvOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+BUILD_OP_CONV(Conv2D, CONV2D)
+BUILD_OP_CONV(Conv3D, CONV3D)
+BUILD_OP_CONV(DepthwiseConv2D, DEPTHWISE_CONV2D)
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE_CONV2D>(TosaSerializationOperator* op) const
+{
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_TransposeConvAttribute); // double check attribute type
+ TosaTransposeConvAttribute *attr =
+ static_cast<TosaTransposeConvAttribute *>(op->GetAttribute());
+ mlir::DenseI64ArrayAttr out_pad =
+ BuildDenseI64ArrayAttr(op_builder, attr->out_pad());
+ mlir::DenseI64ArrayAttr stride =
+ BuildDenseI64ArrayAttr(op_builder, attr->stride());
+ mlir::DenseI64ArrayAttr output_shape =
+ BuildDenseI64ArrayAttr(op_builder, attr->output_shape());
+ auto input_zp = attr->input_zp();
+ auto weight_zp = attr->weight_zp();
+
+ // quantizationattr is required for quantized type, and not allowed for float
+ // type
+ mlir::Operation *mlir_op;
+ if (output_type.getElementType().isa<mlir::FloatType>()) {
+ assert(input_zp == 0 && weight_zp == 0);
+ mlir_op = op_builder->create<mlir::tosa::TransposeConv2DOp>(
+ loc, output_type, input0_val, input1_val, input2_val, out_pad, stride,
+ output_shape);
+ } else {
+ auto quant = op_builder->getAttr<mlir::tosa::ConvOpQuantizationAttr>(
+ input_zp, weight_zp);
+ mlir_op = op_builder->create<mlir::tosa::TransposeConv2DOp>(
+ loc, output_type, input0_val, input1_val, input2_val, out_pad, stride,
+ output_shape, quant);
+ }
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_FULLY_CONNECTED>(TosaSerializationOperator* op) const
+{
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_FullyConnectedAttribute); // double check attribute type
+ TosaFullyConnectedAttribute *attr =
+ static_cast<TosaFullyConnectedAttribute *>(op->GetAttribute());
+ auto input_zp = attr->input_zp();
+ auto weight_zp = attr->weight_zp();
+
+ // quantizationattr is required for quantized type, and not allowed for float
+ // type
+ mlir::Operation *mlir_op;
+ if (output_type.getElementType().isa<mlir::FloatType>()) {
+ assert(input_zp == 0 && weight_zp == 0);
+ mlir_op = op_builder->create<mlir::tosa::FullyConnectedOp>(
+ loc, output_type, input0_val, input1_val, input2_val);
+ } else {
+ auto quant = op_builder->getAttr<mlir::tosa::ConvOpQuantizationAttr>(
+ input_zp, weight_zp);
+ mlir_op = op_builder->create<mlir::tosa::FullyConnectedOp>(
+ loc, output_type, input0_val, input1_val, input2_val, quant);
+ }
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_MATMUL>(TosaSerializationOperator* op) const
+{
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_MatMulAttribute); // double check attribute type
+ TosaMatMulAttribute *attr =
+ static_cast<TosaMatMulAttribute *>(op->GetAttribute());
+ auto A_zp = attr->a_zp();
+ auto B_zp = attr->b_zp();
+
+ mlir::Operation *mlir_op;
+ if (A_zp == 0 && B_zp == 0) {
+ mlir_op = op_builder->create<mlir::tosa::MatMulOp>(loc, output_type,
+ input0_val, input1_val);
+ } else {
+ auto quant =
+ op_builder->getAttr<mlir::tosa::MatMulOpQuantizationAttr>(A_zp, B_zp);
+ mlir_op = op_builder->create<mlir::tosa::MatMulOp>(
+ loc, output_type, input0_val, input1_val, quant);
+ }
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_SELECT>(TosaSerializationOperator* op) const
+{
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::SelectOp>(
+ loc, output_type, input0_val, input1_val, input2_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_ClampAttribute); // double check attribute type
+ TosaClampAttribute *attr =
+ static_cast<TosaClampAttribute *>(op->GetAttribute());
+
+ auto min_int = op_builder->getI64IntegerAttr(attr->min_int());
+ auto max_int = op_builder->getI64IntegerAttr(attr->max_int());
+ auto min_fp = op_builder->getF32FloatAttr(attr->min_fp());
+ auto max_fp = op_builder->getF32FloatAttr(attr->max_fp());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ClampOp>(
+ loc, output_type, input_val, min_int, max_int, min_fp, max_fp);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+// ArgMax has single input, and single I64 axis attribute
+BUILD_OP_REDUCTION(ArgMax, ARGMAX)
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONCAT>(TosaSerializationOperator* op) const
+{
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ llvm::SmallVector<mlir::Value> input_values;
+ for (auto &input_name : op->GetInputTensorNames()) {
+ mlir::Value input_val = tensor_map->at(input_name);
+ input_values.push_back(input_val);
+ }
+
+ assert(op->GetAttributeType() ==
+ Attribute_AxisAttribute); // double check attribute type
+ TosaAxisAttribute *attr =
+ static_cast<TosaAxisAttribute *>(op->GetAttribute());
+ auto axis = op_builder->getI64IntegerAttr(attr->axis());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ConcatOp>(
+ loc, output_type, input_values, axis);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_NEGATE>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_NegateAttribute); // double check attribute type
+ TosaNegateAttribute *attr =
+ static_cast<TosaNegateAttribute *>(op->GetAttribute());
+
+ auto input_zp = attr->input1_zp();
+ auto output_zp = attr->output_zp();
+
+ mlir::Operation *mlir_op;
+ if (input_zp == 0 && output_zp == 0) {
+ mlir_op =
+ op_builder->create<mlir::tosa::NegateOp>(loc, output_type, input_val);
+ } else {
+ auto quant = op_builder->getAttr<mlir::tosa::UnaryOpQuantizationAttr>(
+ input_zp, output_zp);
+ mlir_op = op_builder->create<mlir::tosa::NegateOp>(loc, output_type,
+ input_val, quant);
+ }
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESHAPE>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_ReshapeAttribute); // double check attribute type
+ TosaReshapeAttribute *attr =
+ static_cast<TosaReshapeAttribute *>(op->GetAttribute());
+
+ mlir::DenseI64ArrayAttr new_shape =
+ BuildDenseI64ArrayAttr(op_builder, attr->new_shape());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ReshapeOp>(
+ loc, output_type, input_val, new_shape);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_PadAttribute); // double check attribute type
+ TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute());
+
+ auto padding_attr = BuildDenseI32ElementsAttr(op_builder, attr->padding());
+ auto padding_type = mlir::RankedTensorType::get({attr->padding().size()},
+ op_builder->getI32Type());
+ mlir::Operation *mlir_const_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, padding_type, padding_attr);
+ block->push_back(mlir_const_op);
+ auto padding_value = mlir_const_op->getResult(0);
+ auto pad_const_int = attr->pad_const_int();
+ auto pad_const_fp = attr->pad_const_fp();
+ // todo: int input_zp = attr->pad_input_zp();
+
+ mlir::Operation *mlir_op;
+ if (pad_const_int == 0 && pad_const_fp == 0.0f) {
+ // no pad_const input
+ mlir_op = op_builder->create<mlir::tosa::PadOp>(loc, output_type, input_val,
+ padding_value);
+ } else {
+ // create a const value for pad_const input
+ mlir::Value pad_const_value;
+ if (pad_const_int != 0) {
+ auto pad_const_int_type =
+ mlir::RankedTensorType::get({1}, op_builder->getI32Type());
+ auto pad_const_int_attr =
+ mlir::DenseElementsAttr::get(pad_const_int_type, {pad_const_int});
+ mlir::Operation *pad_const_int_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, pad_const_int_type,
+ pad_const_int_attr);
+ block->push_back(pad_const_int_op);
+ pad_const_value = pad_const_int_op->getResult(0);
+ } else if (pad_const_fp != 0) {
+ auto pad_const_fp_type =
+ mlir::RankedTensorType::get({1}, op_builder->getF32Type());
+ auto pad_const_fp_attr =
+ mlir::DenseElementsAttr::get(pad_const_fp_type, {pad_const_fp});
+ mlir::Operation *pad_const_fp_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, pad_const_fp_type,
+ pad_const_fp_attr);
+ block->push_back(pad_const_fp_op);
+ pad_const_value = pad_const_fp_op->getResult(0);
+ }
+ mlir_op = op_builder->create<mlir::tosa::PadOp>(
+ loc, output_type, input_val, padding_value, pad_const_value);
+ }
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_TransposeAttribute); // double check attribute type
+ TosaTransposeAttribute *attr =
+ static_cast<TosaTransposeAttribute *>(op->GetAttribute());
+ // make a constant op from attr->perms values, of type: { shape = { perms.size
+ // }, element_type = I32 }
+ const auto perms_values = attr->perms();
+ auto const_type = mlir::RankedTensorType::get({perms_values.size()},
+ op_builder->getI32Type());
+ mlir::DenseElementsAttr const_attr =
+ BuildDenseI32ElementsAttr(op_builder, perms_values);
+ mlir::Operation *mlir_const_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
+ auto perms_val = mlir_const_op->getResult(0);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TransposeOp>(
+ loc, output_type, input_val, perms_val);
+
+ block->push_back(mlir_const_op);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_SLICE>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_SliceAttribute); // double check attribute type
+ TosaSliceAttribute *attr =
+ static_cast<TosaSliceAttribute *>(op->GetAttribute());
+
+ mlir::DenseI64ArrayAttr start =
+ BuildDenseI64ArrayAttr(op_builder, attr->start());
+ mlir::DenseI64ArrayAttr size =
+ BuildDenseI64ArrayAttr(op_builder, attr->size());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::SliceOp>(
+ loc, output_type, input_val, start, size);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TILE>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_TileAttribute); // double check attribute type
+ TosaTileAttribute *attr =
+ static_cast<TosaTileAttribute *>(op->GetAttribute());
+
+ mlir::DenseI64ArrayAttr multiples =
+ BuildDenseI64ArrayAttr(op_builder, attr->multiples());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TileOp>(
+ loc, output_type, input_val, multiples);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+// Gather is a binary op
+BUILD_OP_ELEMENTWISE_BINARY(Gather, GATHER)
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_SCATTER>(TosaSerializationOperator* op) const
+{
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ScatterOp>(
+ loc, output_type, input0_val, input1_val, input2_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESIZE>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_ResizeAttribute); // double check attribute type
+ TosaResizeAttribute *attr =
+ static_cast<TosaResizeAttribute *>(op->GetAttribute());
+
+ mlir::DenseI64ArrayAttr scale =
+ BuildDenseI64ArrayAttr(op_builder, attr->scale());
+ mlir::DenseI64ArrayAttr offset =
+ BuildDenseI64ArrayAttr(op_builder, attr->offset());
+ mlir::DenseI64ArrayAttr border =
+ BuildDenseI64ArrayAttr(op_builder, attr->border());
+ auto mode = op_builder->getStringAttr(ResizeEnum2Str(attr->mode()));
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ResizeOp>(
+ loc, output_type, input_val, scale, offset, border, mode);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+// Reverse has single input, and single I64 axis attribute
+BUILD_OP_REDUCTION(Reverse, REVERSE)
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_MUL>(TosaSerializationOperator* op) const
+{
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_MulAttribute); // double check attribute type
+ TosaMulAttribute *attr = static_cast<TosaMulAttribute *>(op->GetAttribute());
+
+ auto shift = op_builder->getI32IntegerAttr(attr->shift());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MulOp>(
+ loc, output_type, input0_val, input1_val, shift);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_ARITHMETIC_RIGHT_SHIFT>(TosaSerializationOperator* op) const
+{
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(
+ op->GetAttributeType() ==
+ Attribute_ArithmeticRightShiftAttribute); // double check attribute type
+ TosaArithmeticRightShiftAttribute *attr =
+ static_cast<TosaArithmeticRightShiftAttribute *>(op->GetAttribute());
+
+ auto round = op_builder->getBoolAttr(attr->round());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::ArithmeticRightShiftOp>(
+ loc, output_type, input0_val, input1_val, round);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_TableAttribute); // double check attribute type
+ TosaTableAttribute *attr =
+ static_cast<TosaTableAttribute *>(op->GetAttribute());
+
+ // create a const op for table value attribute
+ const auto table_values = attr->table();
+ auto const_type = mlir::RankedTensorType::get({table_values.size()},
+ op_builder->getI16Type());
+ mlir::DenseElementsAttr const_attr =
+ BuildDenseI16ElementsAttr(op_builder, table_values);
+ mlir::Operation *mlir_const_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
+ auto table_value = mlir_const_op->getResult(0);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TableOp>(
+ loc, output_type, input_val, table_value);
+ block->push_back(mlir_const_op);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESCALE>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_RescaleAttribute); // double check attribute type
+ TosaRescaleAttribute *attr =
+ static_cast<TosaRescaleAttribute *>(op->GetAttribute());
+
+ auto input_zp = op_builder->getI32IntegerAttr(attr->input_zp());
+ auto output_zp = op_builder->getI32IntegerAttr(attr->output_zp());
+
+ auto multiplier = BuildDenseI32ArrayAttr(op_builder, attr->multiplier());
+ auto shift = BuildDenseI32ArrayAttr(op_builder, attr->shift());
+
+ auto scale32 = op_builder->getBoolAttr(attr->scale32());
+ auto double_round = op_builder->getBoolAttr(attr->double_round());
+ auto per_channel = op_builder->getBoolAttr(attr->per_channel());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RescaleOp>(
+ loc, output_type, input_val, input_zp, output_zp, multiplier, shift,
+ scale32, double_round, per_channel);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CUSTOM>(TosaSerializationOperator* op) const
+{
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_CustomAttribute); // double check attribute type
+ TosaCustomAttribute *attr =
+ static_cast<TosaCustomAttribute *>(op->GetAttribute());
+
+ auto identifier = op_builder->getStringAttr(attr->identifier());
+ auto config = op_builder->getStringAttr(attr->config());
+ std::string impl_str;
+ impl_str.resize(attr->implementation_attrs().size() + 1);
+ int idx = 0;
+ for (auto c : attr->implementation_attrs()) {
+ impl_str[idx++] = c;
+ }
+ auto impl = op_builder->getStringAttr(impl_str);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::CustomOp>(
+ loc, output_type, identifier, config, impl, input_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+// TosaSerializationBasicBlock
+// template <>
+// std::vector<mlir::Value>
+// TosaMlirOperatorBuilder::build<COND_IF>(TosaSerializationOperator* op) const
+//{
+// // todo: mlir::tosa::IfOp
+// return {};
+//}
+
+// TosaSerializationBasicBlock
+// template <>
+// std::vector<mlir::Value>
+// TosaMlirOperatorBuilder::build<WHILE_LOOP>(TosaSerializationOperator* op)
+// const
+//{
+// // todo: mlir::tosa::WhileOp
+// return {};
+//}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_RFFT2D>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output0_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ mlir::RankedTensorType output1_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[1]);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RFFT2dOp>(
+ loc, output0_type, output1_type, input_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(
+ {mlir_op->getResult(0), mlir_op->getResult(1)});
+}
+
+class TosaMlirRegionBuilder {
+public:
+ TosaMlirRegionBuilder(TosaSerializationRegion* _ser_region,
+ TosaSerializationHandler* _tsh,
+ mlir::Region* _region,
+ mlir::OpBuilder* _op_builder,
+ mlir::Location _loc)
+ : ser_region(_ser_region), tsh(_tsh), region(_region), op_builder(_op_builder), loc(_loc) {}
+
+ mlir::LogicalResult BuildAllBlocksInRegion(std::vector<mlir::Value>& return_values);
+
+ mlir::OpBuilder* GetOpBuilder() { return op_builder; }
+ mlir::Location GetLocation() { return loc; }
+
+private:
+ mlir::Region* region;
+ TosaSerializationRegion* ser_region;
+ TosaSerializationHandler* tsh;
+ mlir::OpBuilder* op_builder;
+ mlir::Location loc;
+ std::vector<TosaMlirBlockBuilder*> block_builders;
+};
+
+class TosaMlirBlockBuilder {
+public:
+ TosaMlirBlockBuilder(TosaSerializationBasicBlock* _ser_block,
+ TosaMlirRegionBuilder* _region_builder,
+ mlir::Block* _block)
+ : ser_block(_ser_block), region_builder(_region_builder), block(_block) {}
+
+
+ mlir::LogicalResult
+ BuildAllOpsInBlock(std::vector<mlir::Value>& return_values);
+
+ mlir::OpBuilder* GetOpBuilder() { return region_builder->GetOpBuilder(); }
+ mlir::Location GetLocation() { return region_builder->GetLocation(); }
+
+private:
+ TosaSerializationBasicBlock* ser_block;
+ TosaMlirRegionBuilder* region_builder;
+ mlir::Block* block;
+ std::unordered_map<std::string, mlir::Value> tensor_map;
+ std::unordered_map<std::string, mlir::RankedTensorType> tensor_type_map;
+};
+
+mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
+ std::vector<mlir::Value> &return_values) {
+ block->clear();
+ auto loc = GetLocation();
+ auto op_builder = GetOpBuilder();
+
+ std::unordered_map<std::string, std::vector<TosaSerializationOperator*>> consumer_map;
+ std::unordered_map<std::string, bool> tensor_built;
+ std::unordered_map<TosaSerializationOperator*, bool> operator_built;
+ std::queue<TosaSerializationOperator*> operator_queue;
+
+ TosaMlirOperatorBuilder tosa_op_builder(op_builder, ser_block, block, loc,
+ &tensor_map, &tensor_type_map);
+
+ for (auto ts : ser_block->GetTensors()) {
+ mlir::RankedTensorType type;
+ if (BuildTensorType(op_builder, ts, type).failed()) {
+ return mlir::failure();
+ }
+ const auto& ts_name = ts->GetName();
+ tensor_type_map[ts_name] = type;
+ tensor_built[ts_name] = false;
+ }
+
+ for (auto op : ser_block->GetOperators()) {
+ operator_built[op] = false;
+ for (auto ts_name : op->GetInputTensorNames()) {
+ consumer_map[ts_name].push_back(op);
+ }
+ }
+
+ // Update operator_queue if a consumer of tensor_name has all of its inputs already built
+ auto queue_ready_consumers = [&](const std::string tensor_name) {
+ for (auto consumer_op : consumer_map[tensor_name]) {
+ // Sanity check operator hasn't been built
+ if (operator_built[consumer_op]) {
+ llvm::errs() << "ERROR: " << tosa_op_builder.get_string(consumer_op)
+ << " is already built before its input is built\n";
+ assert(0);
+ }
+ bool all_inputs_ready = true;
+ for (const auto& input_name : consumer_op->GetInputTensorNames()) {
+ if (!tensor_built[input_name]) {
+ all_inputs_ready = false;
+ break;
+ }
+ }
+ if (all_inputs_ready) {
+ operator_queue.push(consumer_op);
+ }
+ }
+ };
+
+ // Initialize tensor_map/tensor_built/operator_queue based on block input arguments
+ for (const std::string& block_input_name : ser_block->GetInputs()) {
+ auto type = tensor_type_map[block_input_name];
+ auto input_value = block->addArgument(type, loc);
+ if (tensor_map.count(block_input_name)) {
+ llvm::errs() << "ERROR: block input tensor " << block_input_name << " already exists\n";
+ return mlir::failure();
+ }
+ tensor_built[block_input_name] = true;
+ tensor_map[block_input_name] = input_value;
+ queue_ready_consumers(block_input_name);
+ }
+
+ // add all operators with 0 inputs (e.g., constant operators) to
+ // operator_queue
+ for (auto op : ser_block->GetOperators()) {
+ if (op->GetInputTensorNames().empty()) {
+ operator_queue.push(op);
+ }
+ }
+
+ while (!operator_queue.empty()) {
+ TosaSerializationOperator* op = operator_queue.front();
+ operator_queue.pop();
+
+ // skip if operator has been built
+ if (operator_built[op]) {
+ // this happens when same input appears twice or more in operator, eg, concat(%0, %0)
+ continue;
+ }
+ operator_built[op] = true;
+
+ std::vector<mlir::Value> output_values;
+ if (false) {
+ }
+#define DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME) \
+ else if (op->GetOp() == Op_##SCHEMA_OP_NAME) { \
+ output_values = tosa_op_builder.build<Op_##SCHEMA_OP_NAME>(op); \
+ }
+#include "schema_operator.def"
+#undef DEF_SCHEMA_OPERATOR
+ else {
+ llvm::errs() << "ERROR: unsupported opcode=" << EnumNamesOp()[op->GetOp()] << "\n";
+ return mlir::failure();
+ }
+
+ // Sanity check if number of built mlir::Value is expected
+ if (op->GetOutputTensorNames().size() != output_values.size()) {
+ llvm::errs() << "ERROR: number of built mlir::Value is not matching number of operator output tensor\n";
+ return mlir::failure();
+ }
+
+ for (size_t i = 0; i < output_values.size(); i++) {
+ // Sanity check tensor hasn't been built
+ std::string op_output_name = op->GetOutputTensorNames()[i];
+ if (tensor_built[op_output_name]) {
+ llvm::errs() << "ERROR: tensor " << op_output_name << " is already built\n";
+ return mlir::failure();
+ }
+ tensor_map[op_output_name] = output_values[i];
+ tensor_built[op_output_name] = true;
+ queue_ready_consumers(op_output_name);
+ }
+ }
+
+ // Construct return values
+ std::vector<mlir::Value> return_operands;
+ for (const auto& output_name : ser_block->GetOutputs()) {
+ // Sanity check if terminator mlir::Value is built
+ if (!tensor_built[output_name]) {
+ llvm::errs() << "ERROR: terminator mlir::Value " << output_name << " is not built\n";
+ return mlir::failure();
+ }
+ mlir::Value output_value = tensor_map.at(output_name);
+ return_operands.push_back(output_value);
+ return_values.push_back(output_value);
+ }
+ auto terminator_op = op_builder->create<mlir::func::ReturnOp>(loc, return_operands);
+ block->push_back(terminator_op);
+
+ // need topological sorting?
+
+ return mlir::success();
+}
+
+mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion(
+ std::vector<mlir::Value>& return_values) {
+ for (auto& ser_block : ser_region->GetBlocks()) {
+ auto& block = region->emplaceBlock();
+ TosaMlirBlockBuilder block_builder(ser_block, this, &block);
+ // Region Builders need access to block builders (?)
+ block_builders.push_back(&block_builder);
+
+ if (block_builder.BuildAllOpsInBlock(return_values).failed()) {
+ return mlir::failure();
+ }
+
+ if (return_values.empty()) {
+ llvm::errs() << "Warning: graph doesn't have return values\n";
+ }
+ }
+ return mlir::success();
+}
+
+mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func,
+ mlir::MLIRContext& context,
+ tosa::TosaSerializationHandler& tsh) {
+
+ mlir::Region* main_region = func.getCallableRegion();
+ if (!main_region) {
+ llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n";
+ return mlir::failure();
+ }
+
+ if (tsh.GetRegions().size() != 1) {
+ llvm::errs() << "Internal Error: TosaSerializationHandler's region list "
+ "must contain exactly one region\n";
+ return mlir::failure();
+ }
+
+ TosaSerializationRegion* ser_main_region = tsh.GetRegions().front();
+
+ auto loc = func.getLoc();
+ std::vector<mlir::Value> main_returns;
+
+ main_region->takeBody(*main_region); // empty old func body
+ auto op_builder = mlir::OpBuilder(func.getBody());
+
+ TosaMlirRegionBuilder region_builder(ser_main_region, &tsh, main_region, &op_builder, loc);
+ if (region_builder.BuildAllBlocksInRegion(main_returns).failed()) {
+ return mlir::failure();
+ }
+
+ if (main_returns.empty()) {
+ llvm::errs() << "Warning: graph doesn't have return values\n";
+ }
+
+ return mlir::success();
+}
+
+// Load Tosa Schema into TosaSerializationHandler, required for JSON save/load
+mlir::LogicalResult loadTosaSchema(tosa::TosaSerializationHandler& tsh) {
+ const char *tosa_schema = tosa_deserialize_schema.c_str();
+
+ if (!tosa_schema) {
+ llvm::errs() << "Flatbuffer schema not defined\n";
+ return mlir::failure();
+ }
+
+ if (tsh.LoadFileSchema(tosa_schema)) {
+ llvm::errs() << "Error loading tosa schema file: " << tosa_schema << "\n";
+ return mlir::failure();
+ }
+ return mlir::success();
+}
+
+namespace mlir {
+
+namespace tosa {
+
+namespace {
+
+class TosaDeserialize : public TosaDeserializationPassBase<TosaDeserialize> {
+public:
+ void runOnOperation() final {
+ TosaSerializationHandler tsh;
+ if (tsh.LoadFileTosaFlatbuffer(tosa_deserialize_filename.c_str())) {
+ llvm::errs() << "Fail to load TOSA file " << tosa_deserialize_filename << "\n";
+ return signalPassFailure();
+ }
+
+ auto function = getOperation();
+ auto& context = getContext();
+
+ if (buildTosaMlir(function, context, tsh).failed()) {
+ llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n";
+ return signalPassFailure();
+ }
+ }
+};
+
+class TosaDeserializeJSON
+ : public TosaDeserializationJSONPassBase<TosaDeserializeJSON> {
+public:
+ void runOnOperation() final {
+ TosaSerializationHandler tsh;
+
+ // must load tosa schema before loading json file
+ if (loadTosaSchema(tsh).failed()) {
+ return signalPassFailure();
+ }
+
+ if (tsh.LoadFileJson(tosa_deserialize_filename.c_str())) {
+ llvm::errs() << "Fail to load TOSA JSON file " << tosa_deserialize_filename << "\n";
+ return signalPassFailure();
+ }
+
+ auto function = getOperation();
+ auto& context = getContext();
+
+ if (buildTosaMlir(function, context, tsh).failed()) {
+ llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n";
+ return signalPassFailure();
+ }
+ }
+};
+
+} // anonymous namespace
+
+// Creates an instance of the TOSA flatbuffer deserialization pass
+std::unique_ptr<Pass> createTosaDeserializePass() {
+ return std::make_unique<TosaDeserialize>();
+}
+
+std::unique_ptr<Pass> createTosaDeserializeJSONPass() {
+ return std::make_unique<TosaDeserializeJSON>();
+}
+
+static PassRegistration<TosaDeserialize> passDeserialize([] {
+ return createTosaDeserializePass();
+});
+
+static PassRegistration<TosaDeserializeJSON> passDeserializeJSON([] {
+ return createTosaDeserializeJSONPass();
+});
+
+} // namespace tosa
+} // namespace mlir