// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 with LLVM Exceptions // (the "License"); you may not use this file except in compliance with // the License. You may obtain a copy of the License at // // https://llvm.org/LICENSE.txt // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // TOSA flatbuffer generation #include "include/SerializationPasses.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tosa_serialization_handler.h" #include #include #include #include // The namespace might be confusing here. We have mlir::tosa:: defined in MLIR // and tosa:: defined in serialization library // TODO: align the namespace using namespace tosa; namespace cl = llvm::cl; llvm::cl::opt tosa_flatbuffer_filename( "tosa-flatbuffer-filename", llvm::cl::desc(""), llvm::cl::init("tosa_dump.tosa"), llvm::cl::value_desc("filename")); llvm::cl::opt tosa_flatbuffer_schema( "tosa-flatbuffer-schema", llvm::cl::desc(""), llvm::cl::init(""), llvm::cl::value_desc("filename")); // Specialize mlir::Value for std::hash and std::equal_to to be able to // build std::unordered_map namespace std { template <> struct hash { std::size_t operator()(const mlir::Value &val) const { return static_cast(mlir::hash_value(val)); } }; template <> struct equal_to { bool operator()(const mlir::Value &lhs, const mlir::Value &rhs) const { return (lhs == rhs); } }; } // namespace std static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { if (mode_str == "NEAREST_NEIGHBOR") return ResizeMode_NEAREST; else if (mode_str == "BILINEAR") return ResizeMode_BILINEAR; else return ResizeMode_UNKNOWN; } static DType Type2DType(mlir::Type element_type) { if (element_type.isF64() || element_type.isF32()) { return DType_FP32; } else if (element_type.isFloat8E5M2()) { return DType_FP8E5M2; } else if (element_type.isFloat8E4M3FN()) { return DType_FP8E4M3; } else if (element_type.isF16()) { return DType_FP16; } else if (element_type.isBF16()) { return DType_BF16; } else if (element_type.isUnsignedInteger(8)) { return DType_UINT8; } else if (element_type.isInteger(4)) { return DType_INT4; } else if (element_type.isInteger(8)) { return DType_INT8; } else if (element_type.isInteger(16)) { return DType_INT16; } else if (element_type.isInteger(32)) { return DType_INT32; } else if (element_type.isInteger(48)) { return DType_INT48; } // boolean in MLIR treated as integer with bitwidth 1 else if (element_type.isInteger(1)) { return DType_BOOL; } return DType_UNKNOWN; } static DType Type2AccDType(mlir::Type element_type) { // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>; if (element_type.isF32()) { return DType_FP32; } else if (element_type.isF16()) { return DType_FP16; } else if (element_type.isInteger(32)) { return DType_INT32; } else if (element_type.isInteger(48)) { return DType_INT48; } return DType_UNKNOWN; } class TosaSerializationBlockBuilder; class TosaSerializationRegionBuilder; std::unordered_map variable_tensor_op_map; std::unordered_map variable_tensor_flatbuffer_name_map; static int variable_tensor_index = 0; namespace { // for now, this is a global map of variables void RegisterVariableOp(mlir::Operation &op) { std::string variable_tensor_flatbuffer_name = "Variable_" + std::to_string(variable_tensor_index++); std::string variable_tensor_mlir_name = op.getAttr("name").cast().getValue().str(); variable_tensor_op_map[variable_tensor_flatbuffer_name] = &op; variable_tensor_flatbuffer_name_map[variable_tensor_mlir_name] = variable_tensor_flatbuffer_name; } } // namespace class TosaSerializationOperatorBuilder { public: TosaSerializationOperatorBuilder( TosaSerializationBlockBuilder *_block_builder) : block_builder(_block_builder) {} template TosaSerializationOperator *build(mlir::Operation &op) const; TosaSerializationHandler *GetTsh() const; TosaSerializationRegionBuilder *GetRegionBuilder() const; mlir::LogicalResult GetDataFromAttribute(mlir::Operation &op, mlir::Attribute &attr, mlir::Type element_type, std::vector &u8_data) const; // populate u8_data with either int64_value or float_value depending on // element_type mlir::LogicalResult GetU8DataFromIntOrFloatValue(int64_t int64_value, float fp_value, mlir::Type element_type, std::vector &u8_data) const; // populate u8_data with int_value depending on non-float element_type mlir::LogicalResult GetU8DataFromIntValues(const std::vector &int_values, mlir::Type element_type, std::vector &u8_data) const; // populate u8_data with fp_value depending on float element_type mlir::LogicalResult GetU8DataFromFloatValues(const std::vector &fp_values, mlir::Type element_type, std::vector &u8_data) const; private: std::string GetTensorName(mlir::Value val) const; std::string GetVariableTensorName(mlir::Operation *op) const; TosaSerializationOperator *BuildPoolOpFromMlirOp(mlir::Operation &op, Op opcode) const; TosaSerializationOperator *BuildEwiseBinaryOpFromMlirOp(mlir::Operation &op, Op opcode) const; TosaSerializationOperator *BuildEwiseUnaryOpFromMlirOp(mlir::Operation &op, Op opcode) const; TosaSerializationOperator *BuildReductionOpFromMlirOp(mlir::Operation &op, Op opcode) const; TosaSerializationBlockBuilder *block_builder; }; // This builder assumes each region only has only one block class TosaSerializationBlockBuilder { public: // constructor TosaSerializationBlockBuilder(TosaSerializationBasicBlock *_ser_block, TosaSerializationRegionBuilder *_region_builder, mlir::Block *_block) : ser_block(_ser_block), region_builder(_region_builder), block(_block) {} mlir::LogicalResult BuildAllOpsInBlock(std::vector &return_values); TosaSerializationBasicBlock *GetBlock() const { return ser_block; } TosaSerializationRegionBuilder *GetRegionBuilder() const { return region_builder; } TosaSerializationHandler *GetTsh() const; std::unordered_map &GetTensorMap() { return tensor_map; } private: TosaSerializationOperator *BuildTosaSerializationOperator( const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op); TosaSerializationTensor * BuildTosaSerializationVariableTensor(mlir::RankedTensorType tensor_type, const std::string &name, const std::string &variable_mlir_name); TosaSerializationTensor * BuildTosaSerializationTensor(mlir::Value val, const std::string &name); TosaSerializationBasicBlock *ser_block; TosaSerializationRegionBuilder *region_builder; mlir::Block *block; std::unordered_map tensor_map; std::unordered_map input_tensor_map; }; class TosaSerializationRegionBuilder { public: // Constructor TosaSerializationRegionBuilder( TosaSerializationRegion *_ser_region, mlir::Region *_region, TosaSerializationRegionBuilder *_parent_value_scope, TosaSerializationHandler *_tsh) : ser_region(_ser_region), region(_region), parent_value_scope(_parent_value_scope), tsh(_tsh) {} TosaSerializationHandler *GetTsh() const { return tsh; } mlir::LogicalResult BuildAllBlocksInRegion(bool is_top, std::vector &return_values); TosaSerializationRegionBuilder *GetParentValueScope() const { return parent_value_scope; } std::vector &GetBlockBuilders() { return block_builders; } private: TosaSerializationRegion *ser_region; mlir::Region *region; TosaSerializationRegionBuilder *parent_value_scope; TosaSerializationHandler *tsh; std::vector block_builders; }; TosaSerializationHandler *TosaSerializationOperatorBuilder::GetTsh() const { return block_builder->GetTsh(); } TosaSerializationHandler *TosaSerializationBlockBuilder::GetTsh() const { return region_builder->GetTsh(); } TosaSerializationRegionBuilder * TosaSerializationOperatorBuilder::GetRegionBuilder() const { return block_builder->GetRegionBuilder(); } std::string TosaSerializationOperatorBuilder::GetTensorName(mlir::Value val) const { auto value_scope = GetRegionBuilder(); while (value_scope) { // Traverse through each block builder in the region for (auto curr_block_builder : value_scope->GetBlockBuilders()) { const auto &tensor_map = curr_block_builder->GetTensorMap(); if (tensor_map.count(val)) { return tensor_map.at(val); } } value_scope = value_scope->GetParentValueScope(); } // Didn't find anything llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map\n"; assert(0); } // Unpack 64-bit integer attribute element and pack into a std vector. template static std::vector getDenseI64ArrayAttr(mlir::Attribute attr) { auto array_ref = attr.cast().asArrayRef(); std::vector vec; for (auto val : array_ref) { vec.push_back(val); } return vec; } // Unpack 8-bit integer attribute element and pack into a std vector. template static std::vector getDenseI8ArrayAttr(mlir::Attribute attr) { auto array_ref = attr.cast().asArrayRef(); std::vector vec; for (auto val : array_ref) { vec.push_back(val); } return vec; } std::string TosaSerializationOperatorBuilder::GetVariableTensorName( mlir::Operation *op) const { std::string variable_tensor_mlir_name = op->getAttr("name").cast().getValue().str(); if (variable_tensor_flatbuffer_name_map.find(variable_tensor_mlir_name) == variable_tensor_flatbuffer_name_map.end()) { llvm::errs() << "ERROR: Failed to find key " << variable_tensor_mlir_name << " from variable_tensor_flatbuffer_name_map\n"; assert(0); } return variable_tensor_flatbuffer_name_map[variable_tensor_mlir_name]; } mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( mlir::Operation &op, mlir::Attribute &attr, mlir::Type element_type, std::vector &u8_data) const { if (!element_type.isIntOrFloat()) { return mlir::failure(); } auto dense_attr = attr.dyn_cast(); // handle float types if (element_type.isa()) { std::vector fp_data; auto val_attr = attr.dyn_cast(); if (dense_attr) { for (auto val : dense_attr.getValues()) { fp_data.push_back(val.convertToFloat()); } } else if (val_attr) { fp_data.push_back((float)val_attr.getValueAsDouble()); } else { op.emitOpError("Unknown const attribute"); return mlir::failure(); } return GetU8DataFromFloatValues(fp_data, element_type, u8_data); } // element_type is integer type bool isInt48 = element_type.isInteger(48); std::vector i64_data; auto val_attr = attr.dyn_cast(); if (dense_attr) { for (auto valueIt : dense_attr.getValues()) { int64_t val = isInt48 ? static_cast(valueIt.getLimitedValue()) : valueIt.getSExtValue(); i64_data.push_back(val); } } else if (val_attr) { i64_data.push_back(val_attr.getInt()); } else { op.emitOpError("Unknown const attribute"); return mlir::failure(); } return GetU8DataFromIntValues(i64_data, element_type, u8_data); } mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromIntValues( const std::vector &int64_values, mlir::Type element_type, std::vector &u8_data) const { switch (element_type.getIntOrFloatBitWidth()) { case 1: { // bool use bool vec std::vector bool_values; for (auto v : int64_values) { bool bool_value = v == 0 ? false : true; bool_values.push_back(bool_value); } TosaSerializationHandler::ConvertBooltoU8(bool_values, u8_data); break; } case 4: case 8: { // I4 and I8 use int8_t vec std::vector i8_values; for (auto v : int64_values) { i8_values.push_back(static_cast(v)); } if (element_type.isInteger(4)) { TosaSerializationHandler::ConvertI4toU8(i8_values, u8_data); } else { TosaSerializationHandler::ConvertI8toU8(i8_values, u8_data); } break; } case 16: { // I16 use int16_t vec std::vector i16_values; for (auto v : int64_values) { i16_values.push_back(static_cast(v)); } TosaSerializationHandler::ConvertI16toU8(i16_values, u8_data); break; } case 32: { // I32 use int32_t vec std::vector i32_values; for (auto v : int64_values) { i32_values.push_back(static_cast(v)); } TosaSerializationHandler::ConvertI32toU8(i32_values, u8_data); break; } case 48: { // I48 use int64_t vec TosaSerializationHandler::ConvertI48toU8(int64_values, u8_data); break; } default: { // unsupported bit widths return mlir::failure(); } } return mlir::success(); } mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromFloatValues( const std::vector &fp_values, mlir::Type element_type, std::vector &u8_data) const { assert( element_type .isa()); // this should only be called for float type if (element_type.isF16()) { TosaSerializationHandler::ConvertF16toU8(fp_values, u8_data); } else if (element_type.isBF16()) { TosaSerializationHandler::ConvertBF16toU8(fp_values, u8_data); } else if (element_type.isFloat8E4M3FN()) { TosaSerializationHandler::ConvertFP8E4M3toU8(fp_values, u8_data); } else if (element_type.isFloat8E5M2()) { TosaSerializationHandler::ConvertFP8E5M2toU8(fp_values, u8_data); } else if (element_type.isF32()) { TosaSerializationHandler::ConvertF32toU8(fp_values, u8_data); } else { return mlir::failure(); } return mlir::success(); } mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromIntOrFloatValue( int64_t int64_value, float fp_value, mlir::Type element_type, std::vector &u8_data) const { if (element_type.isa()) { return GetU8DataFromFloatValues({fp_value}, element_type, u8_data); } else { return GetU8DataFromIntValues({int64_value}, element_type, u8_data); } } // Main template to catch unimplemented translation. template TosaSerializationOperator * TosaSerializationOperatorBuilder::build(mlir::Operation &op) const { llvm::errs() << "Translation of operator " << op.getName().getStringRef() << " is not implemented yet\n"; return nullptr; } /* Start translating TOSA operator */ #define ASSERT_VECTOR_LENGTH(VECTOR, LENGTH) \ if (VECTOR.size() != LENGTH) { \ std::string msg; \ msg = std::string(#VECTOR) + " is [" + std::to_string(VECTOR.size()) + \ "], but expected to be [" + std::to_string(LENGTH) + "]\n"; \ op.emitOpError(msg.c_str()); \ return nullptr; \ } TosaSerializationOperator * TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, Op opcode) const { auto pad = getDenseI64ArrayAttr(op.getAttr("pad")); ASSERT_VECTOR_LENGTH(pad, 4); auto stride = getDenseI64ArrayAttr(op.getAttr("stride")); ASSERT_VECTOR_LENGTH(stride, 2); auto kernel = getDenseI64ArrayAttr(op.getAttr("kernel")); ASSERT_VECTOR_LENGTH(kernel, 2); DType acc_dtype = DType_FP32; // AvgPool has acc_type, MaxPool does not if (op.hasAttr("acc_type")) { auto acc_type = op.getAttr("acc_type").cast().getValue(); acc_dtype = Type2AccDType(acc_type); } std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); int32_t input_zp = op.hasAttr("input_zp") ? input_zp = op.getAttr("input_zp").cast().getInt() : 0; int32_t output_zp = op.hasAttr("output_zp") ? output_zp = op.getAttr("output_zp").cast().getInt() : 0; TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator(opcode, Attribute_PoolAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); return tyop; } TosaSerializationOperator * TosaSerializationOperatorBuilder::BuildEwiseBinaryOpFromMlirOp( mlir::Operation &op, Op opcode) const { std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( opcode, Attribute_NONE, nullptr, std::vector{input0_name, input1_name}, std::vector{output_name}); return tyop; } TosaSerializationOperator * TosaSerializationOperatorBuilder::BuildEwiseUnaryOpFromMlirOp( mlir::Operation &op, Op opcode) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( opcode, Attribute_NONE, nullptr, std::vector{input_name}, std::vector{output_name}); return tyop; } TosaSerializationOperator * TosaSerializationOperatorBuilder::BuildReductionOpFromMlirOp( mlir::Operation &op, Op opcode) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); int32_t axis = op.getAttr("axis").dyn_cast().getInt(); TosaAxisAttribute attribute(axis); TosaSerializationOperator *tyop = new TosaSerializationOperator(opcode, Attribute_AxisAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); return tyop; } #define BUILD_OP_POOL2D(MLIR_OP_NAME, SCHEMA_OP_NAME) \ template <> \ TosaSerializationOperator * \ TosaSerializationOperatorBuilder::build( \ mlir::Operation & op) const { \ return BuildPoolOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \ } #define BUILD_OP_ELEMENTWISE_BINARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ template <> \ TosaSerializationOperator * \ TosaSerializationOperatorBuilder::build( \ mlir::Operation & op) const { \ return BuildEwiseBinaryOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \ } #define BUILD_OP_ELEMENTWISE_UNARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ template <> \ TosaSerializationOperator * \ TosaSerializationOperatorBuilder::build( \ mlir::Operation & op) const { \ return BuildEwiseUnaryOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \ } #define BUILD_OP_REDUCTION(MLIR_OP_NAME, SCHEMA_OP_NAME) \ template <> \ TosaSerializationOperator * \ TosaSerializationOperatorBuilder::build( \ mlir::Operation & op) const { \ return BuildReductionOpFromMlirOp(op, Op_##SCHEMA_OP_NAME); \ } BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D) BUILD_OP_ELEMENTWISE_BINARY(Add, ADD) BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND) BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR) BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR) BUILD_OP_ELEMENTWISE_BINARY(IntDiv, INTDIV) BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND) BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT) BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT) BUILD_OP_ELEMENTWISE_BINARY(LogicalOr, LOGICAL_OR) BUILD_OP_ELEMENTWISE_BINARY(LogicalXor, LOGICAL_XOR) BUILD_OP_ELEMENTWISE_BINARY(Maximum, MAXIMUM) BUILD_OP_ELEMENTWISE_BINARY(Minimum, MINIMUM) BUILD_OP_ELEMENTWISE_BINARY(Pow, POW) BUILD_OP_ELEMENTWISE_BINARY(Sub, SUB) BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS) BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT) BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL) BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ) BUILD_OP_ELEMENTWISE_UNARY(Cos, COS) BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP) BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR) BUILD_OP_ELEMENTWISE_UNARY(Log, LOG) BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT) BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL) BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT) BUILD_OP_ELEMENTWISE_UNARY(Sin, SIN) BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY) BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL) BUILD_OP_REDUCTION(ReduceMax, REDUCE_MAX) BUILD_OP_REDUCTION(ReduceMin, REDUCE_MIN) BUILD_OP_REDUCTION(ReduceProd, REDUCE_PRODUCT) BUILD_OP_REDUCTION(ReduceSum, REDUCE_SUM) BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL) BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER) BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL) BUILD_OP_ELEMENTWISE_UNARY(Erf, ERF) BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID) BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH) BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY) BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST) BUILD_OP_ELEMENTWISE_BINARY(AddShape, ADD_SHAPE) BUILD_OP_ELEMENTWISE_BINARY(SubShape, SUB_SHAPE) BUILD_OP_ELEMENTWISE_BINARY(MulShape, MUL_SHAPE) BUILD_OP_ELEMENTWISE_BINARY(DivShape, DIV_SHAPE) template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationTensor *ts = block_builder->GetBlock()->GetTensorByName(output_name); if (!ts) { op.emitOpError( "ERROR: serialization tensor must be built before building operator"); return nullptr; } // Update tensor.data array with Const value attribute mlir::Attribute value_attr = op.getAttr("value"); if (!value_attr) { op.emitOpError("ERROR: tosa.const_shape doesn't have value"); return nullptr; } assert(ts->GetDtype() == DType::DType_SHAPE); std::vector u8_data; std::vector data; auto dense_attr = op.getAttr(llvm::StringRef("value")) .dyn_cast(); if (!dense_attr) { op.emitOpError("Unknown const attribute"); return nullptr; } for (auto valueIt : dense_attr.getValues()) { int64_t val = valueIt.getSExtValue(); data.push_back(val); } TosaSerializationHandler::ConvertI64toU8(data, u8_data); ts->SetData(u8_data); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONST_SHAPE, Attribute_NONE, nullptr, std::vector{}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationTensor *ts = block_builder->GetBlock()->GetTensorByName(output_name); if (!ts) { op.emitOpError( "ERROR: serialization tensor must be built before building operator"); return nullptr; } // Update tensor.data array with Const value attribute mlir::Attribute value_attr = op.getAttr("value"); if (!value_attr) { op.emitOpError("ERROR: tosa.const doesn't have value"); return nullptr; } std::vector u8_data; mlir::Attribute attr = op.getAttr(llvm::StringRef("value")); mlir::Type element_type = llvm::cast(op.getResult(0).getType()).getElementType(); if (GetDataFromAttribute(op, attr, element_type, u8_data).failed()) { op.emitOpError("ERROR: GetDataFromAttribute() fails when building value of " "const tensor"); return nullptr; } ts->SetData(u8_data); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONST, Attribute_NONE, nullptr, std::vector{}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { auto pad = getDenseI64ArrayAttr(op.getAttr("pad")); ASSERT_VECTOR_LENGTH(pad, 4); auto stride = getDenseI64ArrayAttr(op.getAttr("stride")); ASSERT_VECTOR_LENGTH(stride, 2); auto dilation = getDenseI64ArrayAttr(op.getAttr("dilation")); ASSERT_VECTOR_LENGTH(dilation, 2); std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); int32_t input_zp = op.hasAttr("input_zp") ? op.getAttr("input_zp").cast().getInt() : 0; int32_t weight_zp = op.hasAttr("weight_zp") ? op.getAttr("weight_zp").cast().getInt() : 0; bool local_bound = op.hasAttr("local_bound") ? op.getAttr("local_bound").dyn_cast().getValue() : false; auto acc_type = op.getAttr("acc_type").cast().getValue(); auto acc_dtype = Type2AccDType(acc_type); TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV2D, Attribute_ConvAttribute, &attribute, std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { auto pad = getDenseI64ArrayAttr(op.getAttr("pad")); ASSERT_VECTOR_LENGTH(pad, 6); auto stride = getDenseI64ArrayAttr(op.getAttr("stride")); ASSERT_VECTOR_LENGTH(stride, 3); auto dilation = getDenseI64ArrayAttr(op.getAttr("dilation")); ASSERT_VECTOR_LENGTH(dilation, 3); std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); int32_t input_zp = op.hasAttr("input_zp") ? op.getAttr("input_zp").cast().getInt() : 0; int32_t weight_zp = op.hasAttr("weight_zp") ? op.getAttr("weight_zp").cast().getInt() : 0; bool local_bound = op.hasAttr("local_bound") ? op.getAttr("local_bound").dyn_cast().getValue() : false; auto acc_type = op.getAttr("acc_type").cast().getValue(); auto acc_dtype = Type2AccDType(acc_type); TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV3D, Attribute_ConvAttribute, &attribute, std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { auto pad = getDenseI64ArrayAttr(op.getAttr("pad")); ASSERT_VECTOR_LENGTH(pad, 4); auto stride = getDenseI64ArrayAttr(op.getAttr("stride")); ASSERT_VECTOR_LENGTH(stride, 2); auto dilation = getDenseI64ArrayAttr(op.getAttr("dilation")); ASSERT_VECTOR_LENGTH(dilation, 2); std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); int32_t input_zp = op.hasAttr("input_zp") ? op.getAttr("input_zp").cast().getInt() : 0; int32_t weight_zp = op.hasAttr("weight_zp") ? op.getAttr("weight_zp").cast().getInt() : 0; bool local_bound = op.hasAttr("local_bound") ? op.getAttr("local_bound").dyn_cast().getValue() : false; auto acc_type = op.getAttr("acc_type").cast().getValue(); auto acc_dtype = Type2AccDType(acc_type); TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { auto out_pad = getDenseI64ArrayAttr(op.getAttr("out_pad")); ASSERT_VECTOR_LENGTH(out_pad, 4); auto stride = getDenseI64ArrayAttr(op.getAttr("stride")); ASSERT_VECTOR_LENGTH(stride, 2); std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); int32_t input_zp = op.hasAttr("input_zp") ? op.getAttr("input_zp").cast().getInt() : 0; int32_t weight_zp = op.hasAttr("weight_zp") ? op.getAttr("weight_zp").cast().getInt() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast(); bool local_bound = op.hasAttr("local_bound") ? op.getAttr("local_bound").dyn_cast().getValue() : false; auto acc_type = op.getAttr("acc_type").cast().getValue(); auto acc_dtype = Type2AccDType(acc_type); TosaTransposeConvAttribute attribute(out_pad, stride, input_zp, weight_zp, local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); int32_t input_zp = op.hasAttr("input_zp") ? op.getAttr("input_zp").cast().getInt() : 0; int32_t weight_zp = op.hasAttr("weight_zp") ? op.getAttr("weight_zp").cast().getInt() : 0; TosaFullyConnectedAttribute attribute(input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_FULLY_CONNECTED, Attribute_FullyConnectedAttribute, &attribute, std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); int32_t A_zp = op.hasAttr("a_zp") ? op.getAttr("a_zp").cast().getInt() : 0; int32_t B_zp = op.hasAttr("b_zp") ? op.getAttr("b_zp").cast().getInt() : 0; TosaMatMulAttribute attribute(A_zp, B_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_MATMUL, Attribute_MatMulAttribute, &attribute, std::vector{input0_name, input1_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_SELECT, Attribute_NONE, nullptr, std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { auto min_val_attr = op.getAttr("min_val"); auto max_val_attr = op.getAttr("max_val"); mlir::Type input_element_type = llvm::cast(op.getOperand(0).getType()).getElementType(); if (auto quantType = llvm::dyn_cast( input_element_type)) { input_element_type = quantType.getStorageType(); } std::vector min_val, max_val; float min_fp, max_fp; int64_t min_int, max_int; if (input_element_type.isa()) { min_fp = mlir::cast(min_val_attr).getValue().convertToFloat(); max_fp = mlir::cast(max_val_attr).getValue().convertToFloat(); min_int = max_int = 0; } else { assert(input_element_type.isa()); min_int = mlir::cast(min_val_attr).getInt(); max_int = mlir::cast(max_val_attr).getInt(); min_fp = max_fp = 0.f; } if (GetU8DataFromIntOrFloatValue(min_int, min_fp, input_element_type, min_val) .failed()) { op.emitOpError("Failed to serialize min value"); return nullptr; } if (GetU8DataFromIntOrFloatValue(max_int, max_fp, input_element_type, max_val) .failed()) { op.emitOpError("Failed to serialize max value"); return nullptr; } std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); TosaClampAttribute attribute(min_val, max_val); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CLAMP, Attribute_ClampAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { int32_t axis = op.getAttr("axis").dyn_cast().getInt(); std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); TosaAxisAttribute attribute(axis); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_ARGMAX, Attribute_AxisAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { int32_t axis = op.getAttr("axis").dyn_cast().getInt(); std::vector inputs; for (uint32_t i = 0; i < op.getNumOperands(); i++) { std::string input_name = GetTensorName(op.getOperand(i)); inputs.push_back(input_name); } std::string output_name = GetTensorName(op.getResult(0)); TosaAxisAttribute attribute(axis); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONCAT, Attribute_AxisAttribute, &attribute, inputs, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::vector inputs; for (uint32_t i = 0; i < op.getNumOperands(); i++) { std::string input_name = GetTensorName(op.getOperand(i)); inputs.push_back(input_name); } std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONCAT_SHAPE, Attribute_NONE, nullptr, inputs, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); int32_t input1_zp = op.hasAttr("input1_zp") ? op.getAttr("input1_zp").cast().getInt() : 0; int32_t output_zp = op.hasAttr("output_zp") ? op.getAttr("output_zp").cast().getInt() : 0; TosaNegateAttribute attribute(input1_zp, output_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_NEGATE, Attribute_NegateAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string shape_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_RESHAPE, Attribute_NONE, nullptr, std::vector{input_name, shape_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string padding_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); auto pad_op = llvm::cast(op); auto input_zp_attr = pad_op.getInputZpAttr(); // pad_const includes the zero point if the tensor uses a zero point. int32_t pad_const_int = input_zp_attr ? input_zp_attr.getInt() : 0; float pad_const_fp = 0.f; if (auto tensor = pad_op.getPadConst()) { // Match pad_const tensor as compile-time constant attribute if present. mlir::DenseElementsAttr attr; if (!matchPattern(tensor, m_Constant(&attr))) return nullptr; assert(attr.getNumElements() == 1); auto elementTy = attr.getElementType(); if (elementTy.isa()) { pad_const_int = (attr.getValues()[0]).getSExtValue(); } else if (elementTy.isa()) { pad_const_fp = (attr.getValues()[0]).convertToFloat(); } else { op.emitOpError("Unknown const attribute"); return nullptr; } } std::vector pad_const; mlir::Type input_element_type = llvm::cast(op.getOperand(0).getType()).getElementType(); if (GetU8DataFromIntOrFloatValue(pad_const_int, pad_const_fp, input_element_type, pad_const) .failed()) { op.emitOpError("Failed to serialize pad_const value"); return nullptr; } TosaPadAttribute attribute(pad_const); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_PAD, Attribute_PadAttribute, &attribute, std::vector{input_name, padding_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); int32_t axis = op.getAttr("axis").dyn_cast().getInt(); TosaAxisAttribute attribute(axis); TosaSerializationOperator *tyop = new TosaSerializationOperator(Op_DIM, Attribute_AxisAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); // Match perm tensor as compile-time constant attribute mlir::ElementsAttr perm_elems; if (!matchPattern(op.getOperand(1), m_Constant(&perm_elems))) return nullptr; std::vector perm; for (auto value : perm_elems.getValues()) { perm.push_back(value.getInt()); } TosaTransposeAttribute attribute(perm); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE, Attribute_TransposeAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string start_name = GetTensorName(op.getOperand(1)); std::string size_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_SLICE, Attribute_NONE, nullptr, std::vector{input_name, start_name, size_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TILE, Attribute_NONE, nullptr, std::vector{input0_name, input1_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_GATHER, Attribute_NONE, nullptr, std::vector{input0_name, input1_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_SCATTER, Attribute_NONE, nullptr, std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string scale_name = GetTensorName(op.getOperand(1)); std::string offset_name = GetTensorName(op.getOperand(2)); std::string border_name = GetTensorName(op.getOperand(3)); std::string output_name = GetTensorName(op.getResult(0)); auto mode_str = op.getAttr("mode").dyn_cast().getValue().str(); ResizeMode mode = ResizeModeStr2Enum(mode_str); TosaResizeAttribute attribute({}, {}, {}, mode); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_RESIZE, Attribute_ResizeAttribute, &attribute, std::vector{input_name, scale_name, offset_name, border_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); int32_t axis = op.getAttr("axis").dyn_cast().getInt(); TosaAxisAttribute attribute(axis); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_REVERSE, Attribute_AxisAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { mlir::tosa::MulOp mul_op = mlir::cast(op); std::string input0_name = GetTensorName(mul_op.getInput1()); std::string input1_name = GetTensorName(mul_op.getInput2()); std::string output_name = GetTensorName(mul_op.getOutput()); std::string shift_name = GetTensorName(mul_op.getShift()); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_MUL, Attribute_NONE, nullptr, {input0_name, input1_name, shift_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input0_name = GetTensorName(op.getOperand(0)); std::string input1_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); bool round = op.getAttr("round").dyn_cast().getValue(); TosaArithmeticRightShiftAttribute attribute(round); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_ARITHMETIC_RIGHT_SHIFT, Attribute_ArithmeticRightShiftAttribute, &attribute, std::vector{input0_name, input1_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); // Match table tensor as compile-time constant attribute mlir::ElementsAttr table_elems; if (!matchPattern(op.getOperand(1), m_Constant(&table_elems))) return nullptr; std::vector table; for (auto value : table_elems.getValues()) { table.push_back(value.getInt()); } TosaTableAttribute attribute(table); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TABLE, Attribute_TableAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { int32_t input_zp = op.getAttr("input_zp").dyn_cast().getInt(); int32_t output_zp = op.getAttr("output_zp").dyn_cast().getInt(); bool scale32 = op.getAttr("scale32").dyn_cast().getValue(); bool double_round = op.getAttr("double_round").dyn_cast().getValue(); bool per_channel = op.getAttr("per_channel").dyn_cast().getValue(); bool input_unsigned = op.getAttr("input_unsigned").dyn_cast().getValue(); bool output_unsigned = op.getAttr("output_unsigned").dyn_cast().getValue(); auto input = op.getOperand(0); auto input_ty = input.getType().cast(); auto output = op.getResult(0); auto output_ty = output.getType().cast(); std::string input_name = GetTensorName(input); std::string multiplier_name = GetTensorName(op.getOperand(1)); std::string shift_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(output); TosaRescaleAttribute attribute(input_zp, output_zp, scale32, double_round, per_channel, input_unsigned, output_unsigned); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_RESCALE, Attribute_RescaleAttribute, &attribute, std::vector{input_name, multiplier_name, shift_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); const std::string implementation_attrs = op.getAttr("implementation_attrs") .cast() .getValue() .str(); std::vector attrs_data(implementation_attrs.size()); memcpy(attrs_data.data(), implementation_attrs.data(), attrs_data.size()); TosaCustomAttribute attribute( op.getAttr("operator_name").cast().getValue().str(), op.getAttr("domain_name").cast().getValue().str(), attrs_data); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CUSTOM, Attribute_CustomAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); return tyop; } namespace { // serialize a region and all its blocks, and return region's return values TosaSerializationRegion * BuildRegion(mlir::Region ®ion, const std::string region_name, const bool isolated_from_above, TosaSerializationRegionBuilder *curr_region_builder, TosaSerializationHandler *tsh, std::vector &return_values, bool is_top = false) { TosaSerializationRegion *ser_region = new TosaSerializationRegion(region_name, {}); assert(ser_region); tsh->GetRegions().push_back(ser_region); TosaSerializationRegionBuilder *parent_value_scope = isolated_from_above ? nullptr : curr_region_builder; TosaSerializationRegionBuilder region_builder(ser_region, ®ion, parent_value_scope, tsh); if (region_builder.BuildAllBlocksInRegion(is_top, return_values).failed()) { return nullptr; } return ser_region; } static int input_tensor_index = 0; static int intermediate_tensor_index = 0; static int output_tensor_index = 0; } // namespace template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { const std::string op_name = op.getName().getStringRef().str(); const bool isolated_from_above = op.hasTrait(); auto curr_region_builder = GetRegionBuilder(); std::vector input_names, output_names; std::vector then_yields, else_yields; auto tsh = GetTsh(); mlir::Region &then_region = op.getRegion(0); mlir::Region &else_region = op.getRegion(1); const std::string then_region_name = op_name + "_then_region"; TosaSerializationRegion *ser_then_region = BuildRegion(then_region, then_region_name, isolated_from_above, curr_region_builder, tsh, then_yields); if (!ser_then_region) { return nullptr; } if (then_yields.size() != op.getNumResults()) { op.emitOpError("BuildOpCondIf: then_region yield.size() doesn't match " "cond_if's output size"); return nullptr; } const std::string else_region_name = op_name + "_else_region"; TosaSerializationRegion *ser_else_region = BuildRegion(else_region, else_region_name, isolated_from_above, curr_region_builder, tsh, else_yields); if (!ser_else_region) { return nullptr; } if (else_yields.size() != op.getNumResults()) { op.emitOpError("BuildOpCondIf: else_region yield.size() doesn't match " "cond_if's output size"); return nullptr; } TosaCondIfAttribute attribute(then_region_name, else_region_name); for (size_t i = 0; i < op.getNumOperands(); i++) { std::string input_name = GetTensorName(op.getOperand(i)); input_names.push_back(input_name); } for (size_t i = 0; i < op.getNumResults(); i++) { std::string output_name = GetTensorName(op.getResult(i)); output_names.push_back(output_name); } TosaSerializationOperator *tyop = new TosaSerializationOperator(Op_COND_IF, Attribute_CondIfAttribute, &attribute, input_names, output_names); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { const std::string op_name = op.getName().getStringRef().str(); const bool isolated_from_above = op.hasTrait(); auto curr_region_builder = GetRegionBuilder(); std::vector input_names, output_names; auto tsh = GetTsh(); mlir::Region &cond_region = op.getRegion(0); mlir::Region &body_region = op.getRegion(1); std::vector cond_yields, body_yields; const std::string cond_region_name = op_name + "_cond_region"; TosaSerializationRegion *ser_cond_region = BuildRegion(cond_region, cond_region_name, isolated_from_above, curr_region_builder, tsh, cond_yields); if (!ser_cond_region) { return nullptr; } if (cond_yields.size() != 1) { op.emitOpError("BuildOpWhileLoop: cond_region yield.size() is not 1"); return nullptr; } const std::string body_region_name = op_name + "_body_region"; TosaSerializationRegion *ser_body_region = BuildRegion(body_region, body_region_name, isolated_from_above, curr_region_builder, tsh, body_yields); if (!ser_body_region) { return nullptr; } if (body_yields.size() != op.getNumResults()) { op.emitOpError("BuildOpWhileLoop: body_region yield.size() doesn't " "match while_loop's output size"); return nullptr; } TosaWhileLoopAttribute attribute(cond_region_name, body_region_name); for (size_t i = 0; i < op.getNumOperands(); i++) { std::string input_name = GetTensorName(op.getOperand(i)); input_names.push_back(input_name); } for (size_t i = 0; i < op.getNumResults(); i++) { std::string output_name = GetTensorName(op.getResult(i)); output_names.push_back(output_name); } TosaSerializationOperator *tyop = new TosaSerializationOperator(Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute, input_names, output_names); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_real_name = GetTensorName(op.getResult(0)); std::string output_imag_name = GetTensorName(op.getResult(1)); bool local_bound = op.hasAttr("local_bound") ? op.getAttr("local_bound").dyn_cast().getValue() : false; TosaRFFTAttribute attribute(local_bound); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_RFFT2D, Attribute_RFFTAttribute, &attribute, std::vector{input_name}, std::vector{output_real_name, output_imag_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { bool inverse = op.getAttr("inverse").dyn_cast().getValue(); bool local_bound = op.hasAttr("local_bound") ? op.getAttr("local_bound").dyn_cast().getValue() : false; std::string input_real_name = GetTensorName(op.getOperand(0)); std::string input_imag_name = GetTensorName(op.getOperand(1)); std::string output_real_name = GetTensorName(op.getResult(0)); std::string output_imag_name = GetTensorName(op.getResult(1)); TosaFFTAttribute attribute(inverse, local_bound); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_FFT2D, Attribute_FFTAttribute, &attribute, std::vector{input_real_name, input_imag_name}, std::vector{output_real_name, output_imag_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetVariableTensorName(&op); std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator(Op_IDENTITY, Attribute_NONE, nullptr, std::vector{input_name}, std::vector{output_name}); return tyop; } template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetVariableTensorName(&op); TosaSerializationOperator *tyop = new TosaSerializationOperator(Op_IDENTITY, Attribute_NONE, nullptr, std::vector{input_name}, std::vector{output_name}); return tyop; } /* End translating TOSA operator */ mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion( bool is_top, std::vector &return_values) { std::string region_name = ser_region->GetName(); int block_index = 0; for (auto &block : this->region->getBlocks()) { // must name first block of top region "main" const std::string block_name = (is_top && block_index == 0) ? "main" : (region_name + "_bb" + std::to_string(block_index++)); TosaSerializationBasicBlock *ser_block = new TosaSerializationBasicBlock( block_name, region_name, std::vector(), std::vector(), std::vector(), std::vector()); // build the block TosaSerializationBlockBuilder block_builder(ser_block, this, &block); // Region Builders need access to block builders block_builders.push_back(&block_builder); if (block_builder.BuildAllOpsInBlock(return_values).failed()) { return mlir::failure(); } if (return_values.empty()) { llvm::errs() << "BWarning: graph doesn't have return values\n"; } // Add serialized block to serialized region ser_region->GetBlocks().push_back(ser_block); } return mlir::success(); } mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( std::vector &return_values) { TosaSerializationOperator *ser_operator = nullptr; TosaSerializationTensor *ser_tensor = nullptr; size_t num_blocks_in_region = 0; TosaSerializationOperatorBuilder op_builder(this); // Specify block input tensor name for (auto args : block->getArguments()) { std::string block_input_name = "TosaInput_" + std::to_string(input_tensor_index++); ser_block->GetInputs().push_back(block_input_name); tensor_map[args] = block_input_name; input_tensor_map[args] = block_input_name; } // Build tensor_map for (auto &op : block->getOperations()) { if (llvm::isa(op)) { RegisterVariableOp(op); } else if (!(llvm::isa(op) || llvm::isa(op) || llvm::isa(op))) { for (uint32_t i = 0; i < op.getNumResults(); i++) { std::string intermediate_tensor_name = "layer_" + std::to_string(intermediate_tensor_index++); tensor_map[op.getResult(i)] = intermediate_tensor_name; } } else { if (llvm::isa(op)) continue; // Override return tensor name for (auto val : op.getOperands()) { // Workaround to skip mlir::tensor::CastOp before return mlir::Operation *val_defining_op = val.getDefiningOp(); if (val_defining_op) { if (llvm::isa(*val_defining_op)) val = val_defining_op->getOperand(0); } // Sanity check. This mlir::Value should be built in map since graph // is DAG if (tensor_map.find(val) == tensor_map.end()) { llvm::errs() << "ERROR: Can't find built mlir::Value key.\n"; return mlir::failure(); } // If returned value is block input, short-circuit the tensor name // Otherwise, build a new output name and override the origin tensor // name if (input_tensor_map.find(val) != input_tensor_map.end()) { ser_block->GetOutputs().push_back(input_tensor_map[val]); return_values.push_back(val); } else { std::string output_name = "TosaOutput_" + std::to_string(output_tensor_index++); tensor_map[val] = output_name; ser_block->GetOutputs().push_back(output_name); return_values.push_back(val); } } } } // Build variable tensor for (auto pair : variable_tensor_op_map) { mlir::Operation *op = pair.second; mlir::Value val = op->getResult(0); mlir::RankedTensorType tensor_type = op->getAttr("type") .cast() .getValue() .cast(); std::string variable_mlir_name = op->getAttr("name").cast().getValue().str(); ser_tensor = BuildTosaSerializationVariableTensor( tensor_type /* tensor_type */, pair.first /* flatbuffer name */, variable_mlir_name); if (!ser_tensor) { llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; return mlir::failure(); } // Initialize if "initial_value" attribute exists. If not, set data to all // zeros mlir::Attribute initial_value = op->getAttr("initial_value"); std::vector u8_data; if (initial_value) { if (initial_value.isa()) { if (op_builder .GetDataFromAttribute(*op, initial_value, tensor_type.getElementType(), u8_data) .failed()) { llvm::errs() << "ERROR: GetDataFromAttribute() fails when building " "initial_value of variable tensor\n"; return mlir::failure(); } } else { llvm::errs() << "ERROR: Unknown initial_value attribute type\n"; return mlir::failure(); } } else { TosaSerializationHandler::ForceAlignTensorData(u8_data); } ser_tensor->SetData(u8_data); ser_block->GetTensors().push_back(ser_tensor); } // Build tensor // The tensor_map is sorted by hashed mlir::Value types. // For serialization, sort tensors alphabetically by name for a // deterministic and human-friendly ordering. std::map tensor_name_sort; for (auto pair : tensor_map) tensor_name_sort[pair.second] = pair.first; for (auto pair : tensor_name_sort) { mlir::RankedTensorType tensor_type = pair.second.getType().cast(); ser_tensor = BuildTosaSerializationTensor(pair.second /* val */, pair.first /* name */); if (!ser_tensor) { llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; return mlir::failure(); } ser_block->GetTensors().push_back(ser_tensor); } // Build operator for (auto &op : block->getOperations()) { if (llvm::isa(op) || llvm::isa(op) || llvm::isa(op) || llvm::isa(op)) continue; ser_operator = BuildTosaSerializationOperator(op_builder, op); if (!ser_operator) { llvm::errs() << "ERROR: Failed to build TosaSerializationOperator\n"; return mlir::failure(); } ser_block->GetOperators().push_back(ser_operator); } return mlir::success(); } TosaSerializationOperator * TosaSerializationBlockBuilder::BuildTosaSerializationOperator( const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) { TosaSerializationOperator *target_operator = nullptr; if (llvm::isa(op)) { target_operator = op_builder.build(op); } else if (llvm::isa(op)) { target_operator = op_builder.build(op); } #define DEF_OPERATOR(MLIR_OP) \ else if (llvm::isa(op)) { \ target_operator = op_builder.build(op); \ } #include "operator.def" #undef DEF_OPERATOR else { llvm::errs() << "unsupported tosa operator " << op.getName().getStringRef() << "\n"; } if (!target_operator) { llvm::errs() << op.getName().getStringRef() << " operator hasn't been translated to flatbuffer, skipped\n"; return nullptr; } if (llvm::isa(op) || llvm::isa(op)) { return target_operator; } // Sanity check the number of inputs/outputs of TOSA dialect matches the // number of TOSA flatbuffer if (op.getNumOperands() != target_operator->GetInputTensorNames().size()) { llvm::errs() << "WARNING: MLIR operator has " << op.getNumOperands() << " input tensors != Flatbuffer " "operator has " << target_operator->GetInputTensorNames().size() << " input tensors\n"; } if (op.getNumResults() != target_operator->GetOutputTensorNames().size()) { llvm::errs() << "WARNING: MLIR operator has " << op.getNumResults() << " output tensors != Flatbuffer " "operator has " << target_operator->GetOutputTensorNames().size() << " output tensors\n"; } return target_operator; } TosaSerializationTensor * TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor( mlir::RankedTensorType tensor_type, const std::string &name, const std::string &variable_mlir_name) { // If tensor already created before, use that tensor directly, create a new // one otherwise TosaSerializationTensor *ts = ser_block->GetTensorByName(name); if (ts) { return nullptr; } std::vector shape(tensor_type.getShape().begin(), tensor_type.getShape().end()); DType type = Type2DType(tensor_type.getElementType()); ts = new TosaSerializationTensor(name, shape, type, std::vector(), /* is_variable = */ true, /* is_unranked = */ false, variable_mlir_name); return ts; } TosaSerializationTensor * TosaSerializationBlockBuilder::BuildTosaSerializationTensor( mlir::Value val, const std::string &name) { // If tensor already created before, use that tensor directly, create a new // one otherwise TosaSerializationTensor *ts = ser_block->GetTensorByName(name); if (ts) { return nullptr; } // handling of tosa.shape values if (auto shape_ty = val.getType().dyn_cast()) { auto rank = shape_ty.getRank(); std::vector shape; if (rank > 0) { shape.push_back(rank); } ts = new TosaSerializationTensor(name, /* shape = */ shape, /* type = */ DType::DType_SHAPE, /* data = */ std::vector()); return ts; } auto ttype = val.getType().dyn_cast(); if (!ttype) { llvm::errs() << "TOSA serialization, supplied value is not of TensorType\n"; return nullptr; } const bool is_unranked = !ttype.hasRank(); std::vector shape; if (!is_unranked) { auto shaped = val.getType().dyn_cast(); assert(shaped); for (int idx = 0; idx < ttype.getRank(); idx++) { if (shaped.isDynamicDim(idx)) { shape.push_back(0); // size of 0 represents dynamic dimension } else { auto dim = shaped.getDimSize(idx); shape.push_back(dim); } } } DType type = Type2DType(ttype.getElementType()); ts = new TosaSerializationTensor(name, shape, type, std::vector(), /* variable = */ false, is_unranked); return ts; } mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func, TosaSerializationHandler &tsh) { mlir::Region *main_region = func.getCallableRegion(); std::vector main_returns; if (!main_region) { llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n"; return mlir::failure(); } if (!tsh.GetRegions().empty()) { llvm::errs() << "Internal Error: TosaSerializationHandler's region list " "must be empty\n"; return mlir::failure(); } // reset static counters input_tensor_index = 0; intermediate_tensor_index = 0; output_tensor_index = 0; TosaSerializationRegion *ser_main_region = BuildRegion(*main_region, "main", /* isolated_from_above = */ true, /* parent_value_scope = */ nullptr, &tsh, main_returns, /* is_top = */ true); if (!ser_main_region) { return mlir::failure(); } if (main_returns.empty()) { llvm::errs() << "Warning: graph doesn't have return values\n"; } return mlir::success(); } mlir::LogicalResult dumpTosaFlatbuffer(mlir::func::FuncOp &func) { tosa::TosaSerializationHandler tsh; std::string tosa_flatbuffer_directory_fullpath; if (translate2FlatBuffer(func, tsh).failed()) { llvm::errs() << "Fail to translate TOSA MLIR to flatbuffer\n"; return mlir::failure(); } if (tsh.SaveFileTosaFlatbuffer(tosa_flatbuffer_filename.c_str())) { llvm::errs() << "Fail to save flatbuffer " << tosa_flatbuffer_filename << "\n"; return mlir::failure(); } return mlir::success(); } mlir::LogicalResult dumpTosaJSON(mlir::func::FuncOp &func) { tosa::TosaSerializationHandler tsh; const char *tosa_schema = tosa_flatbuffer_schema.c_str(); if (!tosa_schema) { llvm::errs() << "Flatbuffer schema not defined\n"; return mlir::failure(); } if (tsh.LoadFileSchema(tosa_schema)) { llvm::errs() << "Error loading tosa schema file: " << tosa_schema << "\n"; return mlir::failure(); } std::string tosa_flatbuffer_directory_fullpath; if (translate2FlatBuffer(func, tsh).failed()) { llvm::errs() << "Fail to translate TOSA MLIR to flatbuffer\n"; return mlir::failure(); } if (tsh.SaveFileJson(tosa_flatbuffer_filename.c_str())) { llvm::errs() << "Fail to save flatbuffer " << tosa_flatbuffer_filename << "\n"; return mlir::failure(); } return mlir::success(); } #define GEN_PASS_DEF_TOSASERIALIZATIONPASS namespace mlir { namespace tosa { namespace { class TosaSerialize : public TosaSerializationPassBase { public: void runOnOperation() final { auto moduleOp = getOperation(); // iterate through each op in the moduleOp, call dumpTosaFlatbuffer if // that's a func.funcOp auto regions = moduleOp->getRegions(); auto region_size = regions.size(); auto region_0 = regions.begin(); auto block_size = region_0->getBlocks().size(); auto block_0 = region_0->getBlocks().begin(); auto op_size = block_0->getOperations().size(); for (auto it = block_0->getOperations().begin(); it != block_0->getOperations().end(); ++it) { // read variableOps that are declared outside of functionOps if (llvm::isa(*it)) { RegisterVariableOp(*it); } else if (llvm::isa(*it)) { auto funcOp = dyn_cast((*it)); if (dumpTosaFlatbuffer(funcOp).failed()) { llvm::errs() << "Failed to generate TOSA flatbuffer...\n"; return signalPassFailure(); } } } } }; class TosaSerializeJSON : public TosaSerializationJSONPassBase { public: void runOnOperation() final { auto function = getOperation(); if (dumpTosaJSON(function).failed()) { llvm::errs() << "Failed to generate TOSA JSON...\n"; return signalPassFailure(); } } }; } // anonymous namespace // Creates an instance of the TOSA flatbuffer generation pass std::unique_ptr> createTosaSerializePass() { return std::make_unique(); } std::unique_ptr createTosaSerializeJSONPass() { return std::make_unique(); } static PassRegistration pass([] { return createTosaSerializePass(); }); static PassRegistration passJSON([] { return createTosaSerializeJSONPass(); }); } // namespace tosa } // namespace mlir