From 80a022fd103b26a03a04e0565c4d263f73d950b8 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Mon, 15 Nov 2021 17:07:37 -0800 Subject: First commit of tosa serialize passes Signed-off-by: Kevin Cheng Change-Id: I1551017706f6e8af604792f48cdeb49b4da7ef0d --- src/TosaSerialize.cpp | 1764 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1764 insertions(+) create mode 100644 src/TosaSerialize.cpp (limited to 'src/TosaSerialize.cpp') diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp new file mode 100644 index 0000000..5699ffe --- /dev/null +++ b/src/TosaSerialize.cpp @@ -0,0 +1,1764 @@ + +// Copyright (c) 2020-2021, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://llvm.org/LICENSE.txt +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TOSA flatbuffer generation + +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tosa_serialization_handler.h" +#include +#include + +// The namespace might be confusing here. We have mlir::tosa:: defined in MLIR +// and tosa:: defined in serialization library +// TODO: align the namespace +using namespace tosa; + +namespace cl = llvm::cl; + +llvm::cl::opt tosa_flatbuffer_filename( + "tosa-flatbuffer-filename", llvm::cl::desc(""), + llvm::cl::init("tosa_dump.tosa"), llvm::cl::value_desc("filename")); + +llvm::cl::opt tosa_flatbuffer_schema( + "tosa-flatbuffer-schema", llvm::cl::desc(""), + llvm::cl::init(""), llvm::cl::value_desc("filename")); + +// Specialize mlir::Value for std::hash and std::equal_to to be able to +// build std::unordered_map +namespace std { + +template <> struct hash { + std::size_t operator()(const mlir::Value &val) const { + return static_cast(mlir::hash_value(val)); + } +}; + +template <> struct equal_to { + bool operator()(const mlir::Value &lhs, const mlir::Value &rhs) const { + return (lhs == rhs); + } +}; + +} // namespace std + +ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { + if (mode_str == "NEAREST_NEIGHBOR") + return ResizeMode_NEAREST; + else if (mode_str == "BILINEAR") + return ResizeMode_BILINEAR; + else + return ResizeMode_UNKNOWN; +} + +DType Type2DType(mlir::Type element_type) { + if (element_type.isF64() || element_type.isF32() || element_type.isF16() || + element_type.isBF16()) { + return DType_FLOAT; + } else if (element_type.isUnsignedInteger(8)) { + return DType_UINT8; + } else if (element_type.isInteger(4)) { + return DType_INT8; + } else if (element_type.isInteger(8)) { + return DType_INT8; + } else if (element_type.isInteger(16)) { + return DType_INT16; + } else if (element_type.isInteger(32)) { + return DType_INT32; + } else if (element_type.isInteger(48)) { + return DType_INT48; + } + // boolean in MLIR treated as integer with bitwidth 1 + else if (element_type.isInteger(1)) { + return DType_BOOL; + } + return DType_UNKNOWN; +} + +int GetQuantizedParameter(mlir::Type type, std::vector &scale, + std::vector &zeropoint, + int32_t &quantized_dimension, int64_t &quant_min, + int64_t &quant_max) { + if (auto qtype = type.dyn_cast()) { + scale.push_back(qtype.getScale()); + zeropoint.push_back(qtype.getZeroPoint()); + quantized_dimension = 0; + + quant_min = qtype.getStorageTypeMin(); + quant_max = qtype.getStorageTypeMax(); + } else if (auto qtype = + type.dyn_cast()) { + scale.assign(qtype.getScales().begin(), qtype.getScales().end()); + zeropoint.assign(qtype.getZeroPoints().begin(), + qtype.getZeroPoints().end()); + quantized_dimension = qtype.getQuantizedDimension(); + + quant_min = qtype.getStorageTypeMin(); + quant_max = qtype.getStorageTypeMax(); + } else { + return 1; + } + + return 0; +} + +TosaQuantInfoBase * +GetUnaryQuantInfo(mlir::tosa::UnaryOpQuantizationAttr quant_info) { + int32_t input_zp = quant_info.input_zp().getInt(); + int32_t output_zp = quant_info.output_zp().getInt(); + + TosaQuantInfoBase *qinfo = new TosaUnaryQuantInfo(input_zp, output_zp); + + return qinfo; +} + +TosaQuantInfoBase * +GetConvQuantInfo(mlir::tosa::ConvOpQuantizationAttr quant_info) { + int32_t input_zp = quant_info.input_zp().getInt(); + int32_t weight_zp = quant_info.weight_zp().getInt(); + + TosaQuantInfoBase *qinfo = new TosaConvQuantInfo(input_zp, weight_zp); + + return qinfo; +} + +TosaQuantInfoBase * +GetPadQuantInfo(mlir::tosa::PadOpQuantizationAttr quant_info) { + int32_t input_zp = quant_info.input_zp().getInt(); + + TosaQuantInfoBase *qinfo = new TosaPadQuantInfo(input_zp); + + return qinfo; +} + +TosaQuantInfoBase * +GetMatMulQuantInfo(mlir::tosa::MatMulOpQuantizationAttr quant_info) { + int32_t a_zp = quant_info.a_zp().getInt(); + int32_t b_zp = quant_info.b_zp().getInt(); + + TosaQuantInfoBase *qinfo = new TosaMatMulQuantInfo(a_zp, b_zp); + + return qinfo; +} + +class TosaSerializationBlockBuilder; + +class TosaSerializationOperatorBuilder { +public: + TosaSerializationOperatorBuilder( + TosaSerializationBlockBuilder *_block_builder) + : block_builder(_block_builder) {} + template + TosaSerializationOperator *build(mlir::Operation &op) const; + +private: + std::string GetTensorName(mlir::Value val) const; + TosaSerializationOperator *BuildPoolOpFromMlirOp(mlir::Operation &op, + Op opcode) const; + TosaSerializationOperator *BuildEwiseBinaryOpFromMlirOp(mlir::Operation &op, + Op opcode) const; + TosaSerializationOperator *BuildEwiseUnaryOpFromMlirOp(mlir::Operation &op, + Op opcode) const; + TosaSerializationOperator *BuildReductionOpFromMlirOp(mlir::Operation &op, + Op opcode) const; + TosaSerializationBlockBuilder *block_builder; +}; + +// This builder assumes each region only has only one block +class TosaSerializationBlockBuilder { +public: + friend class TosaSerializationOperatorBuilder; + TosaSerializationBlockBuilder(TosaSerializationBasicBlock *_block, + TosaSerializationHandler *_tsh, + mlir::Region *_region) + : block(_block), tsh(_tsh), region(_region) {} + + mlir::LogicalResult + BuildAllOpsInRegion(std::vector &return_values); + TosaSerializationBasicBlock *GetBlock() { return block; } + TosaSerializationHandler *GetTsh() { return tsh; } + +private: + TosaSerializationOperator *BuildTosaSerializationOperator( + const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op); + TosaSerializationTensor * + BuildTosaSerializationTensor(mlir::Value val, const std::string &name); + + TosaSerializationBasicBlock *block; + TosaSerializationHandler *tsh; + mlir::Region *region; + std::unordered_map tensor_map; +}; + +std::string +TosaSerializationOperatorBuilder::GetTensorName(mlir::Value val) const { + if (block_builder->tensor_map.find(val) == block_builder->tensor_map.end()) { + llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map"; + assert(0); + } + return block_builder->tensor_map[val]; +} + +// Main template to catch unimplemented translation. +template +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build(mlir::Operation &op) const { + llvm::errs() << "Translation of operator " << op.getName().getStringRef() + << " is not implemented yet\n"; + return nullptr; +} + +/* Start translating TOSA operator */ + +#define ASSERT_VECTOR_LENGTH(VECTOR, LENGTH) \ + if (VECTOR.size() != LENGTH) { \ + std::string msg; \ + msg = std::string(#VECTOR) + " is [" + std::to_string(VECTOR.size()) + \ + "], but expected to be [" + std::to_string(LENGTH) + "]\n"; \ + op.emitOpError(msg.c_str()); \ + return nullptr; \ + } + +TosaSerializationOperator * +TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, + Op opcode) const { + std::vector pad, stride, kernel; + + auto pad_attr = op.getAttr("pad").dyn_cast().getValue(); + for (auto &int_attr : pad_attr) { + pad.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(pad, 4); + + auto stride_attr = + op.getAttr("stride").dyn_cast().getValue(); + for (auto &int_attr : stride_attr) { + stride.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(stride, 2); + + auto kernel_attr = + op.getAttr("kernel").dyn_cast().getValue(); + for (auto &int_attr : kernel_attr) { + kernel.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(kernel, 2); + + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaPoolAttribute attribute(pad, kernel, stride); + auto quant_info = op.getAttrOfType( + "quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_UnaryQuantInfo; + qinfo = GetUnaryQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + opcode, Attribute_PoolAttribute, &attribute, qinfo_type, qinfo, + std::vector{input_name}, + std::vector{output_name}); + + delete qinfo; + + return tyop; +} + +TosaSerializationOperator * +TosaSerializationOperatorBuilder::BuildEwiseBinaryOpFromMlirOp( + mlir::Operation &op, Op opcode) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + opcode, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{input0_name, input1_name}, + std::vector{output_name}); + + return tyop; +} + +TosaSerializationOperator * +TosaSerializationOperatorBuilder::BuildEwiseUnaryOpFromMlirOp( + mlir::Operation &op, Op opcode) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + opcode, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +TosaSerializationOperator * +TosaSerializationOperatorBuilder::BuildReductionOpFromMlirOp( + mlir::Operation &op, Op opcode) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + int32_t axis = op.getAttr("axis").dyn_cast().getInt(); + TosaAxisAttribute attribute(axis); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + opcode, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +#define BUILD_OP_POOL2D(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + TosaSerializationOperator * \ + TosaSerializationOperatorBuilder::build( \ + mlir::Operation & op) const { \ + return BuildPoolOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \ + } + +#define BUILD_OP_ELEMENTWISE_BINARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + TosaSerializationOperator * \ + TosaSerializationOperatorBuilder::build( \ + mlir::Operation & op) const { \ + return BuildEwiseBinaryOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \ + } + +#define BUILD_OP_ELEMENTWISE_UNARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + TosaSerializationOperator * \ + TosaSerializationOperatorBuilder::build( \ + mlir::Operation & op) const { \ + return BuildEwiseUnaryOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \ + } + +#define BUILD_OP_REDUCTION(MLIR_OP_NAME, SCHEMA_OP_NAME) \ + template <> \ + TosaSerializationOperator * \ + TosaSerializationOperatorBuilder::build( \ + mlir::Operation & op) const { \ + return BuildReductionOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \ + } + +BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) +BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D) + +BUILD_OP_ELEMENTWISE_BINARY(Add, ADD) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR) +BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR) +BUILD_OP_ELEMENTWISE_BINARY(Div, INTDIV) +BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND) +BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT) +BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT) +BUILD_OP_ELEMENTWISE_BINARY(LogicalOr, LOGICAL_OR) +BUILD_OP_ELEMENTWISE_BINARY(LogicalXor, LOGICAL_XOR) +BUILD_OP_ELEMENTWISE_BINARY(Maximum, MAXIMUM) +BUILD_OP_ELEMENTWISE_BINARY(Minimum, MINIMUM) +BUILD_OP_ELEMENTWISE_BINARY(Pow, POW) +BUILD_OP_ELEMENTWISE_BINARY(Sub, SUB) + +BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS) +BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT) +BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL) +BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ) +BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP) +BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR) +BUILD_OP_ELEMENTWISE_UNARY(Log, LOG) +BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT) +BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL) +BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT) + +BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY) +BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL) +BUILD_OP_REDUCTION(ReduceMax, REDUCE_MAX) +BUILD_OP_REDUCTION(ReduceMin, REDUCE_MIN) +BUILD_OP_REDUCTION(ReduceProd, REDUCE_PRODUCT) +BUILD_OP_REDUCTION(ReduceSum, REDUCE_SUM) + +BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL) +BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER) +BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL) + +BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID) +BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH) +BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY) +BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST) + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string output_name = GetTensorName(op.getResult(0)); + TosaSerializationTensor *ts = + block_builder->GetBlock()->GetTensorByName(output_name); + if (!ts) { + op.emitOpError( + "ERROR: serialization tensor must be built before building operator"); + return nullptr; + } + +#if 0 + // Gracefully handle constants of "constant unit" type which have no value + // by creating a numpy value of 0. + auto unit_val = op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (unit_val) + { + std::vector data = { 0.0 }; + type = DType_FLOAT; + TosaSerializationHandler::ConvertF32toU8(data, u8_data); + } +#endif + + // Update tensor.data array with Const value attribute + std::vector u8_data; + DType type = ts->GetDtype(); + if (type == DType_FLOAT) { + std::vector data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back((float)val_attr.getValueAsDouble()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + TosaSerializationHandler::ConvertF32toU8(data, u8_data); + } else if (type == DType_INT8) { + std::vector data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getInt()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + TosaSerializationHandler::ConvertI8toU8(data, u8_data); + } else if (type == DType_INT16) { + std::vector data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getInt()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + TosaSerializationHandler::ConvertI16toU8(data, u8_data); + } else if (type == DType_INT32) { + std::vector data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getInt()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + TosaSerializationHandler::ConvertI32toU8(data, u8_data); + } else if (type == DType_INT48) { + std::vector data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getInt()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + TosaSerializationHandler::ConvertI48toU8(data, u8_data); + } else if (type == DType_BOOL) { + std::vector data; + + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getValue()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + + TosaSerializationHandler::ConvertBooltoU8(data, u8_data); + } else { + op.emitOpError("Unknown element type of const attribute"); + return nullptr; + } + ts->SetData(u8_data); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CONST, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{}, std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector pad, stride, dilation; + + auto pad_attr = op.getAttr("pad").dyn_cast().getValue(); + for (auto &int_attr : pad_attr) { + pad.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(pad, 4); + + auto stride_attr = + op.getAttr("stride").dyn_cast().getValue(); + for (auto &int_attr : stride_attr) { + stride.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(stride, 2); + + auto dilation_attr = + op.getAttr("dilation").dyn_cast().getValue(); + for (auto &int_attr : dilation_attr) { + dilation.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(dilation, 2); + + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaConvAttribute attribute(pad, stride, dilation); + + auto quant_info = + op.getAttrOfType("quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_ConvQuantInfo; + qinfo = GetConvQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CONV2D, Attribute_ConvAttribute, &attribute, qinfo_type, qinfo, + std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + delete qinfo; + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector pad, stride, dilation; + + auto pad_attr = op.getAttr("pad").dyn_cast().getValue(); + for (auto &int_attr : pad_attr) { + pad.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(pad, 4); + + auto stride_attr = + op.getAttr("stride").dyn_cast().getValue(); + for (auto &int_attr : stride_attr) { + stride.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(stride, 2); + + auto dilation_attr = + op.getAttr("dilation").dyn_cast().getValue(); + for (auto &int_attr : dilation_attr) { + dilation.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(dilation, 2); + + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaConvAttribute attribute(pad, stride, dilation); + + auto quant_info = + op.getAttrOfType("quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_ConvQuantInfo; + qinfo = GetConvQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, qinfo_type, + qinfo, std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + delete qinfo; + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector outpad, stride, dilation, output_shape; + + auto outpad_attr = + op.getAttr("out_pad").dyn_cast().getValue(); + for (auto &int_attr : outpad_attr) { + outpad.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(outpad, 2); + + auto stride_attr = + op.getAttr("stride").dyn_cast().getValue(); + for (auto &int_attr : stride_attr) { + stride.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(stride, 2); + + auto dilation_attr = + op.getAttr("dilation").dyn_cast().getValue(); + for (auto &int_attr : dilation_attr) { + dilation.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(dilation, 2); + + auto output_shape_attr = + op.getAttr("out_shape").dyn_cast().getValue(); + for (auto &int_attr : output_shape_attr) { + output_shape.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(output_shape, 4); + + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaTransposeConvAttribute attribute(outpad, stride, dilation, output_shape); + + auto quant_info = + op.getAttrOfType("quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_ConvQuantInfo; + qinfo = GetConvQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, + qinfo_type, qinfo, + std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + delete qinfo; + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + auto quant_info = + op.getAttrOfType("quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_ConvQuantInfo; + qinfo = GetConvQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_FULLY_CONNECTED, Attribute_NONE, nullptr, qinfo_type, qinfo, + std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + delete qinfo; + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string output_name = GetTensorName(op.getResult(0)); + + auto quant_info = op.getAttrOfType( + "quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_MatMulQuantInfo; + qinfo = GetMatMulQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_MATMUL, Attribute_NONE, nullptr, qinfo_type, qinfo, + std::vector{input0_name, input1_name}, + std::vector{output_name}); + + delete qinfo; + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_SELECT, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + int32_t min_int = + op.getAttr("min_int").dyn_cast().getInt(); + int32_t max_int = + op.getAttr("max_int").dyn_cast().getInt(); + float min_fp = op.getAttr("min_fp") + .dyn_cast() + .getValue() + .convertToFloat(); + float max_fp = op.getAttr("max_fp") + .dyn_cast() + .getValue() + .convertToFloat(); + + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CLAMP, Attribute_ClampAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + int32_t axis = op.getAttr("axis").dyn_cast().getInt(); + + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaAxisAttribute attribute(axis); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_ARGMAX, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + int32_t axis = op.getAttr("axis").dyn_cast().getInt(); + + std::vector inputs; + for (uint32_t i = 0; i < op.getNumOperands(); i++) { + std::string input_name = GetTensorName(op.getOperand(i)); + inputs.push_back(input_name); + } + + std::string output_name = GetTensorName(op.getResult(0)); + + TosaAxisAttribute attribute(axis); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CONCAT, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr, + inputs, std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + auto quant_info = op.getAttrOfType( + "quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_UnaryQuantInfo; + qinfo = GetUnaryQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_NEGATE, Attribute_NONE, nullptr, qinfo_type, qinfo, + std::vector{input_name}, + std::vector{output_name}); + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + std::vector shape; + auto shape_attr = + op.getAttr("new_shape").dyn_cast().getValue(); + for (auto &int_attr : shape_attr) { + shape.push_back(int_attr.dyn_cast().getInt()); + } + + TosaReshapeAttribute attribute(shape); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_RESHAPE, Attribute_ReshapeAttribute, &attribute, QuantInfo_NONE, + nullptr, std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + // Match padding tensor as compile-time constant attribute + // TODO: fix when MLIR dialect changes + mlir::ElementsAttr paddings_elems; + if (!matchPattern(op.getOperand(1), m_Constant(&paddings_elems))) + return nullptr; + + std::vector paddings; + for (int32_t val : paddings_elems.getValues()) { + paddings.push_back(val); + } + + TosaPadAttribute attribute(paddings, 0 /* pad_const_int */, + 0.0f /* pad_const_fp */); + + auto quant_info = + op.getAttrOfType("quantization_info"); + QuantInfo qinfo_type; + TosaQuantInfoBase *qinfo; + if (quant_info) { + qinfo_type = QuantInfo_PadQuantInfo; + qinfo = GetPadQuantInfo(quant_info); + } else { + qinfo_type = QuantInfo_NONE; + qinfo = new TosaNoneQuantInfo(); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_PAD, Attribute_PadAttribute, &attribute, qinfo_type, qinfo, + std::vector{input_name}, + std::vector{output_name}); + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + // Match perm tensor as compile-time constant attribute + // TODO: fix when MLIR dialect changes + mlir::ElementsAttr perm_elems; + if (!matchPattern(op.getOperand(1), m_Constant(&perm_elems))) + return nullptr; + + std::vector perm; + for (int32_t i = 0; i < perm_elems.getNumElements(); i++) { + perm.push_back(perm_elems.getValue(i).getInt()); + } + + TosaTransposeAttribute attribute(perm); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_TRANSPOSE, Attribute_TransposeAttribute, &attribute, QuantInfo_NONE, + nullptr, std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector start, size; + auto begin_attr = op.getAttr("start").dyn_cast().getValue(); + auto size_attr = op.getAttr("size").dyn_cast().getValue(); + + for (auto &int_attr : begin_attr) { + start.push_back(int_attr.dyn_cast().getInt()); + } + + for (auto &int_attr : size_attr) { + size.push_back(int_attr.dyn_cast().getInt()); + } + + TosaSliceAttribute attribute(start, size); + + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_SLICE, Attribute_SliceAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + std::vector multiples; + auto multiples_attr = + op.getAttr("multiples").dyn_cast().getValue(); + for (auto &int_attr : multiples_attr) { + multiples.push_back(int_attr.dyn_cast().getInt()); + } + + TosaTileAttribute attribute(multiples); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_TILE, Attribute_TileAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_GATHER, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{input0_name, input1_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string input2_name = GetTensorName(op.getOperand(2)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_SCATTER, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{input0_name, input1_name, input2_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + std::vector output_size; + auto output_size_attr = + op.getAttr("output_size").dyn_cast().getValue(); + for (auto &int_attr : output_size_attr) { + output_size.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(output_size, 2); + + std::vector stride; + auto stride_attr = + op.getAttr("stride").dyn_cast().getValue(); + for (auto &int_attr : stride_attr) { + stride.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(stride, 2); + + std::vector offset; + auto offset_attr = + op.getAttr("offset").dyn_cast().getValue(); + for (auto &int_attr : offset_attr) { + offset.push_back(int_attr.dyn_cast().getInt()); + } + ASSERT_VECTOR_LENGTH(offset, 2); + + int32_t shift = op.getAttr("shift").dyn_cast().getInt(); + + std::vector stride_fp; + auto stride_fp_attr = + op.getAttr("stride_fp").dyn_cast().getValue(); + for (auto &fp_attr : stride_fp_attr) { + stride_fp.push_back(fp_attr.dyn_cast().getValueAsDouble()); + } + ASSERT_VECTOR_LENGTH(stride_fp, 2); + + std::vector offset_fp; + auto offset_fp_attr = + op.getAttr("offset_fp").dyn_cast().getValue(); + for (auto &fp_attr : offset_fp_attr) { + offset_fp.push_back(fp_attr.dyn_cast().getValueAsDouble()); + } + ASSERT_VECTOR_LENGTH(offset_fp, 2); + + auto mode_str = + op.getAttr("mode").dyn_cast().getValue().str(); + ResizeMode mode = ResizeModeStr2Enum(mode_str); + + TosaResizeAttribute attribute(output_size, stride, offset, shift, stride_fp, + offset_fp, mode); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_RESIZE, Attribute_ResizeAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + int32_t axis = op.getAttr("axis").dyn_cast().getInt(); + + TosaAxisAttribute attribute(axis); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_REVERSE, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string output_name = GetTensorName(op.getResult(0)); + + int32_t shift = op.getAttr("shift").dyn_cast().getInt(); + + TosaMulAttribute attribute(shift); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_MUL, Attribute_MulAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input0_name, input1_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input0_name = GetTensorName(op.getOperand(0)); + std::string input1_name = GetTensorName(op.getOperand(1)); + std::string output_name = GetTensorName(op.getResult(0)); + + bool round = op.getAttr("round").dyn_cast().getValue(); + + TosaArithmeticRightShiftAttribute attribute(round); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_ARITHMETIC_RIGHT_SHIFT, Attribute_ArithmeticRightShiftAttribute, + &attribute, QuantInfo_NONE, nullptr, + std::vector{input0_name, input1_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + // Match table tensor as compile-time constant attribute + // TODO: fix when MLIR dialect changes + mlir::ElementsAttr table_elems; + if (!matchPattern(op.getOperand(1), m_Constant(&table_elems))) + return nullptr; + + std::vector table; + for (int32_t i = 0; i < table_elems.getNumElements(); i++) { + table.push_back(table_elems.getValue(i).getInt()); + } + + TosaTableAttribute attribute(table); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_TABLE, Attribute_TableAttribute, &attribute, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + int32_t input_zp = + op.getAttr("input_zp").dyn_cast().getInt(); + int32_t output_zp = + op.getAttr("output_zp").dyn_cast().getInt(); + bool scale32 = op.getAttr("scale32").dyn_cast().getValue(); + bool double_round = + op.getAttr("double_round").dyn_cast().getValue(); + bool per_channel = + op.getAttr("per_channel").dyn_cast().getValue(); + + std::vector multiplier, shift; + auto multiplier_attr = + op.getAttr("multiplier").dyn_cast().getValue(); + auto shift_attr = op.getAttr("shift").dyn_cast().getValue(); + + for (auto &int_attr : multiplier_attr) { + multiplier.push_back(int_attr.dyn_cast().getInt()); + } + + for (auto &int_attr : shift_attr) { + shift.push_back(int_attr.dyn_cast().getInt()); + } + + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaRescaleAttribute attribute(input_zp, output_zp, multiplier, shift, + scale32, double_round, per_channel); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_RESCALE, Attribute_RescaleAttribute, &attribute, QuantInfo_NONE, + nullptr, std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_CUSTOM, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector input_names, output_names; + + mlir::Region &then_region = op.getRegion(0); + mlir::Region &else_region = op.getRegion(1); + std::vector then_yields, else_yields; + TosaSerializationBasicBlock *then_block = nullptr; + TosaSerializationBasicBlock *else_block = nullptr; + + // Building then branch block + std::string then_block_name = + "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); + then_block = new TosaSerializationBasicBlock( + then_block_name, std::vector(), + std::vector(), std::vector(), + std::vector()); + assert(then_block); + block_builder->GetTsh()->GetBlocks().push_back(then_block); + + TosaSerializationBlockBuilder then_block_builder( + then_block, block_builder->GetTsh(), &then_region); + if (then_block_builder.BuildAllOpsInRegion(then_yields).failed()) { + return nullptr; + } + if (then_yields.size() != op.getNumResults()) { + op.emitOpError("BuildOpCondIf: then_region yield.size() doesn't match " + "cond_if's output size"); + return nullptr; + } + + // Building else branch block + std::string else_block_name = + "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); + else_block = new TosaSerializationBasicBlock( + else_block_name, std::vector(), + std::vector(), std::vector(), + std::vector()); + assert(else_block); + block_builder->GetTsh()->GetBlocks().push_back(else_block); + + TosaSerializationBlockBuilder else_block_builder( + else_block, block_builder->GetTsh(), &else_region); + if (else_block_builder.BuildAllOpsInRegion(else_yields).failed()) { + return nullptr; + } + if (else_yields.size() != op.getNumResults()) { + op.emitOpError("BuildOpCondIf: else_region yield.size() doesn't match " + "cond_if's output size"); + return nullptr; + } + + TosaCondIfAttribute attribute(then_block->GetName(), else_block->GetName()); + + for (size_t i = 0; i < op.getNumOperands(); i++) { + std::string input_name = GetTensorName(op.getOperand(i)); + input_names.push_back(input_name); + } + + for (size_t i = 0; i < op.getNumResults(); i++) { + std::string output_name = GetTensorName(op.getResult(i)); + output_names.push_back(output_name); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_COND_IF, Attribute_CondIfAttribute, &attribute, QuantInfo_NONE, + nullptr, input_names, output_names); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::vector input_names, output_names; + + mlir::Region &cond_region = op.getRegion(0); + mlir::Region &body_region = op.getRegion(1); + std::vector cond_yields, body_yields; + TosaSerializationBasicBlock *cond_block = nullptr; + TosaSerializationBasicBlock *body_block = nullptr; + + // Building cond branch block + std::string cond_block_name = + "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); + cond_block = new TosaSerializationBasicBlock( + cond_block_name, std::vector(), + std::vector(), std::vector(), + std::vector()); + assert(cond_block); + block_builder->GetTsh()->GetBlocks().push_back(cond_block); + + TosaSerializationBlockBuilder cond_block_builder( + cond_block, block_builder->GetTsh(), &cond_region); + if (cond_block_builder.BuildAllOpsInRegion(cond_yields).failed()) { + return nullptr; + } + if (cond_yields.size() != 1) { + op.emitOpError("BuildOpWhileLoop: cond_region yield.size() is not 1"); + return nullptr; + } + + // Building body branch block + std::string body_block_name = + "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); + body_block = new TosaSerializationBasicBlock( + body_block_name, std::vector(), + std::vector(), std::vector(), + std::vector()); + assert(body_block); + block_builder->GetTsh()->GetBlocks().push_back(body_block); + + TosaSerializationBlockBuilder body_block_builder( + body_block, block_builder->GetTsh(), &body_region); + if (body_block_builder.BuildAllOpsInRegion(body_yields).failed()) { + return nullptr; + } + if (body_yields.size() != op.getNumResults()) { + op.emitOpError("BuildOpWhileLoop: body_region yield.size() doesn't " + "match while_loop's output size"); + return nullptr; + } + + TosaWhileLoopAttribute attribute(cond_block->GetName(), + body_block->GetName()); + + for (size_t i = 0; i < op.getNumOperands(); i++) { + std::string input_name = GetTensorName(op.getOperand(i)); + input_names.push_back(input_name); + } + + for (size_t i = 0; i < op.getNumResults(); i++) { + std::string output_name = GetTensorName(op.getResult(i)); + output_names.push_back(output_name); + } + + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute, QuantInfo_NONE, + nullptr, input_names, output_names); + + return tyop; +} + +/* End translating TOSA operator */ + +mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( + std::vector &return_values) { + TosaSerializationOperator *ser_operator = nullptr; + TosaSerializationTensor *ser_tensor = nullptr; + size_t num_blocks_in_region = 0; + static int input_tensor_index = 0; + static int intermediate_tensor_index = 0; + static int output_tensor_index = 0; + TosaSerializationOperatorBuilder op_builder(this); + + for (auto &bb : region->getBlocks()) { + num_blocks_in_region++; + + if (num_blocks_in_region > 1) { + llvm::errs() << "Invalid MLIR: multiple blocks in a region\n"; + return mlir::failure(); + } + + // We always have one block for each region right now + assert(bb.isEntryBlock()); + + // Specify block input tensor name + for (auto args : bb.getArguments()) { + std::string block_input_name = + "TosaInput_" + std::to_string(input_tensor_index++); + block->GetInputs().push_back(block_input_name); + tensor_map[args] = block_input_name; + } + + // Build tensor_map + for (auto &op : bb) { + if (!(llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op))) { + for (uint32_t i = 0; i < op.getNumResults(); i++) { + std::string intermediate_tensor_name = + "layer_" + std::to_string(intermediate_tensor_index++); + tensor_map[op.getResult(i)] = intermediate_tensor_name; + } + } else { + if (llvm::isa(op)) + continue; + // Override return tensor name + for (auto val : op.getOperands()) { + // Workaround to skip mlir::tensor::CastOp before return + mlir::Operation *val_defining_op = val.getDefiningOp(); + if (llvm::isa(*val_defining_op)) + val = val_defining_op->getOperand(0); + + // Sanity check. This mlir::Value should be built in map since graph + // is DAG + if (tensor_map.find(val) == tensor_map.end()) { + llvm::errs() << "ERROR: Can't find built mlir::Value key.\n"; + return mlir::failure(); + } + std::string output_name = + "TosaOutput_" + std::to_string(output_tensor_index++); + tensor_map[val] = output_name; + block->GetOutputs().push_back(output_name); + return_values.push_back(val); + } + } + } + + // Build tensor + for (auto pair : tensor_map) { + ser_tensor = BuildTosaSerializationTensor(pair.first /* val */, + pair.second /* name */); + if (!ser_tensor) { + llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; + return mlir::failure(); + } + block->GetTensors().push_back(ser_tensor); + } + + // Build operator + for (auto &op : bb) { + if (llvm::isa(op) || llvm::isa(op) || + llvm::isa(op)) + continue; + ser_operator = BuildTosaSerializationOperator(op_builder, op); + if (!ser_operator) { + llvm::errs() << "ERROR: Failed to build TosaSerializationOperator\n"; + return mlir::failure(); + } + block->GetOperators().push_back(ser_operator); + } + } + + return mlir::success(); +} + +TosaSerializationOperator * +TosaSerializationBlockBuilder::BuildTosaSerializationOperator( + const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) { + std::string full_op_name = op.getName().getStringRef().str(); + TosaSerializationOperator *target_operator = nullptr; + + if (false) { + } +#define DEF_OPERATOR(MLIR_OP) \ + else if (llvm::isa(op)) { \ + target_operator = op_builder.build(op); \ + } +#include "operator.def" +#undef DEF_OPERATOR + else { + llvm::errs() << "unsupported tosa operator " << op.getName().getStringRef() + << "\n"; + } + + if (!target_operator) { + llvm::errs() << op.getName().getStringRef() + << " operator hasn't been translated to flatbuffer, skipped\n"; + return nullptr; + } + + // Sanity check the number of inputs/outputs of TOSA dialect matches the + // number of TOSA flatbuffer + if (op.getNumOperands() != target_operator->GetInputTensorNames().size()) { + llvm::errs() << "WARNING. MLIR operator has " << op.getNumOperands() + << " input tensors != Flatbuffer " + "operator has " + << target_operator->GetInputTensorNames().size() + << " input tensors\n"; + } + if (op.getNumResults() != target_operator->GetOutputTensorNames().size()) { + llvm::errs() << "WARNING. MLIR operator has " << op.getNumResults() + << " output tensors != Flatbuffer " + "operator has " + << target_operator->GetOutputTensorNames().size() + << " output tensors\n"; + } + + return target_operator; +} + +TosaSerializationTensor * +TosaSerializationBlockBuilder::BuildTosaSerializationTensor( + mlir::Value val, const std::string &name) { + // If tensor already created before, use that tensor directly, create a new + // one otherwise + TosaSerializationTensor *ts = block->GetTensorByName(name); + if (ts) { + return nullptr; + } + + mlir::RankedTensorType tensor = + val.getType().dyn_cast(); + std::vector shape(tensor.getShape().begin(), + tensor.getShape().end()); + DType type = Type2DType(tensor.getElementType()); + + ts = new TosaSerializationTensor(name, shape, type, std::vector()); + + return ts; +} + +mlir::LogicalResult translate2FlatBuffer(mlir::FuncOp &func, + TosaSerializationHandler &tsh) { + TosaSerializationBasicBlock *main_block; + + mlir::Region *main_region = func.getCallableRegion(); + std::vector main_returns; + + if (!main_region) { + llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n"; + return mlir::failure(); + } + + if (!tsh.GetBlocks().empty()) { + llvm::errs() << "Internal Error: TosaSerializationHandler's block list " + "must be empty\n"; + return mlir::failure(); + } + + main_block = new TosaSerializationBasicBlock( + std::string("main"), std::vector(), + std::vector(), std::vector(), + std::vector()); + assert(main_block); + tsh.GetBlocks().push_back(main_block); + + TosaSerializationBlockBuilder block_builder(main_block, &tsh, main_region); + if (block_builder.BuildAllOpsInRegion(main_returns).failed()) { + return mlir::failure(); + } + + if (main_returns.empty()) { + llvm::errs() << "Warning: graph doesn't have return values\n"; + } + + return mlir::success(); +} + +mlir::LogicalResult dumpTosaFlatbuffer(mlir::FuncOp &func) { + tosa::TosaSerializationHandler tsh; + + std::string tosa_flatbuffer_directory_fullpath; + if (translate2FlatBuffer(func, tsh).failed()) { + llvm::errs() << "Fail to translate TOSA MLIR to flatbuffer\n"; + return mlir::failure(); + } + + if (tsh.SaveFileTosaFlatbuffer(tosa_flatbuffer_filename.c_str())) { + llvm::errs() << "Fail to save flatbuffer " << tosa_flatbuffer_filename + << "\n"; + return mlir::failure(); + } + return mlir::success(); +} + +mlir::LogicalResult dumpTosaJSON(mlir::FuncOp &func) { + tosa::TosaSerializationHandler tsh; + + const char *tosa_schema = tosa_flatbuffer_schema.c_str(); + + if (!tosa_schema) { + llvm::errs() << "Flatbuffer schema not defined\n"; + return mlir::failure(); + } + + if (tsh.LoadFileSchema(tosa_schema)) { + llvm::errs() << "Error loading tosa schema file: " << tosa_schema << "\n"; + return mlir::failure(); + } + + std::string tosa_flatbuffer_directory_fullpath; + if (translate2FlatBuffer(func, tsh).failed()) { + llvm::errs() << "Fail to translate TOSA MLIR to flatbuffer\n"; + return mlir::failure(); + } + + if (tsh.SaveFileJson(tosa_flatbuffer_filename.c_str())) { + llvm::errs() << "Fail to save flatbuffer " << tosa_flatbuffer_filename + << "\n"; + return mlir::failure(); + } + + return mlir::success(); +} + +namespace mlir { + +namespace tosa { + +namespace { + +class TosaSerialize : public PassWrapper { +public: + TosaSerialize() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tosa-serialize"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Run the TOSA serialization (flatbuffer generation) pass"; + } + + void runOnFunction() override { + auto function = getFunction(); + + if (dumpTosaFlatbuffer(function).failed()) { + llvm::errs() << "Failed to generate TOSA flatbuffer...\n"; + return signalPassFailure(); + } + } +}; + +class TosaSerializeJSON : public PassWrapper { +public: + TosaSerializeJSON() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tosa-serialize-json"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Run the TOSA serialization (JSON generation) pass"; + } + + void runOnFunction() override { + auto function = getFunction(); + + if (dumpTosaJSON(function).failed()) { + llvm::errs() << "Failed to generate TOSA JSON...\n"; + return signalPassFailure(); + } + } +}; + +} // anonymous namespace + +// Creates an instance of the TOSA flatbuffer generation pass +std::unique_ptr> createTosaSerializePass() { + return std::make_unique(); +} + +std::unique_ptr> createTosaSerializeJSONPass() { + return std::make_unique(); +} + +static PassRegistration pass([] { + return createTosaSerializePass(); +}); + +static PassRegistration passJSON([] { + return createTosaSerializeJSONPass(); +}); + +} // namespace tosa +} // namespace mlir -- cgit v1.2.1