From 581fb5d0e706d8669dd5ce21e1be2770b4951e02 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 16 Feb 2023 22:57:53 +0000 Subject: Add Tosa Deserialization Signed-off-by: Tai Ly Change-Id: I8b0220a8465e75b1accf6b0854e911a425730da6 --- CMakeLists.txt | 6 + include/DeserializationPasses.h | 37 + include/DeserializationPasses.td | 25 + include/schema_operator.def | 93 +++ src/TosaDeserialize.cpp | 1385 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 1546 insertions(+) create mode 100644 include/DeserializationPasses.h create mode 100644 include/DeserializationPasses.td create mode 100644 include/schema_operator.def create mode 100644 src/TosaDeserialize.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 5512401..db0bb78 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,15 +17,21 @@ set(LLVM_TARGET_DEFINITIONS include/SerializationPasses.td) mlir_tablegen(include/SerializationPasses.h.inc -gen-pass-decls -name TosaSerialization) add_public_tablegen_target(tosa_serialization_passes_inc_gen) +set(LLVM_TARGET_DEFINITIONS include/DeserializationPasses.td) +mlir_tablegen(include/DeserializationPasses.h.inc -gen-pass-decls -name TosaDeserialization) +add_public_tablegen_target(tosa_deserialization_passes_inc_gen) + # Compile the TOSA serialization_lib add_subdirectory(third_party/serialization_lib) add_mlir_library(tosa_serialize src/TosaSerialize.cpp + src/TosaDeserialize.cpp DEPENDS mlir-headers tosa_serialization_passes_inc_gen + tosa_deserialization_passes_inc_gen LINK_LIBS PRIVATE tosa_serialization_lib diff --git a/include/DeserializationPasses.h b/include/DeserializationPasses.h new file mode 100644 index 0000000..1bc195a --- /dev/null +++ b/include/DeserializationPasses.h @@ -0,0 +1,37 @@ + +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef INCLUDE_DESERIALIZATION_PASSES_H +#define INCLUDE_DESERIALIZATION_PASSES_H + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tosa { + +std::unique_ptr createTosaDeserializePass(); +std::unique_ptr createTosaDeserializeJSONPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_CLASSES +#include "include/DeserializationPasses.h.inc" + +} // namespace tosa +} // namespace mlir + +#endif // INCLUDE_DESERIALIZATION_PASSES_H diff --git a/include/DeserializationPasses.td b/include/DeserializationPasses.td new file mode 100644 index 0000000..999f0b4 --- /dev/null +++ b/include/DeserializationPasses.td @@ -0,0 +1,25 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +include "mlir/Pass/PassBase.td" + +def TosaDeserializationPass : Pass<"tosa-deserialize", "func::FuncOp"> { + let summary = "Deserialize TOSA flatbuffer. Clear original MLIR graph and generate TOSA MLIR"; + let constructor = "createTosaDeserializePass()"; +} + +def TosaDeserializationJSONPass : Pass<"tosa-deserialize-json", "func::FuncOp"> { + let summary = "Deserialize TOSA flatbuffer JSON form. Clear original MLIR graph and generate TOSA MLIR"; + let constructor = "createTosaDeserializeJSONPass()"; +} diff --git a/include/schema_operator.def b/include/schema_operator.def new file mode 100644 index 0000000..1af367e --- /dev/null +++ b/include/schema_operator.def @@ -0,0 +1,93 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + Syntax: + DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME) + + Description: + SCHEMA_OP_NAME: the schema operator names, must match Op names in schema/tosa.fbs in serialization_lib repo +*/ + +/* schema operators */ +DEF_SCHEMA_OPERATOR(ARGMAX) +DEF_SCHEMA_OPERATOR(AVG_POOL2D) +DEF_SCHEMA_OPERATOR(CONV2D) +DEF_SCHEMA_OPERATOR(CONV3D) +DEF_SCHEMA_OPERATOR(DEPTHWISE_CONV2D) +DEF_SCHEMA_OPERATOR(FULLY_CONNECTED) +DEF_SCHEMA_OPERATOR(MATMUL) +DEF_SCHEMA_OPERATOR(MAX_POOL2D) +DEF_SCHEMA_OPERATOR(TRANSPOSE_CONV2D) +DEF_SCHEMA_OPERATOR(CLAMP) +DEF_SCHEMA_OPERATOR(RESERVED) +DEF_SCHEMA_OPERATOR(SIGMOID) +DEF_SCHEMA_OPERATOR(TANH) +DEF_SCHEMA_OPERATOR(ADD) +DEF_SCHEMA_OPERATOR(ARITHMETIC_RIGHT_SHIFT) +DEF_SCHEMA_OPERATOR(BITWISE_AND) +DEF_SCHEMA_OPERATOR(BITWISE_OR) +DEF_SCHEMA_OPERATOR(BITWISE_XOR) +DEF_SCHEMA_OPERATOR(INTDIV) +DEF_SCHEMA_OPERATOR(LOGICAL_AND) +DEF_SCHEMA_OPERATOR(LOGICAL_LEFT_SHIFT) +DEF_SCHEMA_OPERATOR(LOGICAL_RIGHT_SHIFT) +DEF_SCHEMA_OPERATOR(LOGICAL_OR) +DEF_SCHEMA_OPERATOR(LOGICAL_XOR) +DEF_SCHEMA_OPERATOR(MAXIMUM) +DEF_SCHEMA_OPERATOR(MINIMUM) +DEF_SCHEMA_OPERATOR(MUL) +DEF_SCHEMA_OPERATOR(POW) +DEF_SCHEMA_OPERATOR(SUB) +DEF_SCHEMA_OPERATOR(TABLE) +DEF_SCHEMA_OPERATOR(ABS) +DEF_SCHEMA_OPERATOR(BITWISE_NOT) +DEF_SCHEMA_OPERATOR(CEIL) +DEF_SCHEMA_OPERATOR(CLZ) +DEF_SCHEMA_OPERATOR(EXP) +DEF_SCHEMA_OPERATOR(FLOOR) +DEF_SCHEMA_OPERATOR(LOG) +DEF_SCHEMA_OPERATOR(LOGICAL_NOT) +DEF_SCHEMA_OPERATOR(NEGATE) +DEF_SCHEMA_OPERATOR(RECIPROCAL) +DEF_SCHEMA_OPERATOR(RSQRT) +DEF_SCHEMA_OPERATOR(SELECT) +DEF_SCHEMA_OPERATOR(EQUAL) +DEF_SCHEMA_OPERATOR(GREATER) +DEF_SCHEMA_OPERATOR(GREATER_EQUAL) +DEF_SCHEMA_OPERATOR(REDUCE_ANY) +DEF_SCHEMA_OPERATOR(REDUCE_ALL) +DEF_SCHEMA_OPERATOR(REDUCE_MAX) +DEF_SCHEMA_OPERATOR(REDUCE_MIN) +DEF_SCHEMA_OPERATOR(REDUCE_PRODUCT) +DEF_SCHEMA_OPERATOR(REDUCE_SUM) +DEF_SCHEMA_OPERATOR(CONCAT) +DEF_SCHEMA_OPERATOR(PAD) +DEF_SCHEMA_OPERATOR(RESHAPE) +DEF_SCHEMA_OPERATOR(REVERSE) +DEF_SCHEMA_OPERATOR(SLICE) +DEF_SCHEMA_OPERATOR(TILE) +DEF_SCHEMA_OPERATOR(TRANSPOSE) +DEF_SCHEMA_OPERATOR(GATHER) +DEF_SCHEMA_OPERATOR(SCATTER) +DEF_SCHEMA_OPERATOR(RESIZE) +DEF_SCHEMA_OPERATOR(CAST) +DEF_SCHEMA_OPERATOR(RESCALE) +DEF_SCHEMA_OPERATOR(CONST) +DEF_SCHEMA_OPERATOR(IDENTITY) +DEF_SCHEMA_OPERATOR(CUSTOM) +DEF_SCHEMA_OPERATOR(COND_IF) +DEF_SCHEMA_OPERATOR(WHILE_LOOP) +DEF_SCHEMA_OPERATOR(FFT2D) +DEF_SCHEMA_OPERATOR(RFFT2D) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp new file mode 100644 index 0000000..f6f78fc --- /dev/null +++ b/src/TosaDeserialize.cpp @@ -0,0 +1,1385 @@ + +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TOSA MLIR deserialize passes + +#include "include/DeserializationPasses.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tosa_serialization_handler.h" +#include +#include +#include +#include +#include + +// The namespace might be confusing here. We have mlir::tosa:: defined in MLIR +// and tosa:: defined in serialization library +// TODO: align the namespace +using namespace tosa; + +namespace cl = llvm::cl; + +llvm::cl::opt tosa_deserialize_filename( + "tosa-deserialize-filename", llvm::cl::desc(""), + llvm::cl::init("tosa_dump.tosa"), llvm::cl::value_desc("filename")); + +llvm::cl::opt tosa_deserialize_schema( + "tosa-deserialize-schema", llvm::cl::desc(""), + llvm::cl::init(""), llvm::cl::value_desc("filename")); + +namespace { + +// construct tensor type from dtype and shape of TosaSerializationTensor +mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, + TosaSerializationTensor *ts, + mlir::RankedTensorType &type) { + mlir::Type element_type; + switch (ts->GetDtype()) { + case DType_BOOL: + element_type = op_builder->getI1Type(); + break; + case DType_UINT8: + element_type = op_builder->getIntegerType(8, false); + break; + case DType_INT4: + element_type = op_builder->getI4Type(); + break; + case DType_INT8: + element_type = op_builder->getI8Type(); + break; + case DType_INT16: + element_type = op_builder->getIntegerType(16); + break; + case DType_INT32: + element_type = op_builder->getI32Type(); + break; + case DType_INT48: + element_type = op_builder->getIntegerType(48); + break; + case DType_FP32: + element_type = op_builder->getF32Type(); + break; + case DType_UINT16: + element_type = op_builder->getIntegerType(16, false); + break; + default: + llvm::errs() << "ERROR: unknown type " << EnumNamesDType()[ts->GetDtype()] + << "\n"; + return mlir::failure(); + } + llvm::SmallVector shape(ts->GetShape().begin(), ts->GetShape().end()); + type = mlir::RankedTensorType::get(llvm::makeArrayRef(shape), element_type); + return mlir::success(); +} + +template +mlir::DenseElementsAttr +BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder, + const std::vector &values) { + std::vector vec; + for (auto val : values) { + vec.push_back(val); + } + auto type = + mlir::RankedTensorType::get({vec.size()}, op_builder->getI16Type()); + return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); +} + +template +mlir::DenseElementsAttr +BuildDenseI32ElementsAttr(mlir::OpBuilder *op_builder, + const std::vector &values) { + std::vector vec; + for (auto val : values) { + vec.push_back(val); + } + auto type = + mlir::RankedTensorType::get({vec.size()}, op_builder->getI32Type()); + return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); +} + +template +mlir::DenseI32ArrayAttr BuildDenseI32ArrayAttr(mlir::OpBuilder *op_builder, + const std::vector &values) { + std::vector vec; + for (auto val : values) { + vec.push_back(val); + } + return op_builder->getDenseI32ArrayAttr(vec); +} + +template +mlir::DenseI64ArrayAttr BuildDenseI64ArrayAttr(mlir::OpBuilder *op_builder, + const std::vector &values) { + std::vector vec; + for (auto val : values) { + vec.push_back(val); + } + return op_builder->getDenseI64ArrayAttr(vec); +} + +const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) { + if (mode == ResizeMode_NEAREST) { + return "NEAREST_NEIGHBOR"; + } else if (mode == ResizeMode_BILINEAR) { + return "BILINEAR"; + } + return ""; +} + +} // namespace + +class TosaMlirRegionBuilder; +class TosaMlirBlockBuilder; + +class TosaMlirOperatorBuilder { +public: + TosaMlirOperatorBuilder( + mlir::OpBuilder *_op_builder, TosaSerializationBasicBlock *_ser_block, + mlir::Block *_block, mlir::Location _loc, + std::unordered_map *_tensor_map, + std::unordered_map *_tensor_type_map) + : op_builder(_op_builder), ser_block(_ser_block), block(_block), + loc(_loc), tensor_map(_tensor_map), tensor_type_map(_tensor_type_map) {} + + template + std::vector build(TosaSerializationOperator *op) const; + + std::string get_string(TosaSerializationOperator *op) const { + std::string op_string; + op_string += "operator opcode="; + op_string += EnumNamesOp()[op->GetOp()]; + op_string += ", input=["; + for (auto ts : op->GetInputTensorNames()) { + op_string += (ts + " "); + } + op_string += "], output=["; + for (auto ts : op->GetOutputTensorNames()) { + op_string += (ts + " "); + } + op_string += "]"; + return op_string; + } + +private: + template + std::vector + BuildEwiseUnaryOp(TosaSerializationOperator *op) const; + + template + std::vector + BuildEwiseBinaryOp(TosaSerializationOperator *op) const; + + template + std::vector + BuildReductionOp(TosaSerializationOperator *op) const; + + template + std::vector BuildConvOp(TosaSerializationOperator *op) const; + + mlir::OpBuilder *op_builder; + TosaSerializationBasicBlock *ser_block; + mlir::Block *block; + mlir::Location loc; + std::unordered_map *tensor_map; + std::unordered_map *tensor_type_map; +}; + +// Main template to catch unimplemented translation +template +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + llvm::errs() << "ERROR: " << get_string(op) << " translation hasn't been implemented\n"; + return std::vector(); +} + +// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == Attribute_PoolAttribute); // double check attribute type + TosaPoolAttribute* attr = static_cast(op->GetAttribute()); + mlir::DenseI64ArrayAttr kernel = BuildDenseI64ArrayAttr(op_builder, attr->kernel()); + mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride()); + mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); + int32_t input_zp = attr->input_zp(); + int32_t output_zp = attr->output_zp(); + assert(input_zp == 0 && output_zp == 0); + + mlir::Operation* mlir_op = op_builder->create(loc, output_type, input_val, kernel, stride, pad); + block->push_back(mlir_op); + return std::vector({ mlir_op->getResult(0) }); +} + +// BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D) +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == Attribute_PoolAttribute); // double check attribute type + TosaPoolAttribute* attr = static_cast(op->GetAttribute()); + mlir::DenseI64ArrayAttr kernel = BuildDenseI64ArrayAttr(op_builder, attr->kernel()); + mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride()); + mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); + int32_t input_zp = attr->input_zp(); + int32_t output_zp = attr->output_zp(); + mlir::Operation* mlir_op; + if (input_zp == 0 && output_zp == 0) { + mlir_op = op_builder->create(loc, output_type, input_val, kernel, stride, pad); + } else { + auto quant = op_builder->getAttr(input_zp, output_zp); + mlir_op = op_builder->create(loc, output_type, input_val, kernel, stride, pad, quant); + } + block->push_back(mlir_op); + return std::vector({ mlir_op->getResult(0) }); +} + +template +std::vector TosaMlirOperatorBuilder::BuildEwiseUnaryOp(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == Attribute_NONE); // double check that there is no attribute + + mlir::Operation* mlir_op = op_builder->create(loc, output_type, input_val); + block->push_back(mlir_op); + return std::vector({ mlir_op->getResult(0) }); +} + +template +std::vector TosaMlirOperatorBuilder::BuildEwiseBinaryOp(TosaSerializationOperator* op) const +{ + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == Attribute_NONE); // double check that there is no attribute + + mlir::Operation* mlir_op = op_builder->create(loc, output_type, input0_val, input1_val); + block->push_back(mlir_op); + return std::vector({ mlir_op->getResult(0) }); +} + +template +std::vector TosaMlirOperatorBuilder::BuildReductionOp(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == Attribute_AxisAttribute); // double check attribute type + + TosaAxisAttribute* attr = static_cast(op->GetAttribute()); + auto axis = op_builder->getI64IntegerAttr(attr->axis()); + + mlir::Operation *mlir_op = + op_builder->create(loc, output_type, input_val, axis); + block->push_back(mlir_op); + return std::vector({ mlir_op->getResult(0) }); +} + +#define BUILD_OP_ELEMENTWISE_UNARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector TosaMlirOperatorBuilder::build( \ + TosaSerializationOperator* op) const { \ + return BuildEwiseUnaryOp(op); \ + } + +#define BUILD_OP_ELEMENTWISE_BINARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector TosaMlirOperatorBuilder::build( \ + TosaSerializationOperator* op) const { \ + return BuildEwiseBinaryOp(op); \ + } + +#define BUILD_OP_REDUCTION(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector TosaMlirOperatorBuilder::build( \ + TosaSerializationOperator* op) const { \ + return BuildReductionOp(op); \ + } + +// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) +// BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D) + +BUILD_OP_ELEMENTWISE_BINARY(Add, ADD) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR) +BUILD_OP_ELEMENTWISE_BINARY(Div, INTDIV) +BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND) +BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT) +BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT) +BUILD_OP_ELEMENTWISE_BINARY(LogicalOr, LOGICAL_OR) +BUILD_OP_ELEMENTWISE_BINARY(LogicalXor, LOGICAL_XOR) +BUILD_OP_ELEMENTWISE_BINARY(Maximum, MAXIMUM) +BUILD_OP_ELEMENTWISE_BINARY(Minimum, MINIMUM) +BUILD_OP_ELEMENTWISE_BINARY(Pow, POW) +BUILD_OP_ELEMENTWISE_BINARY(Sub, SUB) + +BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS) +BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT) +BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL) +BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ) +BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP) +BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR) +BUILD_OP_ELEMENTWISE_UNARY(Log, LOG) +BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT) +BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL) +BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT) + +BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY) +BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL) +BUILD_OP_REDUCTION(ReduceMax, REDUCE_MAX) +BUILD_OP_REDUCTION(ReduceMin, REDUCE_MIN) +BUILD_OP_REDUCTION(ReduceProd, REDUCE_PRODUCT) +BUILD_OP_REDUCTION(ReduceSum, REDUCE_SUM) + +BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL) +BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER) +BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL) + +BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID) +BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH) +BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY) +BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST) + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + const auto &output_name = op->GetOutputTensorNames()[0]; + mlir::RankedTensorType output_type = tensor_type_map->at(output_name); + TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name); + const auto &data = ts->GetData(); + auto &shape = ts->GetShape(); + // compute output data size + uint32_t out_size = 1; + for (const auto dim : shape) { + out_size *= dim; + } + mlir::DenseElementsAttr value_attr; + switch (ts->GetDtype()) { + case DType_FP32: { + std::vector float_data; + TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data); + value_attr = + mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(float_data)); + break; + } + case DType_INT8: { + std::vector int8_data; + TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data); + value_attr = + mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data)); + break; + } + case DType_INT16: { + std::vector int16_data; + TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data); + value_attr = + mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int16_data)); + break; + } + case DType_INT32: { + std::vector int32_data; + TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data); + value_attr = + mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int32_data)); + break; + } + case DType_INT48: { + std::vector int48_data; + TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data); + std::vector apint_data; + for (const auto v : int48_data) { + mlir::APInt apint_value(48, static_cast(v), + /* isSigned = */ false); + apint_data.push_back(apint_value); + } + value_attr = + mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(apint_data)); + break; + } + case DType_BOOL: { + std::vector bool_data; + TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data); + llvm::SmallVector bool_values(bool_data.begin(), bool_data.end()); + value_attr = mlir::DenseElementsAttr::get(output_type, bool_values); + break; + } + default: + llvm::errs() << "ERROR: " << get_string(op) + << " contains unsupported element type\n"; + return std::vector(); + } + mlir::Operation *mlir_op = + op_builder->create(loc, output_type, value_attr); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template +std::vector +TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const { + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_ConvAttribute); // double check attribute type + TosaConvAttribute *attr = + static_cast(op->GetAttribute()); + mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); + mlir::DenseI64ArrayAttr stride = + BuildDenseI64ArrayAttr(op_builder, attr->stride()); + mlir::DenseI64ArrayAttr dilation = + BuildDenseI64ArrayAttr(op_builder, attr->dilation()); + auto input_zp = attr->input_zp(); + auto weight_zp = attr->weight_zp(); + + // quantizationattr is required for quantized type, and not allowed for float + // type + mlir::Operation *mlir_op; + if (output_type.getElementType().isa()) { + assert(input_zp == 0 && weight_zp == 0); + mlir_op = + op_builder->create(loc, output_type, input0_val, input1_val, + input2_val, pad, stride, dilation); + } else { + auto quant = op_builder->getAttr( + input_zp, weight_zp); + mlir_op = + op_builder->create(loc, output_type, input0_val, input1_val, + input2_val, pad, stride, dilation, quant); + } + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +#define BUILD_OP_CONV(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + std::vector \ + TosaMlirOperatorBuilder::build( \ + TosaSerializationOperator * op) const { \ + return BuildConvOp(op); \ + } + +BUILD_OP_CONV(Conv2D, CONV2D) +BUILD_OP_CONV(Conv3D, CONV3D) +BUILD_OP_CONV(DepthwiseConv2D, DEPTHWISE_CONV2D) + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_TransposeConvAttribute); // double check attribute type + TosaTransposeConvAttribute *attr = + static_cast(op->GetAttribute()); + mlir::DenseI64ArrayAttr out_pad = + BuildDenseI64ArrayAttr(op_builder, attr->out_pad()); + mlir::DenseI64ArrayAttr stride = + BuildDenseI64ArrayAttr(op_builder, attr->stride()); + mlir::DenseI64ArrayAttr output_shape = + BuildDenseI64ArrayAttr(op_builder, attr->output_shape()); + auto input_zp = attr->input_zp(); + auto weight_zp = attr->weight_zp(); + + // quantizationattr is required for quantized type, and not allowed for float + // type + mlir::Operation *mlir_op; + if (output_type.getElementType().isa()) { + assert(input_zp == 0 && weight_zp == 0); + mlir_op = op_builder->create( + loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, + output_shape); + } else { + auto quant = op_builder->getAttr( + input_zp, weight_zp); + mlir_op = op_builder->create( + loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, + output_shape, quant); + } + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_FullyConnectedAttribute); // double check attribute type + TosaFullyConnectedAttribute *attr = + static_cast(op->GetAttribute()); + auto input_zp = attr->input_zp(); + auto weight_zp = attr->weight_zp(); + + // quantizationattr is required for quantized type, and not allowed for float + // type + mlir::Operation *mlir_op; + if (output_type.getElementType().isa()) { + assert(input_zp == 0 && weight_zp == 0); + mlir_op = op_builder->create( + loc, output_type, input0_val, input1_val, input2_val); + } else { + auto quant = op_builder->getAttr( + input_zp, weight_zp); + mlir_op = op_builder->create( + loc, output_type, input0_val, input1_val, input2_val, quant); + } + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_MatMulAttribute); // double check attribute type + TosaMatMulAttribute *attr = + static_cast(op->GetAttribute()); + auto A_zp = attr->a_zp(); + auto B_zp = attr->b_zp(); + + mlir::Operation *mlir_op; + if (A_zp == 0 && B_zp == 0) { + mlir_op = op_builder->create(loc, output_type, + input0_val, input1_val); + } else { + auto quant = + op_builder->getAttr(A_zp, B_zp); + mlir_op = op_builder->create( + loc, output_type, input0_val, input1_val, quant); + } + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input0_val, input1_val, input2_val); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_ClampAttribute); // double check attribute type + TosaClampAttribute *attr = + static_cast(op->GetAttribute()); + + auto min_int = op_builder->getI64IntegerAttr(attr->min_int()); + auto max_int = op_builder->getI64IntegerAttr(attr->max_int()); + auto min_fp = op_builder->getF32FloatAttr(attr->min_fp()); + auto max_fp = op_builder->getF32FloatAttr(attr->max_fp()); + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input_val, min_int, max_int, min_fp, max_fp); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +// ArgMax has single input, and single I64 axis attribute +BUILD_OP_REDUCTION(ArgMax, ARGMAX) + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + llvm::SmallVector input_values; + for (auto &input_name : op->GetInputTensorNames()) { + mlir::Value input_val = tensor_map->at(input_name); + input_values.push_back(input_val); + } + + assert(op->GetAttributeType() == + Attribute_AxisAttribute); // double check attribute type + TosaAxisAttribute *attr = + static_cast(op->GetAttribute()); + auto axis = op_builder->getI64IntegerAttr(attr->axis()); + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input_values, axis); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_NegateAttribute); // double check attribute type + TosaNegateAttribute *attr = + static_cast(op->GetAttribute()); + + auto input_zp = attr->input1_zp(); + auto output_zp = attr->output_zp(); + + mlir::Operation *mlir_op; + if (input_zp == 0 && output_zp == 0) { + mlir_op = + op_builder->create(loc, output_type, input_val); + } else { + auto quant = op_builder->getAttr( + input_zp, output_zp); + mlir_op = op_builder->create(loc, output_type, + input_val, quant); + } + + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_ReshapeAttribute); // double check attribute type + TosaReshapeAttribute *attr = + static_cast(op->GetAttribute()); + + mlir::DenseI64ArrayAttr new_shape = + BuildDenseI64ArrayAttr(op_builder, attr->new_shape()); + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input_val, new_shape); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_PadAttribute); // double check attribute type + TosaPadAttribute *attr = static_cast(op->GetAttribute()); + + auto padding_attr = BuildDenseI32ElementsAttr(op_builder, attr->padding()); + auto padding_type = mlir::RankedTensorType::get({attr->padding().size()}, + op_builder->getI32Type()); + mlir::Operation *mlir_const_op = + op_builder->create(loc, padding_type, padding_attr); + block->push_back(mlir_const_op); + auto padding_value = mlir_const_op->getResult(0); + auto pad_const_int = attr->pad_const_int(); + auto pad_const_fp = attr->pad_const_fp(); + // todo: int input_zp = attr->pad_input_zp(); + + mlir::Operation *mlir_op; + if (pad_const_int == 0 && pad_const_fp == 0.0f) { + // no pad_const input + mlir_op = op_builder->create(loc, output_type, input_val, + padding_value); + } else { + // create a const value for pad_const input + mlir::Value pad_const_value; + if (pad_const_int != 0) { + auto pad_const_int_type = + mlir::RankedTensorType::get({1}, op_builder->getI32Type()); + auto pad_const_int_attr = + mlir::DenseElementsAttr::get(pad_const_int_type, {pad_const_int}); + mlir::Operation *pad_const_int_op = + op_builder->create(loc, pad_const_int_type, + pad_const_int_attr); + block->push_back(pad_const_int_op); + pad_const_value = pad_const_int_op->getResult(0); + } else if (pad_const_fp != 0) { + auto pad_const_fp_type = + mlir::RankedTensorType::get({1}, op_builder->getF32Type()); + auto pad_const_fp_attr = + mlir::DenseElementsAttr::get(pad_const_fp_type, {pad_const_fp}); + mlir::Operation *pad_const_fp_op = + op_builder->create(loc, pad_const_fp_type, + pad_const_fp_attr); + block->push_back(pad_const_fp_op); + pad_const_value = pad_const_fp_op->getResult(0); + } + mlir_op = op_builder->create( + loc, output_type, input_val, padding_value, pad_const_value); + } + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_TransposeAttribute); // double check attribute type + TosaTransposeAttribute *attr = + static_cast(op->GetAttribute()); + // make a constant op from attr->perms values, of type: { shape = { perms.size + // }, element_type = I32 } + const auto perms_values = attr->perms(); + auto const_type = mlir::RankedTensorType::get({perms_values.size()}, + op_builder->getI32Type()); + mlir::DenseElementsAttr const_attr = + BuildDenseI32ElementsAttr(op_builder, perms_values); + mlir::Operation *mlir_const_op = + op_builder->create(loc, const_type, const_attr); + auto perms_val = mlir_const_op->getResult(0); + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input_val, perms_val); + + block->push_back(mlir_const_op); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_SliceAttribute); // double check attribute type + TosaSliceAttribute *attr = + static_cast(op->GetAttribute()); + + mlir::DenseI64ArrayAttr start = + BuildDenseI64ArrayAttr(op_builder, attr->start()); + mlir::DenseI64ArrayAttr size = + BuildDenseI64ArrayAttr(op_builder, attr->size()); + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input_val, start, size); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_TileAttribute); // double check attribute type + TosaTileAttribute *attr = + static_cast(op->GetAttribute()); + + mlir::DenseI64ArrayAttr multiples = + BuildDenseI64ArrayAttr(op_builder, attr->multiples()); + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input_val, multiples); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +// Gather is a binary op +BUILD_OP_ELEMENTWISE_BINARY(Gather, GATHER) + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input0_val, input1_val, input2_val); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_ResizeAttribute); // double check attribute type + TosaResizeAttribute *attr = + static_cast(op->GetAttribute()); + + mlir::DenseI64ArrayAttr scale = + BuildDenseI64ArrayAttr(op_builder, attr->scale()); + mlir::DenseI64ArrayAttr offset = + BuildDenseI64ArrayAttr(op_builder, attr->offset()); + mlir::DenseI64ArrayAttr border = + BuildDenseI64ArrayAttr(op_builder, attr->border()); + auto mode = op_builder->getStringAttr(ResizeEnum2Str(attr->mode())); + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input_val, scale, offset, border, mode); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +// Reverse has single input, and single I64 axis attribute +BUILD_OP_REDUCTION(Reverse, REVERSE) + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_MulAttribute); // double check attribute type + TosaMulAttribute *attr = static_cast(op->GetAttribute()); + + auto shift = op_builder->getI32IntegerAttr(attr->shift()); + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input0_val, input1_val, shift); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert( + op->GetAttributeType() == + Attribute_ArithmeticRightShiftAttribute); // double check attribute type + TosaArithmeticRightShiftAttribute *attr = + static_cast(op->GetAttribute()); + + auto round = op_builder->getBoolAttr(attr->round()); + + mlir::Operation *mlir_op = + op_builder->create( + loc, output_type, input0_val, input1_val, round); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_TableAttribute); // double check attribute type + TosaTableAttribute *attr = + static_cast(op->GetAttribute()); + + // create a const op for table value attribute + const auto table_values = attr->table(); + auto const_type = mlir::RankedTensorType::get({table_values.size()}, + op_builder->getI16Type()); + mlir::DenseElementsAttr const_attr = + BuildDenseI16ElementsAttr(op_builder, table_values); + mlir::Operation *mlir_const_op = + op_builder->create(loc, const_type, const_attr); + auto table_value = mlir_const_op->getResult(0); + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input_val, table_value); + block->push_back(mlir_const_op); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_RescaleAttribute); // double check attribute type + TosaRescaleAttribute *attr = + static_cast(op->GetAttribute()); + + auto input_zp = op_builder->getI32IntegerAttr(attr->input_zp()); + auto output_zp = op_builder->getI32IntegerAttr(attr->output_zp()); + + auto multiplier = BuildDenseI32ArrayAttr(op_builder, attr->multiplier()); + auto shift = BuildDenseI32ArrayAttr(op_builder, attr->shift()); + + auto scale32 = op_builder->getBoolAttr(attr->scale32()); + auto double_round = op_builder->getBoolAttr(attr->double_round()); + auto per_channel = op_builder->getBoolAttr(attr->per_channel()); + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, input_val, input_zp, output_zp, multiplier, shift, + scale32, double_round, per_channel); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +template <> +std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +{ + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_CustomAttribute); // double check attribute type + TosaCustomAttribute *attr = + static_cast(op->GetAttribute()); + + auto identifier = op_builder->getStringAttr(attr->identifier()); + auto config = op_builder->getStringAttr(attr->config()); + std::string impl_str; + impl_str.resize(attr->implementation_attrs().size() + 1); + int idx = 0; + for (auto c : attr->implementation_attrs()) { + impl_str[idx++] = c; + } + auto impl = op_builder->getStringAttr(impl_str); + + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, identifier, config, impl, input_val); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +// TosaSerializationBasicBlock +// template <> +// std::vector +// TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const +//{ +// // todo: mlir::tosa::IfOp +// return {}; +//} + +// TosaSerializationBasicBlock +// template <> +// std::vector +// TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) +// const +//{ +// // todo: mlir::tosa::WhileOp +// return {}; +//} + +template <> +std::vector +TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output0_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + mlir::RankedTensorType output1_type = + tensor_type_map->at(op->GetOutputTensorNames()[1]); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + + mlir::Operation *mlir_op = op_builder->create( + loc, output0_type, output1_type, input_val); + block->push_back(mlir_op); + return std::vector( + {mlir_op->getResult(0), mlir_op->getResult(1)}); +} + +class TosaMlirRegionBuilder { +public: + TosaMlirRegionBuilder(TosaSerializationRegion* _ser_region, + TosaSerializationHandler* _tsh, + mlir::Region* _region, + mlir::OpBuilder* _op_builder, + mlir::Location _loc) + : ser_region(_ser_region), tsh(_tsh), region(_region), op_builder(_op_builder), loc(_loc) {} + + mlir::LogicalResult BuildAllBlocksInRegion(std::vector& return_values); + + mlir::OpBuilder* GetOpBuilder() { return op_builder; } + mlir::Location GetLocation() { return loc; } + +private: + mlir::Region* region; + TosaSerializationRegion* ser_region; + TosaSerializationHandler* tsh; + mlir::OpBuilder* op_builder; + mlir::Location loc; + std::vector block_builders; +}; + +class TosaMlirBlockBuilder { +public: + TosaMlirBlockBuilder(TosaSerializationBasicBlock* _ser_block, + TosaMlirRegionBuilder* _region_builder, + mlir::Block* _block) + : ser_block(_ser_block), region_builder(_region_builder), block(_block) {} + + + mlir::LogicalResult + BuildAllOpsInBlock(std::vector& return_values); + + mlir::OpBuilder* GetOpBuilder() { return region_builder->GetOpBuilder(); } + mlir::Location GetLocation() { return region_builder->GetLocation(); } + +private: + TosaSerializationBasicBlock* ser_block; + TosaMlirRegionBuilder* region_builder; + mlir::Block* block; + std::unordered_map tensor_map; + std::unordered_map tensor_type_map; +}; + +mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( + std::vector &return_values) { + block->clear(); + auto loc = GetLocation(); + auto op_builder = GetOpBuilder(); + + std::unordered_map> consumer_map; + std::unordered_map tensor_built; + std::unordered_map operator_built; + std::queue operator_queue; + + TosaMlirOperatorBuilder tosa_op_builder(op_builder, ser_block, block, loc, + &tensor_map, &tensor_type_map); + + for (auto ts : ser_block->GetTensors()) { + mlir::RankedTensorType type; + if (BuildTensorType(op_builder, ts, type).failed()) { + return mlir::failure(); + } + const auto& ts_name = ts->GetName(); + tensor_type_map[ts_name] = type; + tensor_built[ts_name] = false; + } + + for (auto op : ser_block->GetOperators()) { + operator_built[op] = false; + for (auto ts_name : op->GetInputTensorNames()) { + consumer_map[ts_name].push_back(op); + } + } + + // Update operator_queue if a consumer of tensor_name has all of its inputs already built + auto queue_ready_consumers = [&](const std::string tensor_name) { + for (auto consumer_op : consumer_map[tensor_name]) { + // Sanity check operator hasn't been built + if (operator_built[consumer_op]) { + llvm::errs() << "ERROR: " << tosa_op_builder.get_string(consumer_op) + << " is already built before its input is built\n"; + assert(0); + } + bool all_inputs_ready = true; + for (const auto& input_name : consumer_op->GetInputTensorNames()) { + if (!tensor_built[input_name]) { + all_inputs_ready = false; + break; + } + } + if (all_inputs_ready) { + operator_queue.push(consumer_op); + } + } + }; + + // Initialize tensor_map/tensor_built/operator_queue based on block input arguments + for (const std::string& block_input_name : ser_block->GetInputs()) { + auto type = tensor_type_map[block_input_name]; + auto input_value = block->addArgument(type, loc); + if (tensor_map.count(block_input_name)) { + llvm::errs() << "ERROR: block input tensor " << block_input_name << " already exists\n"; + return mlir::failure(); + } + tensor_built[block_input_name] = true; + tensor_map[block_input_name] = input_value; + queue_ready_consumers(block_input_name); + } + + // add all operators with 0 inputs (e.g., constant operators) to + // operator_queue + for (auto op : ser_block->GetOperators()) { + if (op->GetInputTensorNames().empty()) { + operator_queue.push(op); + } + } + + while (!operator_queue.empty()) { + TosaSerializationOperator* op = operator_queue.front(); + operator_queue.pop(); + + // skip if operator has been built + if (operator_built[op]) { + // this happens when same input appears twice or more in operator, eg, concat(%0, %0) + continue; + } + operator_built[op] = true; + + std::vector output_values; + if (false) { + } +#define DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME) \ + else if (op->GetOp() == Op_##SCHEMA_OP_NAME) { \ + output_values = tosa_op_builder.build(op); \ + } +#include "schema_operator.def" +#undef DEF_SCHEMA_OPERATOR + else { + llvm::errs() << "ERROR: unsupported opcode=" << EnumNamesOp()[op->GetOp()] << "\n"; + return mlir::failure(); + } + + // Sanity check if number of built mlir::Value is expected + if (op->GetOutputTensorNames().size() != output_values.size()) { + llvm::errs() << "ERROR: number of built mlir::Value is not matching number of operator output tensor\n"; + return mlir::failure(); + } + + for (size_t i = 0; i < output_values.size(); i++) { + // Sanity check tensor hasn't been built + std::string op_output_name = op->GetOutputTensorNames()[i]; + if (tensor_built[op_output_name]) { + llvm::errs() << "ERROR: tensor " << op_output_name << " is already built\n"; + return mlir::failure(); + } + tensor_map[op_output_name] = output_values[i]; + tensor_built[op_output_name] = true; + queue_ready_consumers(op_output_name); + } + } + + // Construct return values + std::vector return_operands; + for (const auto& output_name : ser_block->GetOutputs()) { + // Sanity check if terminator mlir::Value is built + if (!tensor_built[output_name]) { + llvm::errs() << "ERROR: terminator mlir::Value " << output_name << " is not built\n"; + return mlir::failure(); + } + mlir::Value output_value = tensor_map.at(output_name); + return_operands.push_back(output_value); + return_values.push_back(output_value); + } + auto terminator_op = op_builder->create(loc, return_operands); + block->push_back(terminator_op); + + // need topological sorting? + + return mlir::success(); +} + +mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion( + std::vector& return_values) { + for (auto& ser_block : ser_region->GetBlocks()) { + auto& block = region->emplaceBlock(); + TosaMlirBlockBuilder block_builder(ser_block, this, &block); + // Region Builders need access to block builders (?) + block_builders.push_back(&block_builder); + + if (block_builder.BuildAllOpsInBlock(return_values).failed()) { + return mlir::failure(); + } + + if (return_values.empty()) { + llvm::errs() << "Warning: graph doesn't have return values\n"; + } + } + return mlir::success(); +} + +mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func, + mlir::MLIRContext& context, + tosa::TosaSerializationHandler& tsh) { + + mlir::Region* main_region = func.getCallableRegion(); + if (!main_region) { + llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n"; + return mlir::failure(); + } + + if (tsh.GetRegions().size() != 1) { + llvm::errs() << "Internal Error: TosaSerializationHandler's region list " + "must contain exactly one region\n"; + return mlir::failure(); + } + + TosaSerializationRegion* ser_main_region = tsh.GetRegions().front(); + + auto loc = func.getLoc(); + std::vector main_returns; + + main_region->takeBody(*main_region); // empty old func body + auto op_builder = mlir::OpBuilder(func.getBody()); + + TosaMlirRegionBuilder region_builder(ser_main_region, &tsh, main_region, &op_builder, loc); + if (region_builder.BuildAllBlocksInRegion(main_returns).failed()) { + return mlir::failure(); + } + + if (main_returns.empty()) { + llvm::errs() << "Warning: graph doesn't have return values\n"; + } + + return mlir::success(); +} + +// Load Tosa Schema into TosaSerializationHandler, required for JSON save/load +mlir::LogicalResult loadTosaSchema(tosa::TosaSerializationHandler& tsh) { + const char *tosa_schema = tosa_deserialize_schema.c_str(); + + if (!tosa_schema) { + llvm::errs() << "Flatbuffer schema not defined\n"; + return mlir::failure(); + } + + if (tsh.LoadFileSchema(tosa_schema)) { + llvm::errs() << "Error loading tosa schema file: " << tosa_schema << "\n"; + return mlir::failure(); + } + return mlir::success(); +} + +namespace mlir { + +namespace tosa { + +namespace { + +class TosaDeserialize : public TosaDeserializationPassBase { +public: + void runOnOperation() final { + TosaSerializationHandler tsh; + if (tsh.LoadFileTosaFlatbuffer(tosa_deserialize_filename.c_str())) { + llvm::errs() << "Fail to load TOSA file " << tosa_deserialize_filename << "\n"; + return signalPassFailure(); + } + + auto function = getOperation(); + auto& context = getContext(); + + if (buildTosaMlir(function, context, tsh).failed()) { + llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n"; + return signalPassFailure(); + } + } +}; + +class TosaDeserializeJSON + : public TosaDeserializationJSONPassBase { +public: + void runOnOperation() final { + TosaSerializationHandler tsh; + + // must load tosa schema before loading json file + if (loadTosaSchema(tsh).failed()) { + return signalPassFailure(); + } + + if (tsh.LoadFileJson(tosa_deserialize_filename.c_str())) { + llvm::errs() << "Fail to load TOSA JSON file " << tosa_deserialize_filename << "\n"; + return signalPassFailure(); + } + + auto function = getOperation(); + auto& context = getContext(); + + if (buildTosaMlir(function, context, tsh).failed()) { + llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n"; + return signalPassFailure(); + } + } +}; + +} // anonymous namespace + +// Creates an instance of the TOSA flatbuffer deserialization pass +std::unique_ptr createTosaDeserializePass() { + return std::make_unique(); +} + +std::unique_ptr createTosaDeserializeJSONPass() { + return std::make_unique(); +} + +static PassRegistration passDeserialize([] { + return createTosaDeserializePass(); +}); + +static PassRegistration passDeserializeJSON([] { + return createTosaDeserializeJSONPass(); +}); + +} // namespace tosa +} // namespace mlir -- cgit v1.2.1