diff options
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r-- | src/TosaDeserialize.cpp | 2128 |
1 files changed, 2128 insertions, 0 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp new file mode 100644 index 0000000..215d760 --- /dev/null +++ b/src/TosaDeserialize.cpp @@ -0,0 +1,2128 @@ + +// Copyright (c) 2023-2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TOSA MLIR deserialize passes + +#include "include/DeserializationPasses.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tosa_serialization_handler.h" +#include <functional> +#include <queue> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +// The namespace might be confusing here. We have mlir::tosa:: defined in MLIR +// and tosa:: defined in serialization library +// TODO: align the namespace +using namespace tosa; + +namespace cl = llvm::cl; + +llvm::cl::opt<std::string> tosa_deserialize_filename( + "tosa-deserialize-filename", llvm::cl::desc("<tosa flatbuffer filename>"), + llvm::cl::init("tosa_dump.tosa"), llvm::cl::value_desc("filename")); + +llvm::cl::opt<std::string> tosa_deserialize_schema( + "tosa-deserialize-schema", llvm::cl::desc("<tosa flatbuffer schema file>"), + llvm::cl::init(""), llvm::cl::value_desc("filename")); + +const std::string kDefaultExportedName = "tosa_deserialized"; +const std::string kDefaultInputPrefix = "input_"; +const std::string kDefaultOutputPrefix = "output_"; +const std::string kDefaultFBSDescription = "Tosa FBS Converted"; +const std::string kDefaultJSONDescription = "Tosa JSON Converted"; +const std::string kMainFunctionName = "main"; + +namespace { + +// a global map from flatbuffer variable names to serialized tensors +std::unordered_map<std::string, TosaSerializationTensor *> variable_tensor_map; + +void RegisterVariableTensor(TosaSerializationTensor *ts) { + assert(ts->GetVariable()); + // insert variable tensor ts only if not already present + variable_tensor_map.insert({ts->GetName(), ts}); +} + +bool IsVariableTensor(const std::string flatbuffer_tensor_name) { + return variable_tensor_map.count(flatbuffer_tensor_name); +} + +// return the variable name corresponding to flatbuffer_tensor_name +const std::string GetVariableTensorName(TosaSerializationTensor *ts) { + assert(ts->GetVariable()); + const auto name = ts->GetVariableName(); + if (name == "") { + // for legacy flatbuffers which may not have variable_name fields + return ts->GetName(); + } + return name; +} + +// return the variable name corresponding to flatbuffer_tensor_name +const std::string +GetVariableTensorName(const std::string flatbuffer_tensor_name) { + if (!IsVariableTensor(flatbuffer_tensor_name)) { + llvm::errs() << "ERROR: Variable tensor " << flatbuffer_tensor_name + << " is not found in variable_tensor_map"; + return ""; + } + return GetVariableTensorName(variable_tensor_map[flatbuffer_tensor_name]); +} + +bool IsVariableReadOp(TosaSerializationOperator *op) { + return (op->GetOp() == tosa::Op::Op_IDENTITY) && + IsVariableTensor(op->GetInputTensorNames()[0]); +} + +bool IsVariableWriteOp(TosaSerializationOperator *op) { + return (op->GetOp() == tosa::Op::Op_IDENTITY) && + IsVariableTensor(op->GetOutputTensorNames()[0]); +} + +// construct tensor type from dtype and shape of TosaSerializationTensor +mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, + TosaSerializationTensor *ts, + mlir::RankedTensorType &type) { + mlir::Type element_type; + switch (ts->GetDtype()) { + case DType_BOOL: + element_type = op_builder->getI1Type(); + break; + case DType_UINT8: + element_type = op_builder->getIntegerType(8, false); + break; + case DType_INT4: + element_type = op_builder->getI4Type(); + break; + case DType_INT8: + element_type = op_builder->getI8Type(); + break; + case DType_INT16: + element_type = op_builder->getIntegerType(16); + break; + case DType_INT32: + element_type = op_builder->getI32Type(); + break; + case DType_INT48: + element_type = op_builder->getIntegerType(48); + break; + case DType_FP32: + element_type = op_builder->getF32Type(); + break; + case DType_UINT16: + element_type = op_builder->getIntegerType(16, false); + break; + case DType_FP16: + element_type = op_builder->getF16Type(); + break; + case DType_BF16: + element_type = op_builder->getBF16Type(); + break; + case DType_FP8E4M3: + element_type = op_builder->getFloat8E4M3FNType(); + break; + case DType_FP8E5M2: + element_type = op_builder->getFloat8E5M2Type(); + break; + case DType_SHAPE: + llvm::errs() + << "ERROR: Cannot construct RankedTensorType out of tosa.shape type \n"; + return mlir::failure(); + default: + llvm::errs() << "ERROR: unknown type " << EnumNamesDType()[ts->GetDtype()] + << "\n"; + return mlir::failure(); + } + llvm::SmallVector<int64_t> shape; + for (auto dim : ts->GetShape()) { + if (dim > 0) { + shape.push_back(dim); + } else { + // dynamic dim + shape.push_back(mlir::ShapedType::kDynamic); + } + } + type = mlir::RankedTensorType::get(llvm::ArrayRef(shape), element_type); + return mlir::success(); +} + +mlir::DenseElementsAttr GetConstAttr(const std::vector<uint8_t> &data, + const mlir::RankedTensorType &output_type, + uint32_t out_size) { + auto element_type = output_type.getElementType(); + if (element_type.isF32()) { + // for FP32, value attributes are stored as FP32 values + std::vector<float> float_data; + TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data); + return mlir::DenseElementsAttr::get(output_type, + llvm::ArrayRef(float_data)); + } + if (element_type.isBF16()) { + mlir::SmallVector<mlir::APFloat> bf16_data; + for (uint32_t i = 0; i < out_size; i++) { + uint64_t byte0 = data[i * sizeof(int16_t)]; + uint64_t byte1 = data[i * sizeof(int16_t) + 1]; + uint64_t bits = byte0 + (byte1 << 8); + mlir::APInt bf16_bits(16, bits); + mlir::APFloat bf16(mlir::APFloat::BFloat(), bf16_bits); + bf16_data.push_back(bf16); + } + return mlir::DenseElementsAttr::get(output_type, bf16_data); + } + if (element_type.isFloat8E4M3FN()) { + mlir::SmallVector<mlir::APFloat> f8_data; + for (uint32_t i = 0; i < out_size; i++) { + mlir::APInt f8_bits(8, static_cast<uint64_t>(data[i])); + mlir::APFloat f8(mlir::APFloat::Float8E4M3FN(), f8_bits); + f8_data.push_back(f8); + } + return mlir::DenseElementsAttr::get(output_type, f8_data); + } + if (element_type.isFloat8E5M2()) { + mlir::SmallVector<mlir::APFloat> f8_data; + for (uint32_t i = 0; i < out_size; i++) { + mlir::APInt f8_bits(8, static_cast<uint64_t>(data[i])); + mlir::APFloat f8(mlir::APFloat::Float8E5M2(), f8_bits); + f8_data.push_back(f8); + } + return mlir::DenseElementsAttr::get(output_type, f8_data); + } + if (element_type.isInteger(4)) { + std::vector<int8_t> int4_data; + TosaSerializationHandler::ConvertU8toI4(data, out_size, int4_data); + return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data)); + } + if (element_type.isInteger(8)) { + std::vector<int8_t> int8_data; + TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data); + return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data)); + } + if (element_type.isInteger(16)) { + std::vector<int16_t> int16_data; + TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data); + return mlir::DenseElementsAttr::get(output_type, + llvm::ArrayRef(int16_data)); + } + if (element_type.isInteger(32)) { + std::vector<int32_t> int32_data; + TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data); + return mlir::DenseElementsAttr::get(output_type, + llvm::ArrayRef(int32_data)); + } + if (element_type.isInteger(48)) { + std::vector<int64_t> int48_data; + TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data); + std::vector<mlir::APInt> apint_data; + for (const auto v : int48_data) { + mlir::APInt apint_value(48, static_cast<uint64_t>(v), + /* isSigned = */ false); + apint_data.push_back(apint_value); + } + return mlir::DenseElementsAttr::get(output_type, + llvm::ArrayRef(apint_data)); + } + if (element_type.isInteger(1)) { + std::vector<bool> bool_data; + TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data); + llvm::SmallVector<bool> bool_values(bool_data.begin(), bool_data.end()); + return mlir::DenseElementsAttr::get(output_type, bool_values); + } + if (element_type.isF16()) { + std::vector<half_float::half> half_data; + TosaSerializationHandler::ConvertU8toF16(data, out_size, half_data); + return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(half_data)); + } + + return nullptr; +} + +mlir::DenseElementsAttr +ConstructConstAttr(const mlir::RankedTensorType &output_type, + TosaSerializationTensor *ts, const std::string &op_name) { + // compute output data size + uint32_t out_size = 1; + for (const auto dim : ts->GetShape()) { + out_size *= dim; + } + auto attr = GetConstAttr(ts->GetData(), output_type, out_size); + if (!attr) { + llvm::errs() << "ERROR: " << op_name + << " contains unsupported element type\n"; + } + return attr; +} + +mlir::LogicalResult ConstructVariableOps(mlir::ModuleOp &module) { + if (variable_tensor_map.empty()) { + return mlir::success(); + } + auto loc = module.getLoc(); + auto op_builder = mlir::OpBuilder(module.getBodyRegion()); + for (auto [flatbuffer_name, ts] : variable_tensor_map) { + auto name = GetVariableTensorName(ts); + mlir::RankedTensorType type; + if (BuildTensorType(&op_builder, ts, type).failed()) { + return mlir::failure(); + } + + mlir::Attribute value_attr = nullptr; + if (!ts->GetData().empty()) { + value_attr = ConstructConstAttr(type, ts, name); + } + op_builder.create<mlir::tosa::VariableOp>(loc, llvm::StringRef(name), type, + value_attr); + } + + return mlir::success(); +} + +template <class T> +mlir::DenseElementsAttr BuildDenseI8ElementsAttr(mlir::OpBuilder *op_builder, + const std::vector<T> &values) { + llvm::SmallVector<int8_t> vec; + for (auto val : values) { + vec.push_back(val); + } + auto type = mlir::RankedTensorType::get({static_cast<int64_t>(vec.size())}, + op_builder->getI8Type()); + return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); +} + +template <class T> +mlir::DenseElementsAttr +BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder, + const std::vector<T> &values) { + llvm::SmallVector<int16_t> vec; + for (auto val : values) { + vec.push_back(val); + } + auto type = mlir::RankedTensorType::get({static_cast<int64_t>(vec.size())}, + op_builder->getI16Type()); + return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); +} + +template <class T> +mlir::DenseElementsAttr +BuildDenseI32ElementsAttr(mlir::OpBuilder *op_builder, + mlir::RankedTensorType &type, + const std::vector<T> &values) { + llvm::SmallVector<int32_t> vec; + for (auto val : values) { + vec.push_back(val); + } + return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); +} + +template <class T> +mlir::DenseI8ArrayAttr BuildDenseI8ArrayAttr(mlir::OpBuilder *op_builder, + const std::vector<T> &values) { + std::vector<int8_t> vec; + for (auto val : values) { + vec.push_back(val); + } + return op_builder->getDenseI8ArrayAttr(vec); +} + +template <class T> +mlir::DenseI32ArrayAttr BuildDenseI32ArrayAttr(mlir::OpBuilder *op_builder, + const std::vector<T> &values) { + std::vector<int32_t> vec; + for (auto val : values) { + vec.push_back(val); + } + return op_builder->getDenseI32ArrayAttr(vec); +} + +template <class T> +mlir::DenseI64ArrayAttr BuildDenseI64ArrayAttr(mlir::OpBuilder *op_builder, + const std::vector<T> &values) { + std::vector<int64_t> vec; + for (auto val : values) { + vec.push_back(val); + } + return op_builder->getDenseI64ArrayAttr(vec); +} + +const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) { + if (mode == ResizeMode_NEAREST) { + return "NEAREST_NEIGHBOR"; + } else if (mode == ResizeMode_BILINEAR) { + return "BILINEAR"; + } + return ""; +} + +// this is a counter part to Type2AccDType +mlir::Type AccDType2Type(mlir::OpBuilder *op_builder, DType dtype) { + // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>; + if (dtype == DType_INT32) { + return op_builder->getI32Type(); + } else if (dtype == DType_INT48) { + return op_builder->getIntegerType(48); + } else if (dtype == DType_FP32) { + return op_builder->getF32Type(); + } else if (dtype == DType_FP16) { + return op_builder->getF16Type(); + } else { + // unknown acc type + // for now, default to F32 + return op_builder->getF32Type(); + } +} + +} // namespace + +class TosaMlirRegionBuilder; +class TosaMlirBlockBuilder; + +class TosaMlirOperatorBuilder { +public: + TosaMlirOperatorBuilder( + mlir::OpBuilder *_op_builder, TosaSerializationBasicBlock *_ser_block, + mlir::Block *_block, mlir::Location _loc, + TosaMlirBlockBuilder *_block_builder, + std::unordered_map<std::string, mlir::Value> *_tensor_map, + std::unordered_map<std::string, mlir::RankedTensorType> *_tensor_type_map, + std::unordered_map<std::string, mlir::tosa::shapeType> *_shape_type_map) + : op_builder(_op_builder), ser_block(_ser_block), block(_block), + loc(_loc), block_builder(_block_builder), tensor_map(_tensor_map), + tensor_type_map(_tensor_type_map), shape_type_map(_shape_type_map) {} + + template <Op OPCODE> + std::vector<mlir::Value> build(TosaSerializationOperator *op) const; + + std::vector<mlir::Value> BuildVariableOp(TosaSerializationOperator *op) const; + + std::vector<mlir::Value> + BuildVariableReadOp(TosaSerializationOperator *op) const; + + void BuildVariableWriteOp(TosaSerializationOperator *op) const; + + std::string get_string(TosaSerializationOperator *op) const { + std::string op_string; + op_string += "operator opcode="; + op_string += EnumNamesOp()[op->GetOp()]; + op_string += ", input=["; + for (auto ts : op->GetInputTensorNames()) { + op_string += (ts + " "); + } + op_string += "], output=["; + for (auto ts : op->GetOutputTensorNames()) { + op_string += (ts + " "); + } + op_string += "]"; + return op_string; + } + + TosaSerializationHandler *GetTsh() const; + TosaMlirRegionBuilder *GetRegionBuilder() const; + +private: + template <class MLIR_OP> + std::vector<mlir::Value> + BuildEwiseUnaryOp(TosaSerializationOperator *op) const; + + template <class MLIR_OP> + std::vector<mlir::Value> + BuildEwiseBinaryOp(TosaSerializationOperator *op) const; + + template <class MLIR_OP> + std::vector<mlir::Value> + BuildEwiseBinaryShapeOp(TosaSerializationOperator *op) const; + + template <class MLIR_OP> + std::vector<mlir::Value> + BuildReductionOp(TosaSerializationOperator *op) const; + + template <class T> + mlir::Value BuildConstShape(mlir::OpBuilder *op_builder, mlir::Location loc, + const std::vector<T> &values) const; + + template <class MLIR_OP> + std::vector<mlir::Value> BuildConvOp(TosaSerializationOperator *op) const; + + mlir::OpBuilder *op_builder; + TosaSerializationBasicBlock *ser_block; + mlir::Block *block; + mlir::Location loc; + TosaMlirBlockBuilder *block_builder; + std::unordered_map<std::string, mlir::Value> *tensor_map; + std::unordered_map<std::string, mlir::RankedTensorType> *tensor_type_map; + std::unordered_map<std::string, mlir::tosa::shapeType> *shape_type_map; +}; + +// Main template to catch unimplemented translation +template <Op OPCODE> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { + llvm::errs() << "ERROR: " << get_string(op) + << " translation hasn't been implemented\n"; + return {}; +} + +// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_MAX_POOL2D>( + TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_PoolAttribute); // double check attribute type + TosaPoolAttribute *attr = + static_cast<TosaPoolAttribute *>(op->GetAttribute()); + mlir::DenseI64ArrayAttr kernel = + BuildDenseI64ArrayAttr(op_builder, attr->kernel()); + mlir::DenseI64ArrayAttr stride = + BuildDenseI64ArrayAttr(op_builder, attr->stride()); + mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); + int32_t input_zp = attr->input_zp(); + int32_t output_zp = attr->output_zp(); + assert(input_zp == 0 && output_zp == 0); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MaxPool2dOp>( + loc, output_type, input_val, kernel, stride, pad); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +// BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D) +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>( + TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_PoolAttribute); // double check attribute type + TosaPoolAttribute *attr = + static_cast<TosaPoolAttribute *>(op->GetAttribute()); + mlir::DenseI64ArrayAttr kernel = + BuildDenseI64ArrayAttr(op_builder, attr->kernel()); + mlir::DenseI64ArrayAttr stride = + BuildDenseI64ArrayAttr(op_builder, attr->stride()); + mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); + auto acc_attr = + mlir::TypeAttr::get(AccDType2Type(op_builder, attr->acc_type())); + + int32_t input_zp = attr->input_zp(); + int32_t output_zp = attr->output_zp(); + mlir::Operation *mlir_op; + if (input_zp == 0 && output_zp == 0) { + mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>( + loc, output_type, input_val, kernel, stride, pad, acc_attr); + } else { + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp); + mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>( + loc, output_type, input_val, kernel, stride, pad, acc_attr, + input_zp_attr, output_zp_attr); + } + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildVariableReadOp( + TosaSerializationOperator *op) const { + auto input_tensor_name = op->GetInputTensorNames()[0]; + auto output_tensor_name = op->GetOutputTensorNames()[0]; + + assert(IsVariableTensor(input_tensor_name)); + + auto variable_name = GetVariableTensorName(input_tensor_name); + mlir::RankedTensorType output_type = tensor_type_map->at(output_tensor_name); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::VariableReadOp>( + loc, output_type, llvm::StringRef(variable_name)); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +void TosaMlirOperatorBuilder::BuildVariableWriteOp( + TosaSerializationOperator *op) const { + auto input_tensor_name = op->GetInputTensorNames()[0]; + auto output_tensor_name = op->GetOutputTensorNames()[0]; + + assert(IsVariableTensor(output_tensor_name)); + auto variable_name = GetVariableTensorName(output_tensor_name); + mlir::Value input_val = tensor_map->at(input_tensor_name); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::VariableWriteOp>( + loc, llvm::StringRef(variable_name), input_val); + block->push_back(mlir_op); +} + +template <class MLIR_OP> +std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseUnaryOp( + TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = + op_builder->create<MLIR_OP>(loc, output_type, input_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <class MLIR_OP> +std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseBinaryOp( + TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = + op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <class MLIR_OP> +std::vector<mlir::Value> TosaMlirOperatorBuilder::BuildEwiseBinaryShapeOp( + TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::tosa::shapeType output_type = + shape_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = + op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <class MLIR_OP> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::BuildReductionOp(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_AxisAttribute); // double check attribute type + + TosaAxisAttribute *attr = + static_cast<TosaAxisAttribute *>(op->GetAttribute()); + auto axis = op_builder->getI32IntegerAttr(attr->axis()); + + mlir::Operation *mlir_op = + op_builder->create<MLIR_OP>(loc, output_type, input_val, axis); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +#define BUILD_OP_ELEMENTWISE_UNARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector<mlir::Value> \ + TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \ + TosaSerializationOperator * op) const { \ + return BuildEwiseUnaryOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \ + } + +#define BUILD_OP_ELEMENTWISE_BINARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector<mlir::Value> \ + TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \ + TosaSerializationOperator * op) const { \ + return BuildEwiseBinaryOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \ + } + +#define BUILD_OP_ELEMENTWISE_BINARY_SHAPE(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector<mlir::Value> \ + TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \ + TosaSerializationOperator * op) const { \ + return BuildEwiseBinaryShapeOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \ + } + +#define BUILD_OP_REDUCTION(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector<mlir::Value> \ + TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \ + TosaSerializationOperator * op) const { \ + return BuildReductionOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \ + } + +// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) +// BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D) + +BUILD_OP_ELEMENTWISE_BINARY(Add, ADD) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR) +BUILD_OP_ELEMENTWISE_BINARY(IntDiv, INTDIV) +BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND) +BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT) +BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT) +BUILD_OP_ELEMENTWISE_BINARY(LogicalOr, LOGICAL_OR) +BUILD_OP_ELEMENTWISE_BINARY(LogicalXor, LOGICAL_XOR) +BUILD_OP_ELEMENTWISE_BINARY(Maximum, MAXIMUM) +BUILD_OP_ELEMENTWISE_BINARY(Minimum, MINIMUM) +BUILD_OP_ELEMENTWISE_BINARY(Pow, POW) +BUILD_OP_ELEMENTWISE_BINARY(Sub, SUB) + +BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS) +BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT) +BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL) +BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ) +BUILD_OP_ELEMENTWISE_UNARY(Cos, COS) +BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP) +BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR) +BUILD_OP_ELEMENTWISE_UNARY(Log, LOG) +BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT) +BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL) +BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT) +BUILD_OP_ELEMENTWISE_UNARY(Sin, SIN) + +BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY) +BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL) +BUILD_OP_REDUCTION(ReduceMax, REDUCE_MAX) +BUILD_OP_REDUCTION(ReduceMin, REDUCE_MIN) +BUILD_OP_REDUCTION(ReduceProd, REDUCE_PRODUCT) +BUILD_OP_REDUCTION(ReduceSum, REDUCE_SUM) + +BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL) +BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER) +BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL) + +BUILD_OP_ELEMENTWISE_UNARY(Erf, ERF) +BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID) +BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH) +BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY) +BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST) + +BUILD_OP_ELEMENTWISE_BINARY_SHAPE(AddShape, ADD_SHAPE) +BUILD_OP_ELEMENTWISE_BINARY_SHAPE(SubShape, SUB_SHAPE) +BUILD_OP_ELEMENTWISE_BINARY_SHAPE(MulShape, MUL_SHAPE) +BUILD_OP_ELEMENTWISE_BINARY_SHAPE(DivShape, DIV_SHAPE) + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_CONST>(TosaSerializationOperator *op) const { + const auto &output_name = op->GetOutputTensorNames()[0]; + mlir::RankedTensorType output_type = tensor_type_map->at(output_name); + TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name); + auto value_attr = ConstructConstAttr(output_type, ts, get_string(op)); + if (!value_attr) { + return {}; + } + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::ConstOp>(loc, output_type, value_attr); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <class T> +mlir::Value +TosaMlirOperatorBuilder::BuildConstShape(mlir::OpBuilder *op_builder, + mlir::Location loc, + const std::vector<T> &values) const { + std::vector<int64_t> vec; + for (auto val : values) { + vec.push_back(val); + } + auto attr = op_builder->getIndexTensorAttr(vec); + auto type = mlir::tosa::shapeType::get(op_builder->getContext(), + /* rank = */ vec.size()); + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::ConstShapeOp>(loc, type, attr); + block->push_back(mlir_op); + return mlir_op->getResult(0); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONST_SHAPE>( + TosaSerializationOperator *op) const { + const auto &output_name = op->GetOutputTensorNames()[0]; + mlir::tosa::shapeType output_type = shape_type_map->at(output_name); + TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name); + + const auto &data = ts->GetData(); + + std::vector<int64_t> i64_data; + TosaSerializationHandler::ConvertU8toI64(data, output_type.getRank(), + i64_data); + mlir::Value result = BuildConstShape(op_builder, loc, i64_data); + return std::vector<mlir::Value>({result}); +} + +template <class MLIR_OP> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_ConvAttribute); // double check attribute type + TosaConvAttribute *attr = + static_cast<TosaConvAttribute *>(op->GetAttribute()); + mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); + mlir::DenseI64ArrayAttr stride = + BuildDenseI64ArrayAttr(op_builder, attr->stride()); + mlir::DenseI64ArrayAttr dilation = + BuildDenseI64ArrayAttr(op_builder, attr->dilation()); + auto input_zp = attr->input_zp(); + auto weight_zp = attr->weight_zp(); + bool local_bound = attr->local_bound(); + auto acc_type = AccDType2Type(op_builder, attr->acc_type()); + + // input_zp/weight_zp is not allowed for float type + mlir::Operation *mlir_op; + if (output_type.getElementType().isa<mlir::FloatType>()) { + assert(input_zp == 0 && weight_zp == 0); + } + + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); + mlir_op = op_builder->create<MLIR_OP>( + loc, output_type, input0_val, input1_val, input2_val, pad, stride, + dilation, acc_type, input_zp_attr, weight_zp_attr, local_bound); + + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +#define BUILD_OP_CONV(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector<mlir::Value> \ + TosaMlirOperatorBuilder::build<Op_##SCHEMA_OP_NAME>( \ + TosaSerializationOperator * op) const { \ + return BuildConvOp<mlir::tosa::MLIR_OP_NAME##Op>(op); \ + } + +BUILD_OP_CONV(Conv2D, CONV2D) +BUILD_OP_CONV(Conv3D, CONV3D) +BUILD_OP_CONV(DepthwiseConv2D, DEPTHWISE_CONV2D) + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE_CONV2D>( + TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_TransposeConvAttribute); // double check attribute type + TosaTransposeConvAttribute *attr = + static_cast<TosaTransposeConvAttribute *>(op->GetAttribute()); + mlir::DenseI64ArrayAttr out_pad = + BuildDenseI64ArrayAttr(op_builder, attr->out_pad()); + mlir::DenseI64ArrayAttr stride = + BuildDenseI64ArrayAttr(op_builder, attr->stride()); + auto input_zp = attr->input_zp(); + auto weight_zp = attr->weight_zp(); + bool local_bound = attr->local_bound(); + auto acc_type = AccDType2Type(op_builder, attr->acc_type()); + + // input_zp/weight_zp is not allowed for float type + mlir::Operation *mlir_op; + if (output_type.getElementType().isa<mlir::FloatType>()) { + assert(input_zp == 0 && weight_zp == 0); + } + + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); + + mlir_op = op_builder->create<mlir::tosa::TransposeConv2DOp>( + loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, + acc_type, input_zp_attr, weight_zp_attr, local_bound); + + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_FULLY_CONNECTED>( + TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_FullyConnectedAttribute); // double check attribute type + TosaFullyConnectedAttribute *attr = + static_cast<TosaFullyConnectedAttribute *>(op->GetAttribute()); + auto input_zp = attr->input_zp(); + auto weight_zp = attr->weight_zp(); + + // input_zp/weight_zp is not allowed for float type + mlir::Operation *mlir_op; + if (output_type.getElementType().isa<mlir::FloatType>()) { + assert(input_zp == 0 && weight_zp == 0); + mlir_op = op_builder->create<mlir::tosa::FullyConnectedOp>( + loc, output_type, input0_val, input1_val, input2_val); + } else { + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); + mlir_op = op_builder->create<mlir::tosa::FullyConnectedOp>( + loc, output_type, input0_val, input1_val, input2_val, input_zp_attr, + weight_zp_attr); + } + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_MATMUL>(TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_MatMulAttribute); // double check attribute type + TosaMatMulAttribute *attr = + static_cast<TosaMatMulAttribute *>(op->GetAttribute()); + auto A_zp = attr->a_zp(); + auto B_zp = attr->b_zp(); + + mlir::Operation *mlir_op; + if (A_zp == 0 && B_zp == 0) { + mlir_op = op_builder->create<mlir::tosa::MatMulOp>(loc, output_type, + input0_val, input1_val); + } else { + auto a_zp_attr = op_builder->getI32IntegerAttr(A_zp); + auto b_zp_attr = op_builder->getI32IntegerAttr(B_zp); + mlir_op = op_builder->create<mlir::tosa::MatMulOp>( + loc, output_type, input0_val, input1_val, a_zp_attr, b_zp_attr); + } + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_SELECT>(TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::SelectOp>( + loc, output_type, input0_val, input1_val, input2_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_ClampAttribute); // double check attribute type + TosaClampAttribute *attr = + static_cast<TosaClampAttribute *>(op->GetAttribute()); + + mlir::Type element_type = + llvm::cast<mlir::ShapedType>(input_val.getType()).getElementType(); + if (auto quantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(element_type)) { + element_type = quantType.getStorageType(); + } + + auto element_const_type = mlir::RankedTensorType::get({1}, element_type); + auto min_values_attr = GetConstAttr(attr->min_val(), element_const_type, 1); + auto max_values_attr = GetConstAttr(attr->max_val(), element_const_type, 1); + + mlir::Attribute min_val_attr, max_val_attr; + if (element_type.isa<mlir::FloatType>()) { + min_val_attr = op_builder->getFloatAttr( + element_type, min_values_attr.getValues<mlir::APFloat>()[0]); + max_val_attr = op_builder->getFloatAttr( + element_type, max_values_attr.getValues<mlir::APFloat>()[0]); + } else { + min_val_attr = op_builder->getIntegerAttr( + element_type, min_values_attr.getValues<mlir::APInt>()[0]); + max_val_attr = op_builder->getIntegerAttr( + element_type, max_values_attr.getValues<mlir::APInt>()[0]); + } + + auto mlir_op = op_builder->create<mlir::tosa::ClampOp>( + loc, output_type, input_val, min_val_attr, max_val_attr); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +// ArgMax has single input, and single I64 axis attribute +BUILD_OP_REDUCTION(ArgMax, ARGMAX) + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_CONCAT>(TosaSerializationOperator *op) const { + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + llvm::SmallVector<mlir::Value> input_values; + for (auto &input_name : op->GetInputTensorNames()) { + mlir::Value input_val = tensor_map->at(input_name); + input_values.push_back(input_val); + } + + assert(op->GetAttributeType() == + Attribute_AxisAttribute); // double check attribute type + TosaAxisAttribute *attr = + static_cast<TosaAxisAttribute *>(op->GetAttribute()); + auto axis = op_builder->getI32IntegerAttr(attr->axis()); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ConcatOp>( + loc, output_type, input_values, axis); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONCAT_SHAPE>( + TosaSerializationOperator *op) const { + mlir::tosa::shapeType output_type = + shape_type_map->at(op->GetOutputTensorNames()[0]); + + llvm::SmallVector<mlir::Value> input_values; + for (auto &input_name : op->GetInputTensorNames()) { + mlir::Value input_val = tensor_map->at(input_name); + input_values.push_back(input_val); + } + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ConcatShapeOp>( + loc, output_type, input_values); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_NEGATE>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_NegateAttribute); // double check attribute type + TosaNegateAttribute *attr = + static_cast<TosaNegateAttribute *>(op->GetAttribute()); + + auto input_zp = attr->input1_zp(); + auto output_zp = attr->output_zp(); + + mlir::Operation *mlir_op; + if (input_zp == 0 && output_zp == 0) { + mlir_op = + op_builder->create<mlir::tosa::NegateOp>(loc, output_type, input_val); + } else { + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp); + mlir_op = op_builder->create<mlir::tosa::NegateOp>( + loc, output_type, input_val, input_zp_attr, output_zp_attr); + } + + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESHAPE>( + TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + mlir::Value shape_val = tensor_map->at(op->GetInputTensorNames()[1]); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ReshapeOp>( + loc, output_type, input_val, shape_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value padding_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType input_type = + tensor_type_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + const auto element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); + + assert(op->GetAttributeType() == + Attribute_PadAttribute); // double check attribute type + TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute()); + const auto &pad_const_u8_data = attr->pad_const(); + + // check for any value in pad_const_u8_data + bool has_pad_const = false; + for (auto v : pad_const_u8_data) { + if (v != 0) { + has_pad_const = true; + break; + } + } + if (!has_pad_const) { + // handle the cases where no explicit pad_const input. + auto mlir_op = op_builder->create<mlir::tosa::PadOp>( + loc, output_type, input_val, padding_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); + } + + // has pad const - create a const op for pad_const input + auto pad_const_type = mlir::RankedTensorType::get({}, element_type); + auto pad_const_attr = GetConstAttr(pad_const_u8_data, pad_const_type, 1); + + auto pad_const_op = op_builder->create<mlir::tosa::ConstOp>( + loc, pad_const_type, pad_const_attr); + + block->push_back(pad_const_op); + mlir::Value pad_const_value = pad_const_op->getResult(0); + + auto mlir_op = op_builder->create<mlir::tosa::PadOp>( + loc, output_type, input_val, padding_val, pad_const_value); + + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_DIM>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_AxisAttribute); // double check attribute type + TosaAxisAttribute *attr = + static_cast<TosaAxisAttribute *>(op->GetAttribute()); + auto axis = op_builder->getI32IntegerAttr(attr->axis()); + + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::DimOp>(loc, output_type, input_val, axis); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE>( + TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_TransposeAttribute); // double check attribute type + TosaTransposeAttribute *attr = + static_cast<TosaTransposeAttribute *>(op->GetAttribute()); + // make a constant op from attr->perms values, of type: { shape = { perms.size + // }, element_type = I32 } + const auto perms_values = attr->perms(); + auto const_type = mlir::RankedTensorType::get( + {static_cast<int64_t>(perms_values.size())}, op_builder->getI32Type()); + mlir::DenseElementsAttr const_attr = + BuildDenseI32ElementsAttr(op_builder, const_type, perms_values); + mlir::Operation *mlir_const_op = + op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr); + auto perms_val = mlir_const_op->getResult(0); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TransposeOp>( + loc, output_type, input_val, perms_val); + + block->push_back(mlir_const_op); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_SLICE>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + mlir::Value start = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value size = tensor_map->at(op->GetInputTensorNames()[2]); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::SliceOp>( + loc, output_type, input_val, start, size); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_TILE>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value multiples = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_NONE); // double check attribute type + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TileOp>( + loc, output_type, input_val, multiples); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +// Gather is a binary op +BUILD_OP_ELEMENTWISE_BINARY(Gather, GATHER) + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_SCATTER>( + TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ScatterOp>( + loc, output_type, input0_val, input1_val, input2_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_RESIZE>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value scale_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value offset_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::Value border_val = tensor_map->at(op->GetInputTensorNames()[3]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_ResizeAttribute); // double check attribute type + TosaResizeAttribute *attr = + static_cast<TosaResizeAttribute *>(op->GetAttribute()); + + auto mode = op_builder->getStringAttr(ResizeEnum2Str(attr->mode())); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ResizeOp>( + loc, output_type, input_val, scale_val, offset_val, border_val, mode); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +// Reverse has single input, and single I64 axis attribute +BUILD_OP_REDUCTION(Reverse, REVERSE) + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_MUL>(TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]); + + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_MulAttribute); // double check attribute type + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MulOp>( + loc, output_type, input0_val, input1_val, shift_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_ARITHMETIC_RIGHT_SHIFT>( + TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert( + op->GetAttributeType() == + Attribute_ArithmeticRightShiftAttribute); // double check attribute type + TosaArithmeticRightShiftAttribute *attr = + static_cast<TosaArithmeticRightShiftAttribute *>(op->GetAttribute()); + + auto round = op_builder->getBoolAttr(attr->round()); + + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::ArithmeticRightShiftOp>( + loc, output_type, input0_val, input1_val, round); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_TableAttribute); // double check attribute type + TosaTableAttribute *attr = + static_cast<TosaTableAttribute *>(op->GetAttribute()); + + // create a const op for table value attribute + const auto table_values = attr->table(); + mlir::RankedTensorType const_type; + mlir::DenseElementsAttr const_attr; + const auto input_element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); + if (input_element_type.isInteger(8)) { + // table is signed 8 mode + const_type = mlir::RankedTensorType::get( + {static_cast<int64_t>(table_values.size())}, op_builder->getI8Type()); + const_attr = BuildDenseI8ElementsAttr(op_builder, table_values); + } else { + // table is signed 16 mode + const_type = mlir::RankedTensorType::get( + {static_cast<int64_t>(table_values.size())}, op_builder->getI16Type()); + const_attr = BuildDenseI16ElementsAttr(op_builder, table_values); + } + mlir::Operation *mlir_const_op = + op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr); + auto table_value = mlir_const_op->getResult(0); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TableOp>( + loc, output_type, input_val, table_value); + block->push_back(mlir_const_op); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESCALE>( + TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value multiplier_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_RescaleAttribute); // double check attribute type + TosaRescaleAttribute *attr = + static_cast<TosaRescaleAttribute *>(op->GetAttribute()); + + auto input_zp = op_builder->getI32IntegerAttr(attr->input_zp()); + auto output_zp = op_builder->getI32IntegerAttr(attr->output_zp()); + + auto scale32 = op_builder->getBoolAttr(attr->scale32()); + auto double_round = op_builder->getBoolAttr(attr->double_round()); + auto per_channel = op_builder->getBoolAttr(attr->per_channel()); + + auto input_unsigned = op_builder->getBoolAttr(attr->input_unsigned()); + auto output_unsigned = op_builder->getBoolAttr(attr->output_unsigned()); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RescaleOp>( + loc, output_type, input_val, multiplier_val, shift_val, input_zp, + output_zp, scale32, double_round, per_channel, input_unsigned, + output_unsigned); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_CUSTOM>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_CustomAttribute); // double check attribute type + TosaCustomAttribute *attr = + static_cast<TosaCustomAttribute *>(op->GetAttribute()); + + auto operator_name = op_builder->getStringAttr(attr->operator_name()); + auto domain_name = op_builder->getStringAttr(attr->domain_name()); + std::string impl_str; + impl_str.resize(attr->implementation_attrs().size() + 1); + int idx = 0; + for (auto c : attr->implementation_attrs()) { + impl_str[idx++] = c; + } + auto impl = op_builder->getStringAttr(impl_str); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::CustomOp>( + loc, output_type, operator_name, domain_name, impl, input_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_RFFT2D>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output0_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + mlir::RankedTensorType output1_type = + tensor_type_map->at(op->GetOutputTensorNames()[1]); + assert(op->GetAttributeType() == + Attribute_RFFTAttribute); // double check attribute type + TosaRFFTAttribute *attr = + static_cast<TosaRFFTAttribute *>(op->GetAttribute()); + + bool local_bound = attr->local_bound(); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RFFT2dOp>( + loc, output0_type, output1_type, input_val, local_bound); + block->push_back(mlir_op); + return std::vector<mlir::Value>( + {mlir_op->getResult(0), mlir_op->getResult(1)}); +} + +template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_FFT2D>(TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output0_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + mlir::RankedTensorType output1_type = + tensor_type_map->at(op->GetOutputTensorNames()[1]); + + assert(op->GetAttributeType() == Attribute_FFTAttribute); + TosaFFTAttribute *attr = static_cast<TosaFFTAttribute *>(op->GetAttribute()); + auto inverse = op_builder->getBoolAttr(attr->inverse()); + auto local_bound = op_builder->getBoolAttr(attr->local_bound()); + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::FFT2dOp>( + loc, output0_type, output1_type, input0_val, input1_val, inverse, + local_bound); + block->push_back(mlir_op); + return std::vector<mlir::Value>( + {mlir_op->getResult(0), mlir_op->getResult(1)}); +} + +class TosaMlirRegionBuilder { +public: + TosaMlirRegionBuilder(TosaSerializationRegion *_ser_region, + TosaSerializationHandler *_tsh, mlir::Region *_region, + mlir::OpBuilder *_op_builder, mlir::Location _loc, + TosaMlirRegionBuilder *_parent_value_scope = nullptr) + : ser_region(_ser_region), tsh(_tsh), region(_region), + op_builder(_op_builder), loc(_loc) { + if (_parent_value_scope) { + // inherit parent_value_scope's tensor_map + for (auto &kv : _parent_value_scope->GetTensorMap()) { + tensor_map.insert(kv); + } + } + } + + mlir::LogicalResult + BuildAllBlocksInRegion(std::vector<mlir::Value> &return_values); + + mlir::OpBuilder *GetOpBuilder() { return op_builder; } + mlir::Location GetLocation() { return loc; } + std::unordered_map<std::string, mlir::Value> &GetTensorMap() { + return tensor_map; + } + TosaSerializationHandler *GetTsh() const { return tsh; } + +private: + mlir::Region *region; + TosaSerializationRegion *ser_region; + TosaSerializationHandler *tsh; + mlir::OpBuilder *op_builder; + mlir::Location loc; + std::unordered_map<std::string, mlir::Value> tensor_map; +}; + +class TosaMlirBlockBuilder { +public: + TosaMlirBlockBuilder(TosaSerializationBasicBlock *_ser_block, + TosaMlirRegionBuilder *_region_builder, + mlir::Block *_block) + : ser_block(_ser_block), region_builder(_region_builder), block(_block) {} + + mlir::LogicalResult + BuildAllOpsInBlock(std::vector<mlir::Value> &return_values); + + mlir::OpBuilder *GetOpBuilder() { return region_builder->GetOpBuilder(); } + mlir::Location GetLocation() { return region_builder->GetLocation(); } + std::unordered_map<std::string, mlir::Value> &GetTensorMap() { + return region_builder->GetTensorMap(); + } + + TosaSerializationHandler *GetTsh() const { return region_builder->GetTsh(); } + TosaMlirRegionBuilder *GetRegionBuilder() const { return region_builder; } + +private: + TosaSerializationBasicBlock *ser_block; + TosaMlirRegionBuilder *region_builder; + mlir::Block *block; + std::unordered_map<std::string, mlir::RankedTensorType> tensor_type_map; + std::unordered_map<std::string, mlir::tosa::shapeType> shape_type_map; + std::unordered_set<std::string> unranked_tensors; +}; + +TosaSerializationHandler *TosaMlirOperatorBuilder::GetTsh() const { + return block_builder->GetTsh(); +} + +TosaMlirRegionBuilder *TosaMlirOperatorBuilder::GetRegionBuilder() const { + return block_builder->GetRegionBuilder(); +} + +// build control flow ops: + +namespace { + +mlir::LogicalResult +BuildRegion(TosaSerializationRegion *ser_region, TosaSerializationHandler *tsh, + mlir::Region *mlir_region, mlir::OpBuilder *op_builder, + mlir::Location loc, std::vector<mlir::Value> &return_values, + bool isolated_from_above = false, + TosaMlirRegionBuilder *parent_region_builder = nullptr) { + TosaMlirRegionBuilder *parent_value_scope = + isolated_from_above ? nullptr : parent_region_builder; + TosaMlirRegionBuilder region_builder(ser_region, tsh, mlir_region, op_builder, + loc, parent_value_scope); + return region_builder.BuildAllBlocksInRegion(return_values); +} + +} // namespace + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_COND_IF>( + TosaSerializationOperator *op) const { + mlir::Value cond_val = tensor_map->at(op->GetInputTensorNames().at(0)); + std::vector<mlir::Value> input_values; + for (auto idx = 1u; idx < op->GetInputTensorNames().size(); idx++) { + input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx))); + } + std::vector<mlir::Type> output_types; + for (auto &name : op->GetInputTensorNames()) { + output_types.push_back(tensor_type_map->at(name)); + } + + assert(op->GetAttributeType() == + Attribute_CondIfAttribute); // double check attribute type + TosaCondIfAttribute *attr = + static_cast<TosaCondIfAttribute *>(op->GetAttribute()); + auto ser_then_region = GetTsh()->GetRegionByName(attr->then_graph()); + auto ser_else_region = GetTsh()->GetRegionByName(attr->else_graph()); + + if (!ser_then_region || !ser_else_region) { + llvm::errs() << "ERROR: " << get_string(op) + << " region serialization hasn't been implemented\n"; + return {}; + } + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::IfOp>( + loc, output_types, cond_val, input_values); + + const bool isolated_from_above = + mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>(); + mlir::Region &then_region = mlir_op->getRegion(0); + mlir::Region &else_region = mlir_op->getRegion(1); + + auto curr_region_builder = GetRegionBuilder(); + + std::vector<mlir::Value> then_returns, else_returns; + + if (BuildRegion(ser_then_region, GetTsh(), &then_region, op_builder, loc, + then_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (then_returns.size() != mlir_op->getNumResults()) { + llvm::errs() + << "ERROR: " << get_string(op) + << " then_region yield.size() doesn't match cond_if's output size\n"; + return {}; + } + + if (BuildRegion(ser_else_region, GetTsh(), &else_region, op_builder, loc, + else_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (else_returns.size() != mlir_op->getNumResults()) { + llvm::errs() + << "ERROR: " << get_string(op) + << " else_region yield.size() doesn't match cond_if's output size\n"; + return {}; + } + + block->push_back(mlir_op); + return std::vector<mlir::Value>(mlir_op->getResults().begin(), + mlir_op->getResults().end()); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_WHILE_LOOP>( + TosaSerializationOperator *op) const { + std::vector<mlir::Value> input_values; + for (auto idx = 0u; idx < op->GetInputTensorNames().size(); idx++) { + input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx))); + } + std::vector<mlir::Type> output_types; + for (auto &name : op->GetInputTensorNames()) { + output_types.push_back(tensor_type_map->at(name)); + } + assert(op->GetAttributeType() == + Attribute_WhileLoopAttribute); // double check attribute type + TosaWhileLoopAttribute *attr = + static_cast<TosaWhileLoopAttribute *>(op->GetAttribute()); + auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_graph()); + auto ser_body_region = GetTsh()->GetRegionByName(attr->body_graph()); + + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::WhileOp>(loc, output_types, input_values); + + const bool isolated_from_above = + mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>(); + + mlir::Region &cond_region = mlir_op->getRegion(0); + mlir::Region &body_region = mlir_op->getRegion(1); + + auto curr_region_builder = GetRegionBuilder(); + + std::vector<mlir::Value> cond_returns, body_returns; + + if (BuildRegion(ser_cond_region, GetTsh(), &cond_region, op_builder, loc, + cond_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (cond_returns.size() != 1) { + llvm::errs() << "ERROR: " << get_string(op) + << " cond_region yield.size() is not 1\n"; + return {}; + } + + if (BuildRegion(ser_body_region, GetTsh(), &body_region, op_builder, loc, + body_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (body_returns.size() != mlir_op->getNumResults()) { + llvm::errs() + << "ERROR: " << get_string(op) + << " body_region yield.size() doesn't match while_loop's output size\n"; + return {}; + } + + block->push_back(mlir_op); + return std::vector<mlir::Value>(mlir_op->getResults().begin(), + mlir_op->getResults().end()); +} + +mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( + std::vector<mlir::Value> &return_values) { + block->clear(); + auto loc = GetLocation(); + auto op_builder = GetOpBuilder(); + auto &tensor_map = GetTensorMap(); + + std::unordered_set<TosaSerializationOperator *> operator_built; + std::queue<TosaSerializationOperator *> operator_queue; + + TosaMlirOperatorBuilder tosa_op_builder(op_builder, ser_block, block, loc, + this, &tensor_map, &tensor_type_map, + &shape_type_map); + + for (auto ts : ser_block->GetTensors()) { + if (ts->GetVariable()) { + RegisterVariableTensor(ts); + } + const auto &ts_name = ts->GetName(); + if (ts->GetDtype() == DType::DType_SHAPE) { + // ts is tosa.shape type + auto shape_rank = ts->GetShape()[0]; + shape_type_map[ts_name] = + mlir::tosa::shapeType::get(op_builder->getContext(), shape_rank); + continue; + } + mlir::RankedTensorType type; + if (BuildTensorType(op_builder, ts, type).failed()) { + return mlir::failure(); + } + tensor_type_map[ts_name] = type; + if (ts->GetIsUnranked()) { + assert(ts->GetShape().empty()); // unranked tensors should have shape = {} + unranked_tensors.insert(ts_name); + } + } + + // Initialize tensor_map/operator_queue based on block input arguments + for (const std::string &block_input_name : ser_block->GetInputs()) { + mlir::Type type = tensor_type_map[block_input_name]; + if (unranked_tensors.count(block_input_name)) { + // recast type as unranked tensor type + auto element_type = type.cast<mlir::RankedTensorType>().getElementType(); + type = mlir::UnrankedTensorType::get(element_type); + } + auto input_value = block->addArgument(type, loc); + if (tensor_map.count(block_input_name)) { + llvm::errs() << "ERROR: block input tensor " << block_input_name + << " already exists\n"; + return mlir::failure(); + } + tensor_map[block_input_name] = input_value; + } + + for (auto op : ser_block->GetOperators()) { + + // skip if operator has been built + if (operator_built.count(op)) { + // this happens when same input appears twice or more in operator, eg, + // concat(%0, %0) + continue; + } + operator_built.insert(op); + + std::vector<mlir::Value> output_values; + if (IsVariableReadOp(op)) { + output_values = tosa_op_builder.BuildVariableReadOp(op); + } else if (IsVariableWriteOp(op)) { + tosa_op_builder.BuildVariableWriteOp(op); + } +#define DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME) \ + else if (op->GetOp() == Op_##SCHEMA_OP_NAME) { \ + output_values = tosa_op_builder.build<Op_##SCHEMA_OP_NAME>(op); \ + } +#include "schema_operator.def" +#undef DEF_SCHEMA_OPERATOR + else { + llvm::errs() << "ERROR: unsupported opcode=" << EnumNamesOp()[op->GetOp()] + << "\n"; + return mlir::failure(); + } + + if (IsVariableWriteOp(op)) { + // the sanity checking below does not apply for variable write op because + // it has no output tensors whereas the original identity op has + continue; + } + + // Sanity check if number of built mlir::Value is expected + if (op->GetOutputTensorNames().size() != output_values.size()) { + llvm::errs() << "ERROR: number of built mlir::Value is not matching " + "number of operator output tensor\n"; + return mlir::failure(); + } + + for (size_t i = 0; i < output_values.size(); i++) { + // Sanity check tensor hasn't been built + std::string op_output_name = op->GetOutputTensorNames()[i]; + if (tensor_map.count(op_output_name)) { + llvm::errs() << "ERROR: tensor " << op_output_name + << " is already built\n"; + return mlir::failure(); + } + tensor_map[op_output_name] = output_values[i]; + } + } + + // Construct return values + std::vector<mlir::Value> return_operands; + for (const auto &output_name : ser_block->GetOutputs()) { + // Sanity check if terminator mlir::Value is built + if (!tensor_map.count(output_name)) { + llvm::errs() << "ERROR: terminator mlir::Value " << output_name + << " is not built in block " << ser_block->GetName() << "\n"; + return mlir::failure(); + } + mlir::Value output_value = tensor_map.at(output_name); + return_operands.push_back(output_value); + return_values.push_back(output_value); + } + mlir::Operation *terminator_op; + auto parent_op = block->getParentOp(); + if (mlir::isa<mlir::func::FuncOp>(parent_op)) { + terminator_op = + op_builder->create<mlir::func::ReturnOp>(loc, return_operands); + } else { + terminator_op = + op_builder->create<mlir::tosa::YieldOp>(loc, return_operands); + } + block->push_back(terminator_op); + + // need topological sorting? + + return mlir::success(); +} + +mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion( + std::vector<mlir::Value> &return_values) { + for (auto &ser_block : ser_region->GetBlocks()) { + auto &block = region->emplaceBlock(); + TosaMlirBlockBuilder block_builder(ser_block, this, &block); + + if (block_builder.BuildAllOpsInBlock(return_values).failed()) { + return mlir::failure(); + } + + if (return_values.empty()) { + llvm::errs() << "Warning: graph doesn't have return values\n"; + } + } + return mlir::success(); +} + +mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp &func, + mlir::MLIRContext &context, + tosa::TosaSerializationHandler &tsh, + std::vector<mlir::Value> &main_returns) { + + mlir::Region *main_region = func.getCallableRegion(); + if (!main_region) { + llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n"; + return mlir::failure(); + } + + TosaSerializationRegion *ser_main_region = tsh.GetRegions().front(); + + auto loc = func.getLoc(); + + main_region->takeBody(*main_region); // empty old func body + auto op_builder = mlir::OpBuilder(func.getBody()); + + if (BuildRegion(ser_main_region, &tsh, main_region, &op_builder, loc, + main_returns) + .failed()) { + return mlir::failure(); + } + + if (main_returns.empty()) { + llvm::errs() << "Warning: graph doesn't have return values\n"; + } + + return mlir::success(); +} + +// Load Tosa Schema into TosaSerializationHandler, required for JSON save/load +mlir::LogicalResult loadTosaSchema(tosa::TosaSerializationHandler &tsh) { + const char *tosa_schema = tosa_deserialize_schema.c_str(); + + if (!tosa_schema) { + llvm::errs() << "Flatbuffer schema not defined\n"; + return mlir::failure(); + } + + if (tsh.LoadFileSchema(tosa_schema)) { + llvm::errs() << "Error loading tosa schema file: " << tosa_schema << "\n"; + return mlir::failure(); + } + return mlir::success(); +} + +namespace { + +mlir::NamedAttribute DefaultEntryFuncitonAttr(mlir::Builder &builder, + bool is_input, int count) { + std::string names; + for (int i = 0; i < count; i++) { + std::string name = kDefaultExportedName + "_"; + name += (is_input ? kDefaultInputPrefix : kDefaultOutputPrefix); + name += std::to_string(i) + ":0"; + if (i > 0) { + names += ","; + } + names += name; + } + return builder.getNamedAttr((is_input ? "inputs" : "outputs"), + builder.getStringAttr(names)); +} + +// erase all ops in block except for FuncOp +void ClearNonFuncOps(mlir::Block *block) { + std::vector<mlir::Operation *> to_delete; + for (auto &op : block->getOperations()) { + if (!mlir::isa<mlir::func::FuncOp>(op)) { + to_delete.push_back(&op); + } + } + for (mlir::Operation *op : to_delete) { + op->erase(); + } +} + +// erase function attrs and empty function region's body +void ResetFunction(mlir::func::FuncOp &function, mlir::MLIRContext &context) { + function->setAttrs(mlir::DictionaryAttr::get(&context, {})); + mlir::Region *main_region = function.getCallableRegion(); + main_region->takeBody(*main_region); +} + +// replace attrs and body of @a to_function and its parent module +// by @a from_module and its "main" function +mlir::LogicalResult CloneIntoModuleAndFunction( + mlir::MLIRContext &context, mlir::func::FuncOp &to_function, + mlir::ModuleOp &to_module, mlir::func::FuncOp &from_function, + mlir::ModuleOp &from_module) { + auto from_block = from_function.getOperation()->getBlock(); + auto to_block = to_function.getOperation()->getBlock(); + ClearNonFuncOps(to_block); + // copy all attrs from new_module to module + to_module->setAttrs(from_module->getAttrDictionary()); + // erase attrs and body of function + ResetFunction(to_function, context); + // clone new_func attrs and region into function + mlir::IRMapping mapping; + from_function.cloneInto(to_function, mapping); + + // copy variable ops in from_block to to_block + // collect variable ops in from_block in reverse order + std::vector<mlir::Operation *> variable_ops; + for (mlir::Operation &op : *from_block) { + if (mlir::isa<mlir::tosa::VariableOp>(op)) { + variable_ops.push_back(&op); + } + } + auto cloneOptions = + mlir::Operation::CloneOptions::all().cloneRegions(false).cloneOperands( + false); + for (auto iter = variable_ops.rbegin(); iter != variable_ops.rend(); iter++) { + auto op = *iter; + to_block->push_front(op->clone(mapping, cloneOptions)); + } + return mlir::success(); +} + +} // namespace + +namespace mlir { + +namespace tosa { + +mlir::OwningOpRef<mlir::ModuleOp> +BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context, + bool file_is_fbs) { + TosaSerializationHandler tsh; + if (file_is_fbs) { + if (tsh.LoadFileTosaFlatbuffer(file_name)) { + llvm::errs() << "Fail to load TOSA file " << file_name << "\n"; + return nullptr; + } + } else { + // must load tosa schema before loading json file + if (loadTosaSchema(tsh).failed()) { + return nullptr; + } + if (tsh.LoadFileJson(file_name)) { + llvm::errs() << "Fail to load TOSA JSON file " << file_name << "\n"; + return nullptr; + } + } + + // create new module + auto base_loc = mlir::FileLineColLoc::get(context, file_name, 0, 0); + auto module = mlir::ModuleOp::create(base_loc); + + // set module attributes + const auto &tosa_version = tsh.GetVersion().to_string(); + std::string tosa_description = + file_is_fbs ? kDefaultFBSDescription : kDefaultJSONDescription; + auto builder = mlir::Builder(context); + module->setAttr("tosa.fbs_version", builder.getStringAttr(tosa_version)); + module->setAttr("tosa.description", builder.getStringAttr(tosa_description)); + module->setAttr("tf_saved_model.semantics", mlir::UnitAttr::get(context)); + + // construct function with input and return types + llvm::SmallVector<mlir::Type, 2> ret_types; + llvm::SmallVector<mlir::Type, 4> input_types; + auto func_type = builder.getFunctionType(input_types, ret_types); + auto func_loc = + mlir::NameLoc::get(builder.getStringAttr(kMainFunctionName), base_loc); + auto func = mlir::func::FuncOp::create(func_loc, kMainFunctionName, func_type, + /* attrs= */ {}); + func.addEntryBlock(); + + // deserialize tosa fbs into function + std::vector<mlir::Value> main_returns; + if (buildTosaMlir(func, *context, tsh, main_returns).failed()) { + llvm::errs() << "Failed to deserialize flatbuffer " + << tosa_deserialize_filename << "\n"; + return nullptr; + } + auto main_args = func.getCallableRegion()->getArguments(); + // extract function input types + for (auto arg : main_args) { + input_types.push_back(arg.getType()); + } + // extract function return types + for (auto ret : main_returns) { + ret_types.push_back(ret.getType()); + } + // set function type with full input and return types + func_type = builder.getFunctionType(input_types, ret_types); + func.setType(func_type); + + // set function attributes + llvm::SmallVector<mlir::NamedAttribute, 2> attributes; + if (!input_types.empty()) { + attributes.push_back(DefaultEntryFuncitonAttr( + builder, /* is_input = */ true, /* count = */ input_types.size())); + for (int i = 0; i < input_types.size(); i++) { + std::string input_i = kDefaultInputPrefix + std::to_string(i); + func.setArgAttr(i, "tf_saved_model.index_path", + mlir::ArrayAttr::get( + context, {mlir::StringAttr::get(context, input_i)})); + } + } + if (!ret_types.empty()) { + attributes.push_back(DefaultEntryFuncitonAttr( + builder, /* is_input = */ false, /* count = */ ret_types.size())); + for (int i = 0; i < ret_types.size(); i++) { + std::string output_i = kDefaultOutputPrefix + std::to_string(i); + func.setResultAttr( + i, "tf_saved_model.index_path", + mlir::ArrayAttr::get(context, + {mlir::StringAttr::get(context, output_i)})); + } + } + func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes)); + func->setAttr( + "tf_saved_model.exported_names", + mlir::ArrayAttr::get( + context, {mlir::StringAttr::get(context, kDefaultExportedName)})); + + // deserialize variable ops in the new module just before adding func op + if (ConstructVariableOps(module).failed()) { + return nullptr; + } + + // add func to module + module.push_back(std::move(func)); + return mlir::OwningOpRef<mlir::ModuleOp>(module); +} + +namespace { + +class TosaDeserialize : public TosaDeserializationPassBase<TosaDeserialize> { +public: + void runOnOperation() final { + auto function = getOperation(); + auto &context = getContext(); + + auto new_module_ref = BuildMlirFromTosaFile( + tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ true); + if (!new_module_ref) { + return signalPassFailure(); + } + + mlir::ModuleOp new_module = *new_module_ref; + auto builder = mlir::Builder(&context); + auto module = function->getParentOfType<mlir::ModuleOp>(); + auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>( + builder.getStringAttr(kMainFunctionName)); + if (!new_function) { + llvm::errs() << "Failed to find main function in deserialized module\n"; + return signalPassFailure(); + } + if (CloneIntoModuleAndFunction(context, + /* to_function = */ function, + /* to_module = */ module, + /* from_function = */ new_function, + /* from_module = */ new_module) + .failed()) { + return signalPassFailure(); + } + } +}; + +class TosaDeserializeJSON + : public TosaDeserializationJSONPassBase<TosaDeserializeJSON> { +public: + void runOnOperation() final { + auto function = getOperation(); + auto &context = getContext(); + + auto new_module_ref = BuildMlirFromTosaFile( + tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ false); + if (!new_module_ref) { + return signalPassFailure(); + } + + mlir::ModuleOp new_module = *new_module_ref; + auto builder = mlir::Builder(&context); + auto module = function->getParentOfType<mlir::ModuleOp>(); + auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>( + builder.getStringAttr(kMainFunctionName)); + if (!new_function) { + llvm::errs() << "Failed to find main function in deserialized module\n"; + return signalPassFailure(); + } + if (CloneIntoModuleAndFunction(context, + /* to_function = */ function, + /* to_module = */ module, + /* from_function = */ new_function, + /* from_module = */ new_module) + .failed()) { + return signalPassFailure(); + } + } +}; + +} // anonymous namespace + +// Creates an instance of the TOSA flatbuffer deserialization pass +std::unique_ptr<Pass> createTosaDeserializePass() { + return std::make_unique<TosaDeserialize>(); +} + +std::unique_ptr<Pass> createTosaDeserializeJSONPass() { + return std::make_unique<TosaDeserializeJSON>(); +} + +static PassRegistration<TosaDeserialize> passDeserialize([] { + return createTosaDeserializePass(); +}); + +static PassRegistration<TosaDeserializeJSON> passDeserializeJSON([] { + return createTosaDeserializeJSONPass(); +}); + +} // namespace tosa +} // namespace mlir |