aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp2128
1 files changed, 2128 insertions, 0 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
new file mode 100644
index 0000000..215d760
--- /dev/null
+++ b/src/TosaDeserialize.cpp
@@ -0,0 +1,2128 @@
+
+// Copyright (c) 2023-2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://llvm.org/LICENSE.txt
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// TOSA MLIR deserialize passes
+
+#include "include/DeserializationPasses.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
+#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
+#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
+#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
+#include "mlir/IR/Matchers.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "tosa_serialization_handler.h"
+#include <functional>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+// The namespace might be confusing here. We have mlir::tosa:: defined in MLIR
+// and tosa:: defined in serialization library
+// TODO: align the namespace
+using namespace tosa;
+
+namespace cl = llvm::cl;
+
+llvm::cl::opt<std::string> tosa_deserialize_filename(
+ "tosa-deserialize-filename", llvm::cl::desc("<tosa flatbuffer filename>"),
+ llvm::cl::init("tosa_dump.tosa"), llvm::cl::value_desc("filename"));
+
+llvm::cl::opt<std::string> tosa_deserialize_schema(
+ "tosa-deserialize-schema", llvm::cl::desc("<tosa flatbuffer schema file>"),
+ llvm::cl::init(""), llvm::cl::value_desc("filename"));
+
+const std::string kDefaultExportedName = "tosa_deserialized";
+const std::string kDefaultInputPrefix = "input_";
+const std::string kDefaultOutputPrefix = "output_";
+const std::string kDefaultFBSDescription = "Tosa FBS Converted";
+const std::string kDefaultJSONDescription = "Tosa JSON Converted";
+const std::string kMainFunctionName = "main";
+
+namespace {
+
+// a global map from flatbuffer variable names to serialized tensors
+std::unordered_map<std::string, TosaSerializationTensor *> variable_tensor_map;
+
+void RegisterVariableTensor(TosaSerializationTensor *ts) {
+ assert(ts->GetVariable());
+ // insert variable tensor ts only if not already present
+ variable_tensor_map.insert({ts->GetName(), ts});
+}
+
+bool IsVariableTensor(const std::string flatbuffer_tensor_name) {
+ return variable_tensor_map.count(flatbuffer_tensor_name);
+}
+
+// return the variable name corresponding to flatbuffer_tensor_name
+const std::string GetVariableTensorName(TosaSerializationTensor *ts) {
+ assert(ts->GetVariable());
+ const auto name = ts->GetVariableName();
+ if (name == "") {
+ // for legacy flatbuffers which may not have variable_name fields
+ return ts->GetName();
+ }
+ return name;
+}
+
+// return the variable name corresponding to flatbuffer_tensor_name
+const std::string
+GetVariableTensorName(const std::string flatbuffer_tensor_name) {
+ if (!IsVariableTensor(flatbuffer_tensor_name)) {
+ llvm::errs() << "ERROR: Variable tensor " << flatbuffer_tensor_name
+ << " is not found in variable_tensor_map";
+ return "";
+ }
+ return GetVariableTensorName(variable_tensor_map[flatbuffer_tensor_name]);
+}
+
+bool IsVariableReadOp(TosaSerializationOperator *op) {
+ return (op->GetOp() == tosa::Op::Op_IDENTITY) &&
+ IsVariableTensor(op->GetInputTensorNames()[0]);
+}
+
+bool IsVariableWriteOp(TosaSerializationOperator *op) {
+ return (op->GetOp() == tosa::Op::Op_IDENTITY) &&
+ IsVariableTensor(op->GetOutputTensorNames()[0]);
+}
+
+// construct tensor type from dtype and shape of TosaSerializationTensor
+mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
+ TosaSerializationTensor *ts,
+ mlir::RankedTensorType &type) {
+ mlir::Type element_type;
+ switch (ts->GetDtype()) {
+ case DType_BOOL:
+ element_type = op_builder->getI1Type();
+ break;
+ case DType_UINT8:
+ element_type = op_builder->getIntegerType(8, false);
+ break;
+ case DType_INT4:
+ element_type = op_builder->getI4Type();
+ break;
+ case DType_INT8:
+ element_type = op_builder->getI8Type();
+ break;
+ case DType_INT16:
+ element_type = op_builder->getIntegerType(16);
+ break;
+ case DType_INT32:
+ element_type = op_builder->getI32Type();
+ break;
+ case DType_INT48:
+ element_type = op_builder->getIntegerType(48);
+ break;
+ case DType_FP32:
+ element_type = op_builder->getF32Type();
+ break;
+ case DType_UINT16:
+ element_type = op_builder->getIntegerType(16, false);
+ break;
+ case DType_FP16:
+ element_type = op_builder->getF16Type();
+ break;
+ case DType_BF16:
+ element_type = op_builder->getBF16Type();
+ break;
+ case DType_FP8E4M3:
+ element_type = op_builder->getFloat8E4M3FNType();
+ break;
+ case DType_FP8E5M2:
+ element_type = op_builder->getFloat8E5M2Type();
+ break;
+ case DType_SHAPE:
+ llvm::errs()
+ << "ERROR: Cannot construct RankedTensorType out of tosa.shape type \n";
+ return mlir::failure();
+ default:
+ llvm::errs() << "ERROR: unknown type " << EnumNamesDType()[ts->GetDtype()]
+ << "\n";
+ return mlir::failure();
+ }
+ llvm::SmallVector<int64_t> shape;
+ for (auto dim : ts->GetShape()) {
+ if (dim > 0) {
+ shape.push_back(dim);
+ } else {
+ // dynamic dim
+ shape.push_back(mlir::ShapedType::kDynamic);
+ }
+ }
+ type = mlir::RankedTensorType::get(llvm::ArrayRef(shape), element_type);
+ return mlir::success();
+}
+
+mlir::DenseElementsAttr GetConstAttr(const std::vector<uint8_t> &data,
+ const mlir::RankedTensorType &output_type,
+ uint32_t out_size) {
+ auto element_type = output_type.getElementType();
+ if (element_type.isF32()) {
+ // for FP32, value attributes are stored as FP32 values
+ std::vector<float> float_data;
+ TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data);
+ return mlir::DenseElementsAttr::get(output_type,
+ llvm::ArrayRef(float_data));
+ }
+ if (element_type.isBF16()) {
+ mlir::SmallVector<mlir::APFloat> bf16_data;
+ for (uint32_t i = 0; i < out_size; i++) {
+ uint64_t byte0 = data[i * sizeof(int16_t)];
+ uint64_t byte1 = data[i * sizeof(int16_t) + 1];
+ uint64_t bits = byte0 + (byte1 << 8);
+ mlir::APInt bf16_bits(16, bits);
+ mlir::APFloat bf16(mlir::APFloat::BFloat(), bf16_bits);
+ bf16_data.push_back(bf16);
+ }
+ return mlir::DenseElementsAttr::get(output_type, bf16_data);
+ }
+ if (element_type.isFloat8E4M3FN()) {
+ mlir::SmallVector<mlir::APFloat> f8_data;
+ for (uint32_t i = 0; i < out_size; i++) {
+ mlir::APInt f8_bits(8, static_cast<uint64_t>(data[i]));
+ mlir::APFloat f8(mlir::APFloat::Float8E4M3FN(), f8_bits);
+ f8_data.push_back(f8);
+ }
+ return mlir::DenseElementsAttr::get(output_type, f8_data);
+ }
+ if (element_type.isFloat8E5M2()) {
+ mlir::SmallVector<mlir::APFloat> f8_data;
+ for (uint32_t i = 0; i < out_size; i++) {
+ mlir::APInt f8_bits(8, static_cast<uint64_t>(data[i]));
+ mlir::APFloat f8(mlir::APFloat::Float8E5M2(), f8_bits);
+ f8_data.push_back(f8);
+ }
+ return mlir::DenseElementsAttr::get(output_type, f8_data);
+ }
+ if (element_type.isInteger(4)) {
+ std::vector<int8_t> int4_data;
+ TosaSerializationHandler::ConvertU8toI4(data, out_size, int4_data);
+ return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data));
+ }
+ if (element_type.isInteger(8)) {
+ std::vector<int8_t> int8_data;
+ TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data);
+ return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data));
+ }
+ if (element_type.isInteger(16)) {
+ std::vector<int16_t> int16_data;
+ TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data);
+ return mlir::DenseElementsAttr::get(output_type,
+ llvm::ArrayRef(int16_data));
+ }
+ if (element_type.isInteger(32)) {
+ std::vector<int32_t> int32_data;
+ TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data);
+ return mlir::DenseElementsAttr::get(output_type,
+ llvm::ArrayRef(int32_data));
+ }
+ if (element_type.isInteger(48)) {
+ std::vector<int64_t> int48_data;
+ TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data);
+ std::vector<mlir::APInt> apint_data;
+ for (const auto v : int48_data) {
+ mlir::APInt apint_value(48, static_cast<uint64_t>(v),
+ /* isSigned = */ false);
+ apint_data.push_back(apint_value);
+ }
+ return mlir::DenseElementsAttr::get(output_type,
+ llvm::ArrayRef(apint_data));
+ }
+ if (element_type.isInteger(1)) {
+ std::vector<bool> bool_data;
+ TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data);
+ llvm::SmallVector<bool> bool_values(bool_data.begin(), bool_data.end());
+ return mlir::DenseElementsAttr::get(output_type, bool_values);
+ }
+ if (element_type.isF16()) {
+ std::vector<half_float::half> half_data;
+ TosaSerializationHandler::ConvertU8toF16(data, out_size, half_data);
+ return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(half_data));
+ }
+
+ return nullptr;
+}
+
+mlir::DenseElementsAttr
+ConstructConstAttr(const mlir::RankedTensorType &output_type,
+ TosaSerializationTensor *ts, const std::string &op_name) {
+ // compute output data size
+ uint32_t out_size = 1;
+ for (const auto dim : ts->GetShape()) {
+ out_size *= dim;
+ }
+ auto attr = GetConstAttr(ts->GetData(), output_type, out_size);
+ if (!attr) {
+ llvm::errs() << "ERROR: " << op_name
+ << " contains unsupported element type\n";
+ }
+ return attr;
+}
+
+mlir::LogicalResult ConstructVariableOps(mlir::ModuleOp &module) {
+ if (variable_tensor_map.empty()) {
+ return mlir::success();
+ }
+ auto loc = module.getLoc();
+ auto op_builder = mlir::OpBuilder(module.getBodyRegion());
+ for (auto [flatbuffer_name, ts] : variable_tensor_map) {
+ auto name = GetVariableTensorName(ts);
+ mlir::RankedTensorType type;
+ if (BuildTensorType(&op_builder, ts, type).failed()) {
+ return mlir::failure();
+ }
+
+ mlir::Attribute value_attr = nullptr;
+ if (!ts->GetData().empty()) {
+ value_attr = ConstructConstAttr(type, ts, name);
+ }
+ op_builder.create<mlir::tosa::VariableOp>(loc, llvm::StringRef(name), type,
+ value_attr);
+ }
+
+ return mlir::success();
+}
+
+template <class T>
+mlir::DenseElementsAttr BuildDenseI8ElementsAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ llvm::SmallVector<int8_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ auto type = mlir::RankedTensorType::get({static_cast<int64_t>(vec.size())},
+ op_builder->getI8Type());
+ return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec));
+}
+
+template <class T>
+mlir::DenseElementsAttr
+BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ llvm::SmallVector<int16_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ auto type = mlir::RankedTensorType::get({static_cast<int64_t>(vec.size())},
+ op_builder->getI16Type());
+ return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec));
+}
+
+template <class T>
+mlir::DenseElementsAttr
+BuildDenseI32ElementsAttr(mlir::OpBuilder *op_builder,
+ mlir::RankedTensorType &type,
+ const std::vector<T> &values) {
+ llvm::SmallVector<int32_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec));
+}
+
+template <class T>
+mlir::DenseI8ArrayAttr BuildDenseI8ArrayAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int8_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ return op_builder->getDenseI8ArrayAttr(vec);
+}
+
+template <class T>
+mlir::DenseI32ArrayAttr BuildDenseI32ArrayAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int32_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ return op_builder->getDenseI32ArrayAttr(vec);
+}
+
+template <class T>
+mlir::DenseI64ArrayAttr BuildDenseI64ArrayAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int64_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ return op_builder->getDenseI64ArrayAttr(vec);
+}
+
+const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) {
+ if (mode == ResizeMode_NEAREST) {
+ return "NEAREST_NEIGHBOR";
+ } else if (mode == ResizeMode_BILINEAR) {
+ return "BILINEAR";
+ }
+ return "";
+}
+
+// this is a counter part to Type2AccDType
+mlir::Type AccDType2Type(mlir::OpBuilder *op_builder, DType dtype) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
+ if (dtype == DType_INT32) {
+ return op_builder->getI32Type();
+ } else if (dtype == DType_INT48) {
+ return op_builder->getIntegerType(48);
+ } else if (dtype == DType_FP32) {
+ return op_builder->getF32Type();
+ } else if (dtype == DType_FP16) {
+ return op_builder->getF16Type();
+ } else {
+ // unknown acc type
+ // for now, default to F32
+ return op_builder->getF32Type();
+ }
+}
+
+} // namespace
+
+class TosaMlirRegionBuilder;
+class TosaMlirBlockBuilder;
+
+class TosaMlirOperatorBuilder {
+public:
+ TosaMlirOperatorBuilder(
+ mlir::OpBuilder *_op_builder, TosaSerializationBasicBlock *_ser_block,
+ mlir::Block *_block, mlir::Location _loc,
+ TosaMlirBlockBuilder *_block_builder,
+ std::unordered_map<std::string, mlir::Value> *_tensor_map,
+ std::unordered_map<std::string, mlir::RankedTensorType> *_tensor_type_map,
+ std::unordered_map<std::string, mlir::tosa::shapeType> *_shape_type_map)
+ : op_builder(_op_builder), ser_block(_ser_block), block(_block),
+ loc(_loc), block_builder(_block_builder), tensor_map(_tensor_map),
+ tensor_type_map(_tensor_type_map), shape_type_map(_shape_type_map) {}
+
+ template <Op OPCODE>
+ std::vector<mlir::Value> build(TosaSerializationOperator *op) const;
+
+ std::vector<mlir::Value> BuildVariableOp(TosaSerializationOperator *op) const;
+
+ std::vector<mlir::Value>
+ BuildVariableReadOp(TosaSerializationOperator *op) const;
+
+ void BuildVariableWriteOp(TosaSerializationOperator *op) const;
+
+ std::string get_string(TosaSerializationOperator *op) const {
+ std::string op_string;
+ op_string += "operator opcode=";
+ op_string += EnumNamesOp()[op->GetOp()];
+ op_string += ", input=[";
+ for (auto ts : op->GetInputTensorNames()) {
+ op_string += (ts + " ");
+ }
+ op_string += "], output=[";
+ for (auto ts : op->GetOutputTensorNames()) {
+ op_string += (ts + " ");
+ }
+ op_string += "]";
+ return op_string;
+ }
+
+ TosaSerializationHandler *GetTsh() const;
+ TosaMlirRegionBuilder *GetRegionBuilder() const;
+
+private:
+ template <class MLIR_OP>
+ std::vector<mlir::Value>
+ BuildEwiseUnaryOp(TosaSerializationOperator *op) const;
+
+ template <class MLIR_OP>
+ std::vector<mlir::Value>
+ BuildEwiseBinaryOp(TosaSerializationOperator *op) const;
+
+ template <class MLIR_OP>
+ std::vector<mlir::Value>
+ BuildEwiseBinaryShapeOp(TosaSerializationOperator *op) const;
+
+ template <class MLIR_OP>
+ std::vector<mlir::Value>
+ BuildReductionOp(TosaSerializationOperator *op) const;
+
+ template <class T>
+ mlir::Value BuildConstShape(mlir::OpBuilder *op_builder, mlir::Location loc,
+ const std::vector<T> &values) const;
+
+ template <class MLIR_OP>
+ std::vector<mlir::Value> BuildConvOp(TosaSerializationOperator *op) const;
+
+ mlir::OpBuilder *op_builder;
+ TosaSerializationBasicBlock *ser_block;
+ mlir::Block *block;
+ mlir::Location loc;
+ TosaMlirBlockBuilder *block_builder;
+ std::unordered_map<std::string, mlir::Value> *tensor_map;
+ std::unordered_map<std::string, mlir::RankedTensorType> *tensor_type_map;
+ std::unordered_map<std::string, mlir::tosa::shapeType> *shape_type_map;
+};
+
+// Main template to catch unimplemented translation
+template <Op OPCODE>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const {
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " translation hasn't been implemented\n";
+ return {};
+}
+
+// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D)
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_MAX_POOL2D>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_PoolAttribute); // double check attribute type
+ TosaPoolAttribute *attr =
+ static_cast<TosaPoolAttribute *>(op->GetAttribute());
+ mlir::DenseI64ArrayAttr kernel =
+ BuildDenseI64ArrayAttr(op_builder, attr->kernel());
+ mlir::DenseI64ArrayAttr stride =
+ BuildDenseI64ArrayAttr(op_builder, attr->stride());
+ mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
+ int32_t input_zp = attr->input_zp();
+ int32_t output_zp = attr->output_zp();
+ assert(input_zp == 0 && output_zp == 0);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MaxPool2dOp>(
+ loc, output_type, input_val, kernel, stride, pad);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+// BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D)
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_PoolAttribute); // double check attribute type
+ TosaPoolAttribute *attr =
+ static_cast<TosaPoolAttribute *>(op->GetAttribute());
+ mlir::DenseI64ArrayAttr kernel =
+ BuildDenseI64ArrayAttr(op_builder, attr->kernel());
+ mlir::DenseI64ArrayAttr stride =
+ BuildDenseI64ArrayAttr(op_builder, attr->stride());
+ mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
+ auto acc_attr =
+ mlir::TypeAttr::get(AccDType2Type(op_builder, attr->acc_type()));
+
+ int32_t input_zp = attr->input_zp();
+ int32_t output_zp = attr->output_zp();
+ mlir::Operation *mlir_op;
+ if (input_zp == 0 && output_zp == 0) {
+ mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>(
+ loc, output_type, input_val, kernel, stride, pad, acc_attr);
+ } else {
+ auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp);
+ auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp);
+ mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>(
+ loc, output_type, input_val, kernel, stride, pad, acc_attr,
+ input_zp_attr, output_zp_attr);
+ }
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildVariableReadOp(
+ TosaSerializationOperator *op) const {
+ auto input_tensor_name = op->GetInputTensorNames()[0];
+ auto output_tensor_name = op->GetOutputTensorNames()[0];
+
+ assert(IsVariableTensor(input_tensor_name));
+
+ auto variable_name = GetVariableTensorName(input_tensor_name);
+ mlir::RankedTensorType output_type = tensor_type_map->at(output_tensor_name);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::VariableReadOp>(
+ loc, output_type, llvm::StringRef(variable_name));
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+void TosaMlirOperatorBuilder::BuildVariableWriteOp(
+ TosaSerializationOperator *op) const {
+ auto input_tensor_name = op->GetInputTensorNames()[0];
+ auto output_tensor_name = op->GetOutputTensorNames()[0];
+
+ assert(IsVariableTensor(output_tensor_name));
+ auto variable_name = GetVariableTensorName(output_tensor_name);
+ mlir::Value input_val = tensor_map->at(input_tensor_name);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::VariableWriteOp>(
+ loc, llvm::StringRef(variable_name), input_val);
+ block->push_back(mlir_op);
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseUnaryOp(
+ TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op =
+ op_builder->create<MLIR_OP>(loc, output_type, input_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseBinaryOp(
+ TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op =
+ op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseBinaryShapeOp(
+ TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::tosa::shapeType output_type =
+ shape_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op =
+ op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::BuildReductionOp(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_AxisAttribute); // double check attribute type
+
+ TosaAxisAttribute *attr =
+ static_cast<TosaAxisAttribute *>(op->GetAttribute());
+ auto axis = op_builder->getI32IntegerAttr(attr->axis());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<MLIR_OP>(loc, output_type, input_val, axis);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+#define BUILD_OP_ELEMENTWISE_UNARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> \
+ TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator * op) const { \
+ return BuildEwiseUnaryOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+#define BUILD_OP_ELEMENTWISE_BINARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> \
+ TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator * op) const { \
+ return BuildEwiseBinaryOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+#define BUILD_OP_ELEMENTWISE_BINARY_SHAPE(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> \
+ TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator * op) const { \
+ return BuildEwiseBinaryShapeOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+#define BUILD_OP_REDUCTION(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> \
+ TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator * op) const { \
+ return BuildReductionOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D)
+// BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D)
+
+BUILD_OP_ELEMENTWISE_BINARY(Add, ADD)
+BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND)
+BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR)
+BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR)
+BUILD_OP_ELEMENTWISE_BINARY(IntDiv, INTDIV)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalOr, LOGICAL_OR)
+BUILD_OP_ELEMENTWISE_BINARY(LogicalXor, LOGICAL_XOR)
+BUILD_OP_ELEMENTWISE_BINARY(Maximum, MAXIMUM)
+BUILD_OP_ELEMENTWISE_BINARY(Minimum, MINIMUM)
+BUILD_OP_ELEMENTWISE_BINARY(Pow, POW)
+BUILD_OP_ELEMENTWISE_BINARY(Sub, SUB)
+
+BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS)
+BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT)
+BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL)
+BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ)
+BUILD_OP_ELEMENTWISE_UNARY(Cos, COS)
+BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP)
+BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR)
+BUILD_OP_ELEMENTWISE_UNARY(Log, LOG)
+BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT)
+BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL)
+BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT)
+BUILD_OP_ELEMENTWISE_UNARY(Sin, SIN)
+
+BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY)
+BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL)
+BUILD_OP_REDUCTION(ReduceMax, REDUCE_MAX)
+BUILD_OP_REDUCTION(ReduceMin, REDUCE_MIN)
+BUILD_OP_REDUCTION(ReduceProd, REDUCE_PRODUCT)
+BUILD_OP_REDUCTION(ReduceSum, REDUCE_SUM)
+
+BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL)
+BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER)
+BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL)
+
+BUILD_OP_ELEMENTWISE_UNARY(Erf, ERF)
+BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID)
+BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH)
+BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY)
+BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST)
+
+BUILD_OP_ELEMENTWISE_BINARY_SHAPE(AddShape, ADD_SHAPE)
+BUILD_OP_ELEMENTWISE_BINARY_SHAPE(SubShape, SUB_SHAPE)
+BUILD_OP_ELEMENTWISE_BINARY_SHAPE(MulShape, MUL_SHAPE)
+BUILD_OP_ELEMENTWISE_BINARY_SHAPE(DivShape, DIV_SHAPE)
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_CONST>(TosaSerializationOperator *op) const {
+ const auto &output_name = op->GetOutputTensorNames()[0];
+ mlir::RankedTensorType output_type = tensor_type_map->at(output_name);
+ TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name);
+ auto value_attr = ConstructConstAttr(output_type, ts, get_string(op));
+ if (!value_attr) {
+ return {};
+ }
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, output_type, value_attr);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <class T>
+mlir::Value
+TosaMlirOperatorBuilder::BuildConstShape(mlir::OpBuilder *op_builder,
+ mlir::Location loc,
+ const std::vector<T> &values) const {
+ std::vector<int64_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ auto attr = op_builder->getIndexTensorAttr(vec);
+ auto type = mlir::tosa::shapeType::get(op_builder->getContext(),
+ /* rank = */ vec.size());
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::ConstShapeOp>(loc, type, attr);
+ block->push_back(mlir_op);
+ return mlir_op->getResult(0);
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONST_SHAPE>(
+ TosaSerializationOperator *op) const {
+ const auto &output_name = op->GetOutputTensorNames()[0];
+ mlir::tosa::shapeType output_type = shape_type_map->at(output_name);
+ TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name);
+
+ const auto &data = ts->GetData();
+
+ std::vector<int64_t> i64_data;
+ TosaSerializationHandler::ConvertU8toI64(data, output_type.getRank(),
+ i64_data);
+ mlir::Value result = BuildConstShape(op_builder, loc, i64_data);
+ return std::vector<mlir::Value>({result});
+}
+
+template <class MLIR_OP>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_ConvAttribute); // double check attribute type
+ TosaConvAttribute *attr =
+ static_cast<TosaConvAttribute *>(op->GetAttribute());
+ mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
+ mlir::DenseI64ArrayAttr stride =
+ BuildDenseI64ArrayAttr(op_builder, attr->stride());
+ mlir::DenseI64ArrayAttr dilation =
+ BuildDenseI64ArrayAttr(op_builder, attr->dilation());
+ auto input_zp = attr->input_zp();
+ auto weight_zp = attr->weight_zp();
+ bool local_bound = attr->local_bound();
+ auto acc_type = AccDType2Type(op_builder, attr->acc_type());
+
+ // input_zp/weight_zp is not allowed for float type
+ mlir::Operation *mlir_op;
+ if (output_type.getElementType().isa<mlir::FloatType>()) {
+ assert(input_zp == 0 && weight_zp == 0);
+ }
+
+ auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp);
+ auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp);
+ mlir_op = op_builder->create<MLIR_OP>(
+ loc, output_type, input0_val, input1_val, input2_val, pad, stride,
+ dilation, acc_type, input_zp_attr, weight_zp_attr, local_bound);
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+#define BUILD_OP_CONV(MLIR_OP_NAME, SCHEMA_OP_NAME) \
+ template <> \
+ std::vector<mlir::Value> \
+ TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \
+ TosaSerializationOperator * op) const { \
+ return BuildConvOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \
+ }
+
+BUILD_OP_CONV(Conv2D, CONV2D)
+BUILD_OP_CONV(Conv3D, CONV3D)
+BUILD_OP_CONV(DepthwiseConv2D, DEPTHWISE_CONV2D)
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE_CONV2D>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_TransposeConvAttribute); // double check attribute type
+ TosaTransposeConvAttribute *attr =
+ static_cast<TosaTransposeConvAttribute *>(op->GetAttribute());
+ mlir::DenseI64ArrayAttr out_pad =
+ BuildDenseI64ArrayAttr(op_builder, attr->out_pad());
+ mlir::DenseI64ArrayAttr stride =
+ BuildDenseI64ArrayAttr(op_builder, attr->stride());
+ auto input_zp = attr->input_zp();
+ auto weight_zp = attr->weight_zp();
+ bool local_bound = attr->local_bound();
+ auto acc_type = AccDType2Type(op_builder, attr->acc_type());
+
+ // input_zp/weight_zp is not allowed for float type
+ mlir::Operation *mlir_op;
+ if (output_type.getElementType().isa<mlir::FloatType>()) {
+ assert(input_zp == 0 && weight_zp == 0);
+ }
+
+ auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp);
+ auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp);
+
+ mlir_op = op_builder->create<mlir::tosa::TransposeConv2DOp>(
+ loc, output_type, input0_val, input1_val, input2_val, out_pad, stride,
+ acc_type, input_zp_attr, weight_zp_attr, local_bound);
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_FULLY_CONNECTED>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_FullyConnectedAttribute); // double check attribute type
+ TosaFullyConnectedAttribute *attr =
+ static_cast<TosaFullyConnectedAttribute *>(op->GetAttribute());
+ auto input_zp = attr->input_zp();
+ auto weight_zp = attr->weight_zp();
+
+ // input_zp/weight_zp is not allowed for float type
+ mlir::Operation *mlir_op;
+ if (output_type.getElementType().isa<mlir::FloatType>()) {
+ assert(input_zp == 0 && weight_zp == 0);
+ mlir_op = op_builder->create<mlir::tosa::FullyConnectedOp>(
+ loc, output_type, input0_val, input1_val, input2_val);
+ } else {
+ auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp);
+ auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp);
+ mlir_op = op_builder->create<mlir::tosa::FullyConnectedOp>(
+ loc, output_type, input0_val, input1_val, input2_val, input_zp_attr,
+ weight_zp_attr);
+ }
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_MATMUL>(TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_MatMulAttribute); // double check attribute type
+ TosaMatMulAttribute *attr =
+ static_cast<TosaMatMulAttribute *>(op->GetAttribute());
+ auto A_zp = attr->a_zp();
+ auto B_zp = attr->b_zp();
+
+ mlir::Operation *mlir_op;
+ if (A_zp == 0 && B_zp == 0) {
+ mlir_op = op_builder->create<mlir::tosa::MatMulOp>(loc, output_type,
+ input0_val, input1_val);
+ } else {
+ auto a_zp_attr = op_builder->getI32IntegerAttr(A_zp);
+ auto b_zp_attr = op_builder->getI32IntegerAttr(B_zp);
+ mlir_op = op_builder->create<mlir::tosa::MatMulOp>(
+ loc, output_type, input0_val, input1_val, a_zp_attr, b_zp_attr);
+ }
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_SELECT>(TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::SelectOp>(
+ loc, output_type, input0_val, input1_val, input2_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_ClampAttribute); // double check attribute type
+ TosaClampAttribute *attr =
+ static_cast<TosaClampAttribute *>(op->GetAttribute());
+
+ mlir::Type element_type =
+ llvm::cast<mlir::ShapedType>(input_val.getType()).getElementType();
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(element_type)) {
+ element_type = quantType.getStorageType();
+ }
+
+ auto element_const_type = mlir::RankedTensorType::get({1}, element_type);
+ auto min_values_attr = GetConstAttr(attr->min_val(), element_const_type, 1);
+ auto max_values_attr = GetConstAttr(attr->max_val(), element_const_type, 1);
+
+ mlir::Attribute min_val_attr, max_val_attr;
+ if (element_type.isa<mlir::FloatType>()) {
+ min_val_attr = op_builder->getFloatAttr(
+ element_type, min_values_attr.getValues<mlir::APFloat>()[0]);
+ max_val_attr = op_builder->getFloatAttr(
+ element_type, max_values_attr.getValues<mlir::APFloat>()[0]);
+ } else {
+ min_val_attr = op_builder->getIntegerAttr(
+ element_type, min_values_attr.getValues<mlir::APInt>()[0]);
+ max_val_attr = op_builder->getIntegerAttr(
+ element_type, max_values_attr.getValues<mlir::APInt>()[0]);
+ }
+
+ auto mlir_op = op_builder->create<mlir::tosa::ClampOp>(
+ loc, output_type, input_val, min_val_attr, max_val_attr);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+// ArgMax has single input, and single I64 axis attribute
+BUILD_OP_REDUCTION(ArgMax, ARGMAX)
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_CONCAT>(TosaSerializationOperator *op) const {
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ llvm::SmallVector<mlir::Value> input_values;
+ for (auto &input_name : op->GetInputTensorNames()) {
+ mlir::Value input_val = tensor_map->at(input_name);
+ input_values.push_back(input_val);
+ }
+
+ assert(op->GetAttributeType() ==
+ Attribute_AxisAttribute); // double check attribute type
+ TosaAxisAttribute *attr =
+ static_cast<TosaAxisAttribute *>(op->GetAttribute());
+ auto axis = op_builder->getI32IntegerAttr(attr->axis());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ConcatOp>(
+ loc, output_type, input_values, axis);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONCAT_SHAPE>(
+ TosaSerializationOperator *op) const {
+ mlir::tosa::shapeType output_type =
+ shape_type_map->at(op->GetOutputTensorNames()[0]);
+
+ llvm::SmallVector<mlir::Value> input_values;
+ for (auto &input_name : op->GetInputTensorNames()) {
+ mlir::Value input_val = tensor_map->at(input_name);
+ input_values.push_back(input_val);
+ }
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ConcatShapeOp>(
+ loc, output_type, input_values);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_NEGATE>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_NegateAttribute); // double check attribute type
+ TosaNegateAttribute *attr =
+ static_cast<TosaNegateAttribute *>(op->GetAttribute());
+
+ auto input_zp = attr->input1_zp();
+ auto output_zp = attr->output_zp();
+
+ mlir::Operation *mlir_op;
+ if (input_zp == 0 && output_zp == 0) {
+ mlir_op =
+ op_builder->create<mlir::tosa::NegateOp>(loc, output_type, input_val);
+ } else {
+ auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp);
+ auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp);
+ mlir_op = op_builder->create<mlir::tosa::NegateOp>(
+ loc, output_type, input_val, input_zp_attr, output_zp_attr);
+ }
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESHAPE>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ mlir::Value shape_val = tensor_map->at(op->GetInputTensorNames()[1]);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ReshapeOp>(
+ loc, output_type, input_val, shape_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value padding_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType input_type =
+ tensor_type_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ const auto element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
+
+ assert(op->GetAttributeType() ==
+ Attribute_PadAttribute); // double check attribute type
+ TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute());
+ const auto &pad_const_u8_data = attr->pad_const();
+
+ // check for any value in pad_const_u8_data
+ bool has_pad_const = false;
+ for (auto v : pad_const_u8_data) {
+ if (v != 0) {
+ has_pad_const = true;
+ break;
+ }
+ }
+ if (!has_pad_const) {
+ // handle the cases where no explicit pad_const input.
+ auto mlir_op = op_builder->create<mlir::tosa::PadOp>(
+ loc, output_type, input_val, padding_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+ }
+
+ // has pad const - create a const op for pad_const input
+ auto pad_const_type = mlir::RankedTensorType::get({}, element_type);
+ auto pad_const_attr = GetConstAttr(pad_const_u8_data, pad_const_type, 1);
+
+ auto pad_const_op = op_builder->create<mlir::tosa::ConstOp>(
+ loc, pad_const_type, pad_const_attr);
+
+ block->push_back(pad_const_op);
+ mlir::Value pad_const_value = pad_const_op->getResult(0);
+
+ auto mlir_op = op_builder->create<mlir::tosa::PadOp>(
+ loc, output_type, input_val, padding_val, pad_const_value);
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_DIM>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_AxisAttribute); // double check attribute type
+ TosaAxisAttribute *attr =
+ static_cast<TosaAxisAttribute *>(op->GetAttribute());
+ auto axis = op_builder->getI32IntegerAttr(attr->axis());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::DimOp>(loc, output_type, input_val, axis);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_TransposeAttribute); // double check attribute type
+ TosaTransposeAttribute *attr =
+ static_cast<TosaTransposeAttribute *>(op->GetAttribute());
+ // make a constant op from attr->perms values, of type: { shape = { perms.size
+ // }, element_type = I32 }
+ const auto perms_values = attr->perms();
+ auto const_type = mlir::RankedTensorType::get(
+ {static_cast<int64_t>(perms_values.size())}, op_builder->getI32Type());
+ mlir::DenseElementsAttr const_attr =
+ BuildDenseI32ElementsAttr(op_builder, const_type, perms_values);
+ mlir::Operation *mlir_const_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
+ auto perms_val = mlir_const_op->getResult(0);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TransposeOp>(
+ loc, output_type, input_val, perms_val);
+
+ block->push_back(mlir_const_op);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_SLICE>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ mlir::Value start = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value size = tensor_map->at(op->GetInputTensorNames()[2]);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::SliceOp>(
+ loc, output_type, input_val, start, size);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_TILE>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value multiples = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check attribute type
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TileOp>(
+ loc, output_type, input_val, multiples);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+// Gather is a binary op
+BUILD_OP_ELEMENTWISE_BINARY(Gather, GATHER)
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_SCATTER>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ assert(op->GetAttributeType() ==
+ Attribute_NONE); // double check that there is no attribute
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ScatterOp>(
+ loc, output_type, input0_val, input1_val, input2_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_RESIZE>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value scale_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value offset_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::Value border_val = tensor_map->at(op->GetInputTensorNames()[3]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_ResizeAttribute); // double check attribute type
+ TosaResizeAttribute *attr =
+ static_cast<TosaResizeAttribute *>(op->GetAttribute());
+
+ auto mode = op_builder->getStringAttr(ResizeEnum2Str(attr->mode()));
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ResizeOp>(
+ loc, output_type, input_val, scale_val, offset_val, border_val, mode);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+// Reverse has single input, and single I64 axis attribute
+BUILD_OP_REDUCTION(Reverse, REVERSE)
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_MUL>(TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]);
+
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_MulAttribute); // double check attribute type
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MulOp>(
+ loc, output_type, input0_val, input1_val, shift_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_ARITHMETIC_RIGHT_SHIFT>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(
+ op->GetAttributeType() ==
+ Attribute_ArithmeticRightShiftAttribute); // double check attribute type
+ TosaArithmeticRightShiftAttribute *attr =
+ static_cast<TosaArithmeticRightShiftAttribute *>(op->GetAttribute());
+
+ auto round = op_builder->getBoolAttr(attr->round());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::ArithmeticRightShiftOp>(
+ loc, output_type, input0_val, input1_val, round);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_TableAttribute); // double check attribute type
+ TosaTableAttribute *attr =
+ static_cast<TosaTableAttribute *>(op->GetAttribute());
+
+ // create a const op for table value attribute
+ const auto table_values = attr->table();
+ mlir::RankedTensorType const_type;
+ mlir::DenseElementsAttr const_attr;
+ const auto input_element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
+ if (input_element_type.isInteger(8)) {
+ // table is signed 8 mode
+ const_type = mlir::RankedTensorType::get(
+ {static_cast<int64_t>(table_values.size())}, op_builder->getI8Type());
+ const_attr = BuildDenseI8ElementsAttr(op_builder, table_values);
+ } else {
+ // table is signed 16 mode
+ const_type = mlir::RankedTensorType::get(
+ {static_cast<int64_t>(table_values.size())}, op_builder->getI16Type());
+ const_attr = BuildDenseI16ElementsAttr(op_builder, table_values);
+ }
+ mlir::Operation *mlir_const_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
+ auto table_value = mlir_const_op->getResult(0);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TableOp>(
+ loc, output_type, input_val, table_value);
+ block->push_back(mlir_const_op);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESCALE>(
+ TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value multiplier_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_RescaleAttribute); // double check attribute type
+ TosaRescaleAttribute *attr =
+ static_cast<TosaRescaleAttribute *>(op->GetAttribute());
+
+ auto input_zp = op_builder->getI32IntegerAttr(attr->input_zp());
+ auto output_zp = op_builder->getI32IntegerAttr(attr->output_zp());
+
+ auto scale32 = op_builder->getBoolAttr(attr->scale32());
+ auto double_round = op_builder->getBoolAttr(attr->double_round());
+ auto per_channel = op_builder->getBoolAttr(attr->per_channel());
+
+ auto input_unsigned = op_builder->getBoolAttr(attr->input_unsigned());
+ auto output_unsigned = op_builder->getBoolAttr(attr->output_unsigned());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RescaleOp>(
+ loc, output_type, input_val, multiplier_val, shift_val, input_zp,
+ output_zp, scale32, double_round, per_channel, input_unsigned,
+ output_unsigned);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_CUSTOM>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_CustomAttribute); // double check attribute type
+ TosaCustomAttribute *attr =
+ static_cast<TosaCustomAttribute *>(op->GetAttribute());
+
+ auto operator_name = op_builder->getStringAttr(attr->operator_name());
+ auto domain_name = op_builder->getStringAttr(attr->domain_name());
+ std::string impl_str;
+ impl_str.resize(attr->implementation_attrs().size() + 1);
+ int idx = 0;
+ for (auto c : attr->implementation_attrs()) {
+ impl_str[idx++] = c;
+ }
+ auto impl = op_builder->getStringAttr(impl_str);
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::CustomOp>(
+ loc, output_type, operator_name, domain_name, impl, input_val);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_RFFT2D>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output0_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ mlir::RankedTensorType output1_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[1]);
+ assert(op->GetAttributeType() ==
+ Attribute_RFFTAttribute); // double check attribute type
+ TosaRFFTAttribute *attr =
+ static_cast<TosaRFFTAttribute *>(op->GetAttribute());
+
+ bool local_bound = attr->local_bound();
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RFFT2dOp>(
+ loc, output0_type, output1_type, input_val, local_bound);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(
+ {mlir_op->getResult(0), mlir_op->getResult(1)});
+}
+
+template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_FFT2D>(TosaSerializationOperator *op) const {
+ mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::RankedTensorType output0_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ mlir::RankedTensorType output1_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[1]);
+
+ assert(op->GetAttributeType() == Attribute_FFTAttribute);
+ TosaFFTAttribute *attr = static_cast<TosaFFTAttribute *>(op->GetAttribute());
+ auto inverse = op_builder->getBoolAttr(attr->inverse());
+ auto local_bound = op_builder->getBoolAttr(attr->local_bound());
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::FFT2dOp>(
+ loc, output0_type, output1_type, input0_val, input1_val, inverse,
+ local_bound);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(
+ {mlir_op->getResult(0), mlir_op->getResult(1)});
+}
+
+class TosaMlirRegionBuilder {
+public:
+ TosaMlirRegionBuilder(TosaSerializationRegion *_ser_region,
+ TosaSerializationHandler *_tsh, mlir::Region *_region,
+ mlir::OpBuilder *_op_builder, mlir::Location _loc,
+ TosaMlirRegionBuilder *_parent_value_scope = nullptr)
+ : ser_region(_ser_region), tsh(_tsh), region(_region),
+ op_builder(_op_builder), loc(_loc) {
+ if (_parent_value_scope) {
+ // inherit parent_value_scope's tensor_map
+ for (auto &kv : _parent_value_scope->GetTensorMap()) {
+ tensor_map.insert(kv);
+ }
+ }
+ }
+
+ mlir::LogicalResult
+ BuildAllBlocksInRegion(std::vector<mlir::Value> &return_values);
+
+ mlir::OpBuilder *GetOpBuilder() { return op_builder; }
+ mlir::Location GetLocation() { return loc; }
+ std::unordered_map<std::string, mlir::Value> &GetTensorMap() {
+ return tensor_map;
+ }
+ TosaSerializationHandler *GetTsh() const { return tsh; }
+
+private:
+ mlir::Region *region;
+ TosaSerializationRegion *ser_region;
+ TosaSerializationHandler *tsh;
+ mlir::OpBuilder *op_builder;
+ mlir::Location loc;
+ std::unordered_map<std::string, mlir::Value> tensor_map;
+};
+
+class TosaMlirBlockBuilder {
+public:
+ TosaMlirBlockBuilder(TosaSerializationBasicBlock *_ser_block,
+ TosaMlirRegionBuilder *_region_builder,
+ mlir::Block *_block)
+ : ser_block(_ser_block), region_builder(_region_builder), block(_block) {}
+
+ mlir::LogicalResult
+ BuildAllOpsInBlock(std::vector<mlir::Value> &return_values);
+
+ mlir::OpBuilder *GetOpBuilder() { return region_builder->GetOpBuilder(); }
+ mlir::Location GetLocation() { return region_builder->GetLocation(); }
+ std::unordered_map<std::string, mlir::Value> &GetTensorMap() {
+ return region_builder->GetTensorMap();
+ }
+
+ TosaSerializationHandler *GetTsh() const { return region_builder->GetTsh(); }
+ TosaMlirRegionBuilder *GetRegionBuilder() const { return region_builder; }
+
+private:
+ TosaSerializationBasicBlock *ser_block;
+ TosaMlirRegionBuilder *region_builder;
+ mlir::Block *block;
+ std::unordered_map<std::string, mlir::RankedTensorType> tensor_type_map;
+ std::unordered_map<std::string, mlir::tosa::shapeType> shape_type_map;
+ std::unordered_set<std::string> unranked_tensors;
+};
+
+TosaSerializationHandler *TosaMlirOperatorBuilder::GetTsh() const {
+ return block_builder->GetTsh();
+}
+
+TosaMlirRegionBuilder *TosaMlirOperatorBuilder::GetRegionBuilder() const {
+ return block_builder->GetRegionBuilder();
+}
+
+// build control flow ops:
+
+namespace {
+
+mlir::LogicalResult
+BuildRegion(TosaSerializationRegion *ser_region, TosaSerializationHandler *tsh,
+ mlir::Region *mlir_region, mlir::OpBuilder *op_builder,
+ mlir::Location loc, std::vector<mlir::Value> &return_values,
+ bool isolated_from_above = false,
+ TosaMlirRegionBuilder *parent_region_builder = nullptr) {
+ TosaMlirRegionBuilder *parent_value_scope =
+ isolated_from_above ? nullptr : parent_region_builder;
+ TosaMlirRegionBuilder region_builder(ser_region, tsh, mlir_region, op_builder,
+ loc, parent_value_scope);
+ return region_builder.BuildAllBlocksInRegion(return_values);
+}
+
+} // namespace
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_COND_IF>(
+ TosaSerializationOperator *op) const {
+ mlir::Value cond_val = tensor_map->at(op->GetInputTensorNames().at(0));
+ std::vector<mlir::Value> input_values;
+ for (auto idx = 1u; idx < op->GetInputTensorNames().size(); idx++) {
+ input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx)));
+ }
+ std::vector<mlir::Type> output_types;
+ for (auto &name : op->GetInputTensorNames()) {
+ output_types.push_back(tensor_type_map->at(name));
+ }
+
+ assert(op->GetAttributeType() ==
+ Attribute_CondIfAttribute); // double check attribute type
+ TosaCondIfAttribute *attr =
+ static_cast<TosaCondIfAttribute *>(op->GetAttribute());
+ auto ser_then_region = GetTsh()->GetRegionByName(attr->then_graph());
+ auto ser_else_region = GetTsh()->GetRegionByName(attr->else_graph());
+
+ if (!ser_then_region || !ser_else_region) {
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " region serialization hasn't been implemented\n";
+ return {};
+ }
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::IfOp>(
+ loc, output_types, cond_val, input_values);
+
+ const bool isolated_from_above =
+ mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+ mlir::Region &then_region = mlir_op->getRegion(0);
+ mlir::Region &else_region = mlir_op->getRegion(1);
+
+ auto curr_region_builder = GetRegionBuilder();
+
+ std::vector<mlir::Value> then_returns, else_returns;
+
+ if (BuildRegion(ser_then_region, GetTsh(), &then_region, op_builder, loc,
+ then_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (then_returns.size() != mlir_op->getNumResults()) {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " then_region yield.size() doesn't match cond_if's output size\n";
+ return {};
+ }
+
+ if (BuildRegion(ser_else_region, GetTsh(), &else_region, op_builder, loc,
+ else_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (else_returns.size() != mlir_op->getNumResults()) {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " else_region yield.size() doesn't match cond_if's output size\n";
+ return {};
+ }
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(mlir_op->getResults().begin(),
+ mlir_op->getResults().end());
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_WHILE_LOOP>(
+ TosaSerializationOperator *op) const {
+ std::vector<mlir::Value> input_values;
+ for (auto idx = 0u; idx < op->GetInputTensorNames().size(); idx++) {
+ input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx)));
+ }
+ std::vector<mlir::Type> output_types;
+ for (auto &name : op->GetInputTensorNames()) {
+ output_types.push_back(tensor_type_map->at(name));
+ }
+ assert(op->GetAttributeType() ==
+ Attribute_WhileLoopAttribute); // double check attribute type
+ TosaWhileLoopAttribute *attr =
+ static_cast<TosaWhileLoopAttribute *>(op->GetAttribute());
+ auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_graph());
+ auto ser_body_region = GetTsh()->GetRegionByName(attr->body_graph());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::WhileOp>(loc, output_types, input_values);
+
+ const bool isolated_from_above =
+ mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+
+ mlir::Region &cond_region = mlir_op->getRegion(0);
+ mlir::Region &body_region = mlir_op->getRegion(1);
+
+ auto curr_region_builder = GetRegionBuilder();
+
+ std::vector<mlir::Value> cond_returns, body_returns;
+
+ if (BuildRegion(ser_cond_region, GetTsh(), &cond_region, op_builder, loc,
+ cond_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (cond_returns.size() != 1) {
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " cond_region yield.size() is not 1\n";
+ return {};
+ }
+
+ if (BuildRegion(ser_body_region, GetTsh(), &body_region, op_builder, loc,
+ body_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (body_returns.size() != mlir_op->getNumResults()) {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " body_region yield.size() doesn't match while_loop's output size\n";
+ return {};
+ }
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(mlir_op->getResults().begin(),
+ mlir_op->getResults().end());
+}
+
+mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
+ std::vector<mlir::Value> &return_values) {
+ block->clear();
+ auto loc = GetLocation();
+ auto op_builder = GetOpBuilder();
+ auto &tensor_map = GetTensorMap();
+
+ std::unordered_set<TosaSerializationOperator *> operator_built;
+ std::queue<TosaSerializationOperator *> operator_queue;
+
+ TosaMlirOperatorBuilder tosa_op_builder(op_builder, ser_block, block, loc,
+ this, &tensor_map, &tensor_type_map,
+ &shape_type_map);
+
+ for (auto ts : ser_block->GetTensors()) {
+ if (ts->GetVariable()) {
+ RegisterVariableTensor(ts);
+ }
+ const auto &ts_name = ts->GetName();
+ if (ts->GetDtype() == DType::DType_SHAPE) {
+ // ts is tosa.shape type
+ auto shape_rank = ts->GetShape()[0];
+ shape_type_map[ts_name] =
+ mlir::tosa::shapeType::get(op_builder->getContext(), shape_rank);
+ continue;
+ }
+ mlir::RankedTensorType type;
+ if (BuildTensorType(op_builder, ts, type).failed()) {
+ return mlir::failure();
+ }
+ tensor_type_map[ts_name] = type;
+ if (ts->GetIsUnranked()) {
+ assert(ts->GetShape().empty()); // unranked tensors should have shape = {}
+ unranked_tensors.insert(ts_name);
+ }
+ }
+
+ // Initialize tensor_map/operator_queue based on block input arguments
+ for (const std::string &block_input_name : ser_block->GetInputs()) {
+ mlir::Type type = tensor_type_map[block_input_name];
+ if (unranked_tensors.count(block_input_name)) {
+ // recast type as unranked tensor type
+ auto element_type = type.cast<mlir::RankedTensorType>().getElementType();
+ type = mlir::UnrankedTensorType::get(element_type);
+ }
+ auto input_value = block->addArgument(type, loc);
+ if (tensor_map.count(block_input_name)) {
+ llvm::errs() << "ERROR: block input tensor " << block_input_name
+ << " already exists\n";
+ return mlir::failure();
+ }
+ tensor_map[block_input_name] = input_value;
+ }
+
+ for (auto op : ser_block->GetOperators()) {
+
+ // skip if operator has been built
+ if (operator_built.count(op)) {
+ // this happens when same input appears twice or more in operator, eg,
+ // concat(%0, %0)
+ continue;
+ }
+ operator_built.insert(op);
+
+ std::vector<mlir::Value> output_values;
+ if (IsVariableReadOp(op)) {
+ output_values = tosa_op_builder.BuildVariableReadOp(op);
+ } else if (IsVariableWriteOp(op)) {
+ tosa_op_builder.BuildVariableWriteOp(op);
+ }
+#define DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME) \
+ else if (op->GetOp() == Op_##SCHEMA_OP_NAME) { \
+ output_values = tosa_op_builder.build<Op_##SCHEMA_OP_NAME>(op); \
+ }
+#include "schema_operator.def"
+#undef DEF_SCHEMA_OPERATOR
+ else {
+ llvm::errs() << "ERROR: unsupported opcode=" << EnumNamesOp()[op->GetOp()]
+ << "\n";
+ return mlir::failure();
+ }
+
+ if (IsVariableWriteOp(op)) {
+ // the sanity checking below does not apply for variable write op because
+ // it has no output tensors whereas the original identity op has
+ continue;
+ }
+
+ // Sanity check if number of built mlir::Value is expected
+ if (op->GetOutputTensorNames().size() != output_values.size()) {
+ llvm::errs() << "ERROR: number of built mlir::Value is not matching "
+ "number of operator output tensor\n";
+ return mlir::failure();
+ }
+
+ for (size_t i = 0; i < output_values.size(); i++) {
+ // Sanity check tensor hasn't been built
+ std::string op_output_name = op->GetOutputTensorNames()[i];
+ if (tensor_map.count(op_output_name)) {
+ llvm::errs() << "ERROR: tensor " << op_output_name
+ << " is already built\n";
+ return mlir::failure();
+ }
+ tensor_map[op_output_name] = output_values[i];
+ }
+ }
+
+ // Construct return values
+ std::vector<mlir::Value> return_operands;
+ for (const auto &output_name : ser_block->GetOutputs()) {
+ // Sanity check if terminator mlir::Value is built
+ if (!tensor_map.count(output_name)) {
+ llvm::errs() << "ERROR: terminator mlir::Value " << output_name
+ << " is not built in block " << ser_block->GetName() << "\n";
+ return mlir::failure();
+ }
+ mlir::Value output_value = tensor_map.at(output_name);
+ return_operands.push_back(output_value);
+ return_values.push_back(output_value);
+ }
+ mlir::Operation *terminator_op;
+ auto parent_op = block->getParentOp();
+ if (mlir::isa<mlir::func::FuncOp>(parent_op)) {
+ terminator_op =
+ op_builder->create<mlir::func::ReturnOp>(loc, return_operands);
+ } else {
+ terminator_op =
+ op_builder->create<mlir::tosa::YieldOp>(loc, return_operands);
+ }
+ block->push_back(terminator_op);
+
+ // need topological sorting?
+
+ return mlir::success();
+}
+
+mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion(
+ std::vector<mlir::Value> &return_values) {
+ for (auto &ser_block : ser_region->GetBlocks()) {
+ auto &block = region->emplaceBlock();
+ TosaMlirBlockBuilder block_builder(ser_block, this, &block);
+
+ if (block_builder.BuildAllOpsInBlock(return_values).failed()) {
+ return mlir::failure();
+ }
+
+ if (return_values.empty()) {
+ llvm::errs() << "Warning: graph doesn't have return values\n";
+ }
+ }
+ return mlir::success();
+}
+
+mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp &func,
+ mlir::MLIRContext &context,
+ tosa::TosaSerializationHandler &tsh,
+ std::vector<mlir::Value> &main_returns) {
+
+ mlir::Region *main_region = func.getCallableRegion();
+ if (!main_region) {
+ llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n";
+ return mlir::failure();
+ }
+
+ TosaSerializationRegion *ser_main_region = tsh.GetRegions().front();
+
+ auto loc = func.getLoc();
+
+ main_region->takeBody(*main_region); // empty old func body
+ auto op_builder = mlir::OpBuilder(func.getBody());
+
+ if (BuildRegion(ser_main_region, &tsh, main_region, &op_builder, loc,
+ main_returns)
+ .failed()) {
+ return mlir::failure();
+ }
+
+ if (main_returns.empty()) {
+ llvm::errs() << "Warning: graph doesn't have return values\n";
+ }
+
+ return mlir::success();
+}
+
+// Load Tosa Schema into TosaSerializationHandler, required for JSON save/load
+mlir::LogicalResult loadTosaSchema(tosa::TosaSerializationHandler &tsh) {
+ const char *tosa_schema = tosa_deserialize_schema.c_str();
+
+ if (!tosa_schema) {
+ llvm::errs() << "Flatbuffer schema not defined\n";
+ return mlir::failure();
+ }
+
+ if (tsh.LoadFileSchema(tosa_schema)) {
+ llvm::errs() << "Error loading tosa schema file: " << tosa_schema << "\n";
+ return mlir::failure();
+ }
+ return mlir::success();
+}
+
+namespace {
+
+mlir::NamedAttribute DefaultEntryFuncitonAttr(mlir::Builder &builder,
+ bool is_input, int count) {
+ std::string names;
+ for (int i = 0; i < count; i++) {
+ std::string name = kDefaultExportedName + "_";
+ name += (is_input ? kDefaultInputPrefix : kDefaultOutputPrefix);
+ name += std::to_string(i) + ":0";
+ if (i > 0) {
+ names += ",";
+ }
+ names += name;
+ }
+ return builder.getNamedAttr((is_input ? "inputs" : "outputs"),
+ builder.getStringAttr(names));
+}
+
+// erase all ops in block except for FuncOp
+void ClearNonFuncOps(mlir::Block *block) {
+ std::vector<mlir::Operation *> to_delete;
+ for (auto &op : block->getOperations()) {
+ if (!mlir::isa<mlir::func::FuncOp>(op)) {
+ to_delete.push_back(&op);
+ }
+ }
+ for (mlir::Operation *op : to_delete) {
+ op->erase();
+ }
+}
+
+// erase function attrs and empty function region's body
+void ResetFunction(mlir::func::FuncOp &function, mlir::MLIRContext &context) {
+ function->setAttrs(mlir::DictionaryAttr::get(&context, {}));
+ mlir::Region *main_region = function.getCallableRegion();
+ main_region->takeBody(*main_region);
+}
+
+// replace attrs and body of @a to_function and its parent module
+// by @a from_module and its "main" function
+mlir::LogicalResult CloneIntoModuleAndFunction(
+ mlir::MLIRContext &context, mlir::func::FuncOp &to_function,
+ mlir::ModuleOp &to_module, mlir::func::FuncOp &from_function,
+ mlir::ModuleOp &from_module) {
+ auto from_block = from_function.getOperation()->getBlock();
+ auto to_block = to_function.getOperation()->getBlock();
+ ClearNonFuncOps(to_block);
+ // copy all attrs from new_module to module
+ to_module->setAttrs(from_module->getAttrDictionary());
+ // erase attrs and body of function
+ ResetFunction(to_function, context);
+ // clone new_func attrs and region into function
+ mlir::IRMapping mapping;
+ from_function.cloneInto(to_function, mapping);
+
+ // copy variable ops in from_block to to_block
+ // collect variable ops in from_block in reverse order
+ std::vector<mlir::Operation *> variable_ops;
+ for (mlir::Operation &op : *from_block) {
+ if (mlir::isa<mlir::tosa::VariableOp>(op)) {
+ variable_ops.push_back(&op);
+ }
+ }
+ auto cloneOptions =
+ mlir::Operation::CloneOptions::all().cloneRegions(false).cloneOperands(
+ false);
+ for (auto iter = variable_ops.rbegin(); iter != variable_ops.rend(); iter++) {
+ auto op = *iter;
+ to_block->push_front(op->clone(mapping, cloneOptions));
+ }
+ return mlir::success();
+}
+
+} // namespace
+
+namespace mlir {
+
+namespace tosa {
+
+mlir::OwningOpRef<mlir::ModuleOp>
+BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context,
+ bool file_is_fbs) {
+ TosaSerializationHandler tsh;
+ if (file_is_fbs) {
+ if (tsh.LoadFileTosaFlatbuffer(file_name)) {
+ llvm::errs() << "Fail to load TOSA file " << file_name << "\n";
+ return nullptr;
+ }
+ } else {
+ // must load tosa schema before loading json file
+ if (loadTosaSchema(tsh).failed()) {
+ return nullptr;
+ }
+ if (tsh.LoadFileJson(file_name)) {
+ llvm::errs() << "Fail to load TOSA JSON file " << file_name << "\n";
+ return nullptr;
+ }
+ }
+
+ // create new module
+ auto base_loc = mlir::FileLineColLoc::get(context, file_name, 0, 0);
+ auto module = mlir::ModuleOp::create(base_loc);
+
+ // set module attributes
+ const auto &tosa_version = tsh.GetVersion().to_string();
+ std::string tosa_description =
+ file_is_fbs ? kDefaultFBSDescription : kDefaultJSONDescription;
+ auto builder = mlir::Builder(context);
+ module->setAttr("tosa.fbs_version", builder.getStringAttr(tosa_version));
+ module->setAttr("tosa.description", builder.getStringAttr(tosa_description));
+ module->setAttr("tf_saved_model.semantics", mlir::UnitAttr::get(context));
+
+ // construct function with input and return types
+ llvm::SmallVector<mlir::Type, 2> ret_types;
+ llvm::SmallVector<mlir::Type, 4> input_types;
+ auto func_type = builder.getFunctionType(input_types, ret_types);
+ auto func_loc =
+ mlir::NameLoc::get(builder.getStringAttr(kMainFunctionName), base_loc);
+ auto func = mlir::func::FuncOp::create(func_loc, kMainFunctionName, func_type,
+ /* attrs= */ {});
+ func.addEntryBlock();
+
+ // deserialize tosa fbs into function
+ std::vector<mlir::Value> main_returns;
+ if (buildTosaMlir(func, *context, tsh, main_returns).failed()) {
+ llvm::errs() << "Failed to deserialize flatbuffer "
+ << tosa_deserialize_filename << "\n";
+ return nullptr;
+ }
+ auto main_args = func.getCallableRegion()->getArguments();
+ // extract function input types
+ for (auto arg : main_args) {
+ input_types.push_back(arg.getType());
+ }
+ // extract function return types
+ for (auto ret : main_returns) {
+ ret_types.push_back(ret.getType());
+ }
+ // set function type with full input and return types
+ func_type = builder.getFunctionType(input_types, ret_types);
+ func.setType(func_type);
+
+ // set function attributes
+ llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
+ if (!input_types.empty()) {
+ attributes.push_back(DefaultEntryFuncitonAttr(
+ builder, /* is_input = */ true, /* count = */ input_types.size()));
+ for (int i = 0; i < input_types.size(); i++) {
+ std::string input_i = kDefaultInputPrefix + std::to_string(i);
+ func.setArgAttr(i, "tf_saved_model.index_path",
+ mlir::ArrayAttr::get(
+ context, {mlir::StringAttr::get(context, input_i)}));
+ }
+ }
+ if (!ret_types.empty()) {
+ attributes.push_back(DefaultEntryFuncitonAttr(
+ builder, /* is_input = */ false, /* count = */ ret_types.size()));
+ for (int i = 0; i < ret_types.size(); i++) {
+ std::string output_i = kDefaultOutputPrefix + std::to_string(i);
+ func.setResultAttr(
+ i, "tf_saved_model.index_path",
+ mlir::ArrayAttr::get(context,
+ {mlir::StringAttr::get(context, output_i)}));
+ }
+ }
+ func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
+ func->setAttr(
+ "tf_saved_model.exported_names",
+ mlir::ArrayAttr::get(
+ context, {mlir::StringAttr::get(context, kDefaultExportedName)}));
+
+ // deserialize variable ops in the new module just before adding func op
+ if (ConstructVariableOps(module).failed()) {
+ return nullptr;
+ }
+
+ // add func to module
+ module.push_back(std::move(func));
+ return mlir::OwningOpRef<mlir::ModuleOp>(module);
+}
+
+namespace {
+
+class TosaDeserialize : public TosaDeserializationPassBase<TosaDeserialize> {
+public:
+ void runOnOperation() final {
+ auto function = getOperation();
+ auto &context = getContext();
+
+ auto new_module_ref = BuildMlirFromTosaFile(
+ tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ true);
+ if (!new_module_ref) {
+ return signalPassFailure();
+ }
+
+ mlir::ModuleOp new_module = *new_module_ref;
+ auto builder = mlir::Builder(&context);
+ auto module = function->getParentOfType<mlir::ModuleOp>();
+ auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>(
+ builder.getStringAttr(kMainFunctionName));
+ if (!new_function) {
+ llvm::errs() << "Failed to find main function in deserialized module\n";
+ return signalPassFailure();
+ }
+ if (CloneIntoModuleAndFunction(context,
+ /* to_function = */ function,
+ /* to_module = */ module,
+ /* from_function = */ new_function,
+ /* from_module = */ new_module)
+ .failed()) {
+ return signalPassFailure();
+ }
+ }
+};
+
+class TosaDeserializeJSON
+ : public TosaDeserializationJSONPassBase<TosaDeserializeJSON> {
+public:
+ void runOnOperation() final {
+ auto function = getOperation();
+ auto &context = getContext();
+
+ auto new_module_ref = BuildMlirFromTosaFile(
+ tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ false);
+ if (!new_module_ref) {
+ return signalPassFailure();
+ }
+
+ mlir::ModuleOp new_module = *new_module_ref;
+ auto builder = mlir::Builder(&context);
+ auto module = function->getParentOfType<mlir::ModuleOp>();
+ auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>(
+ builder.getStringAttr(kMainFunctionName));
+ if (!new_function) {
+ llvm::errs() << "Failed to find main function in deserialized module\n";
+ return signalPassFailure();
+ }
+ if (CloneIntoModuleAndFunction(context,
+ /* to_function = */ function,
+ /* to_module = */ module,
+ /* from_function = */ new_function,
+ /* from_module = */ new_module)
+ .failed()) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // anonymous namespace
+
+// Creates an instance of the TOSA flatbuffer deserialization pass
+std::unique_ptr<Pass> createTosaDeserializePass() {
+ return std::make_unique<TosaDeserialize>();
+}
+
+std::unique_ptr<Pass> createTosaDeserializeJSONPass() {
+ return std::make_unique<TosaDeserializeJSON>();
+}
+
+static PassRegistration<TosaDeserialize> passDeserialize([] {
+ return createTosaDeserializePass();
+});
+
+static PassRegistration<TosaDeserializeJSON> passDeserializeJSON([] {
+ return createTosaDeserializeJSONPass();
+});
+
+} // namespace tosa
+} // namespace mlir