// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 with LLVM Exceptions // (the "License"); you may not use this file except in compliance with // the License. You may obtain a copy of the License at // // https://llvm.org/LICENSE.txt // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // TOSA MLIR deserialize passes #include "include/DeserializationPasses.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tosa_serialization_handler.h" #include #include #include #include #include // The namespace might be confusing here. We have mlir::tosa:: defined in MLIR // and tosa:: defined in serialization library // TODO: align the namespace using namespace tosa; namespace cl = llvm::cl; llvm::cl::opt tosa_deserialize_filename( "tosa-deserialize-filename", llvm::cl::desc(""), llvm::cl::init("tosa_dump.tosa"), llvm::cl::value_desc("filename")); llvm::cl::opt tosa_deserialize_schema( "tosa-deserialize-schema", llvm::cl::desc(""), llvm::cl::init(""), llvm::cl::value_desc("filename")); const std::string kDefaultExportedName = "tosa_deserialized"; const std::string kDefaultInputPrefix = "input_"; const std::string kDefaultOutputPrefix = "output_"; const std::string kDefaultFBSDescription = "Tosa FBS Converted"; const std::string kDefaultJSONDescription = "Tosa JSON Converted"; const std::string kMainFunctionName = "main"; namespace { // a global map from flatbuffer variable names to serialized tensors std::unordered_map variable_tensor_map; void RegisterVariableTensor(TosaSerializationTensor *ts) { assert(ts->GetVariable()); // insert variable tensor ts only if not already present variable_tensor_map.insert({ts->GetName(), ts}); } bool IsVariableTensor(const std::string flatbuffer_tensor_name) { return variable_tensor_map.count(flatbuffer_tensor_name); } // return the variable name corresponding to flatbuffer_tensor_name const std::string GetVariableTensorName(TosaSerializationTensor *ts) { assert(ts->GetVariable()); const auto name = ts->GetVariableName(); if (name == "") { // for legacy flatbuffers which may not have variable_name fields return ts->GetName(); } return name; } // return the variable name corresponding to flatbuffer_tensor_name const std::string GetVariableTensorName(const std::string flatbuffer_tensor_name) { if (!IsVariableTensor(flatbuffer_tensor_name)) { llvm::errs() << "ERROR: Variable tensor " << flatbuffer_tensor_name << " is not found in variable_tensor_map"; return ""; } return GetVariableTensorName(variable_tensor_map[flatbuffer_tensor_name]); } bool IsVariableReadOp(TosaSerializationOperator *op) { return (op->GetOp() == tosa::Op::Op_IDENTITY) && IsVariableTensor(op->GetInputTensorNames()[0]); } bool IsVariableWriteOp(TosaSerializationOperator *op) { return (op->GetOp() == tosa::Op::Op_IDENTITY) && IsVariableTensor(op->GetOutputTensorNames()[0]); } // construct tensor type from dtype and shape of TosaSerializationTensor mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, TosaSerializationTensor *ts, mlir::RankedTensorType &type) { mlir::Type element_type; switch (ts->GetDtype()) { case DType_BOOL: element_type = op_builder->getI1Type(); break; case DType_UINT8: element_type = op_builder->getIntegerType(8, false); break; case DType_INT4: element_type = op_builder->getI4Type(); break; case DType_INT8: element_type = op_builder->getI8Type(); break; case DType_INT16: element_type = op_builder->getIntegerType(16); break; case DType_INT32: element_type = op_builder->getI32Type(); break; case DType_INT48: element_type = op_builder->getIntegerType(48); break; case DType_FP32: element_type = op_builder->getF32Type(); break; case DType_UINT16: element_type = op_builder->getIntegerType(16, false); break; case DType_FP16: element_type = op_builder->getF16Type(); break; case DType_BF16: element_type = op_builder->getBF16Type(); break; case DType_FP8E4M3: element_type = op_builder->getFloat8E4M3FNType(); break; case DType_FP8E5M2: element_type = op_builder->getFloat8E5M2Type(); break; case DType_SHAPE: llvm::errs() << "ERROR: Cannot construct RankedTensorType out of tosa.shape type \n"; return mlir::failure(); default: llvm::errs() << "ERROR: unknown type " << EnumNamesDType()[ts->GetDtype()] << "\n"; return mlir::failure(); } llvm::SmallVector shape; for (auto dim : ts->GetShape()) { if (dim > 0) { shape.push_back(dim); } else { // dynamic dim shape.push_back(mlir::ShapedType::kDynamic); } } type = mlir::RankedTensorType::get(llvm::ArrayRef(shape), element_type); return mlir::success(); } mlir::DenseElementsAttr GetConstAttr(const std::vector &data, const mlir::RankedTensorType &output_type, uint32_t out_size) { auto element_type = output_type.getElementType(); if (element_type.isF32()) { // for FP32, value attributes are stored as FP32 values std::vector float_data; TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data); return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(float_data)); } if (element_type.isBF16()) { mlir::SmallVector bf16_data; for (uint32_t i = 0; i < out_size; i++) { uint64_t byte0 = data[i * sizeof(int16_t)]; uint64_t byte1 = data[i * sizeof(int16_t) + 1]; uint64_t bits = byte0 + (byte1 << 8); mlir::APInt bf16_bits(16, bits); mlir::APFloat bf16(mlir::APFloat::BFloat(), bf16_bits); bf16_data.push_back(bf16); } return mlir::DenseElementsAttr::get(output_type, bf16_data); } if (element_type.isFloat8E4M3FN()) { mlir::SmallVector f8_data; for (uint32_t i = 0; i < out_size; i++) { mlir::APInt f8_bits(8, static_cast(data[i])); mlir::APFloat f8(mlir::APFloat::Float8E4M3FN(), f8_bits); f8_data.push_back(f8); } return mlir::DenseElementsAttr::get(output_type, f8_data); } if (element_type.isFloat8E5M2()) { mlir::SmallVector f8_data; for (uint32_t i = 0; i < out_size; i++) { mlir::APInt f8_bits(8, static_cast(data[i])); mlir::APFloat f8(mlir::APFloat::Float8E5M2(), f8_bits); f8_data.push_back(f8); } return mlir::DenseElementsAttr::get(output_type, f8_data); } if (element_type.isInteger(4)) { std::vector int4_data; TosaSerializationHandler::ConvertU8toI4(data, out_size, int4_data); return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data)); } if (element_type.isInteger(8)) { std::vector int8_data; TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data); return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data)); } if (element_type.isInteger(16)) { std::vector int16_data; TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data); return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int16_data)); } if (element_type.isInteger(32)) { std::vector int32_data; TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data); return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int32_data)); } if (element_type.isInteger(48)) { std::vector int48_data; TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data); std::vector apint_data; for (const auto v : int48_data) { mlir::APInt apint_value(48, static_cast(v), /* isSigned = */ false); apint_data.push_back(apint_value); } return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(apint_data)); } if (element_type.isInteger(1)) { std::vector bool_data; TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data); llvm::SmallVector bool_values(bool_data.begin(), bool_data.end()); return mlir::DenseElementsAttr::get(output_type, bool_values); } if (element_type.isF16()) { std::vector half_data; TosaSerializationHandler::ConvertU8toF16(data, out_size, half_data); return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(half_data)); } return nullptr; } mlir::DenseElementsAttr ConstructConstAttr(const mlir::RankedTensorType &output_type, TosaSerializationTensor *ts, const std::string &op_name) { // compute output data size uint32_t out_size = 1; for (const auto dim : ts->GetShape()) { out_size *= dim; } auto attr = GetConstAttr(ts->GetData(), output_type, out_size); if (!attr) { llvm::errs() << "ERROR: " << op_name << " contains unsupported element type\n"; } return attr; } mlir::LogicalResult ConstructVariableOps(mlir::ModuleOp &module) { if (variable_tensor_map.empty()) { return mlir::success(); } auto loc = module.getLoc(); auto op_builder = mlir::OpBuilder(module.getBodyRegion()); for (auto [flatbuffer_name, ts] : variable_tensor_map) { auto name = GetVariableTensorName(ts); mlir::RankedTensorType type; if (BuildTensorType(&op_builder, ts, type).failed()) { return mlir::failure(); } mlir::Attribute value_attr = nullptr; if (!ts->GetData().empty()) { value_attr = ConstructConstAttr(type, ts, name); } op_builder.create(loc, llvm::StringRef(name), type, value_attr); } return mlir::success(); } template mlir::DenseElementsAttr BuildDenseI8ElementsAttr(mlir::OpBuilder *op_builder, const std::vector &values) { llvm::SmallVector vec; for (auto val : values) { vec.push_back(val); } auto type = mlir::RankedTensorType::get({static_cast(vec.size())}, op_builder->getI8Type()); return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); } template mlir::DenseElementsAttr BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder, const std::vector &values) { llvm::SmallVector vec; for (auto val : values) { vec.push_back(val); } auto type = mlir::RankedTensorType::get({static_cast(vec.size())}, op_builder->getI16Type()); return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); } template mlir::DenseElementsAttr BuildDenseI32ElementsAttr(mlir::OpBuilder *op_builder, mlir::RankedTensorType &type, const std::vector &values) { llvm::SmallVector vec; for (auto val : values) { vec.push_back(val); } return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); } template mlir::DenseI8ArrayAttr BuildDenseI8ArrayAttr(mlir::OpBuilder *op_builder, const std::vector &values) { std::vector vec; for (auto val : values) { vec.push_back(val); } return op_builder->getDenseI8ArrayAttr(vec); } template mlir::DenseI32ArrayAttr BuildDenseI32ArrayAttr(mlir::OpBuilder *op_builder, const std::vector &values) { std::vector vec; for (auto val : values) { vec.push_back(val); } return op_builder->getDenseI32ArrayAttr(vec); } template mlir::DenseI64ArrayAttr BuildDenseI64ArrayAttr(mlir::OpBuilder *op_builder, const std::vector &values) { std::vector vec; for (auto val : values) { vec.push_back(val); } return op_builder->getDenseI64ArrayAttr(vec); } const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) { if (mode == ResizeMode_NEAREST) { return "NEAREST_NEIGHBOR"; } else if (mode == ResizeMode_BILINEAR) { return "BILINEAR"; } return ""; } // this is a counter part to Type2AccDType mlir::Type AccDType2Type(mlir::OpBuilder *op_builder, DType dtype) { // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>; if (dtype == DType_INT32) { return op_builder->getI32Type(); } else if (dtype == DType_INT48) { return op_builder->getIntegerType(48); } else if (dtype == DType_FP32) { return op_builder->getF32Type(); } else if (dtype == DType_FP16) { return op_builder->getF16Type(); } else { // unknown acc type // for now, default to F32 return op_builder->getF32Type(); } } } // namespace class TosaMlirRegionBuilder; class TosaMlirBlockBuilder; class TosaMlirOperatorBuilder { public: TosaMlirOperatorBuilder( mlir::OpBuilder *_op_builder, TosaSerializationBasicBlock *_ser_block, mlir::Block *_block, mlir::Location _loc, TosaMlirBlockBuilder *_block_builder, std::unordered_map *_tensor_map, std::unordered_map *_tensor_type_map, std::unordered_map *_shape_type_map) : op_builder(_op_builder), ser_block(_ser_block), block(_block), loc(_loc), block_builder(_block_builder), tensor_map(_tensor_map), tensor_type_map(_tensor_type_map), shape_type_map(_shape_type_map) {} template std::vector build(TosaSerializationOperator *op) const; std::vector BuildVariableOp(TosaSerializationOperator *op) const; std::vector BuildVariableReadOp(TosaSerializationOperator *op) const; void BuildVariableWriteOp(TosaSerializationOperator *op) const; std::string get_string(TosaSerializationOperator *op) const { std::string op_string; op_string += "operator opcode="; op_string += EnumNamesOp()[op->GetOp()]; op_string += ", input=["; for (auto ts : op->GetInputTensorNames()) { op_string += (ts + " "); } op_string += "], output=["; for (auto ts : op->GetOutputTensorNames()) { op_string += (ts + " "); } op_string += "]"; return op_string; } TosaSerializationHandler *GetTsh() const; TosaMlirRegionBuilder *GetRegionBuilder() const; private: template std::vector BuildEwiseUnaryOp(TosaSerializationOperator *op) const; template std::vector BuildEwiseBinaryOp(TosaSerializationOperator *op) const; template std::vector BuildEwiseBinaryShapeOp(TosaSerializationOperator *op) const; template std::vector BuildReductionOp(TosaSerializationOperator *op) const; template mlir::Value BuildConstShape(mlir::OpBuilder *op_builder, mlir::Location loc, const std::vector &values) const; template std::vector BuildConvOp(TosaSerializationOperator *op) const; mlir::OpBuilder *op_builder; TosaSerializationBasicBlock *ser_block; mlir::Block *block; mlir::Location loc; TosaMlirBlockBuilder *block_builder; std::unordered_map *tensor_map; std::unordered_map *tensor_type_map; std::unordered_map *shape_type_map; }; // Main template to catch unimplemented translation template std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { llvm::errs() << "ERROR: " << get_string(op) << " translation hasn't been implemented\n"; return {}; } // BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_PoolAttribute); // double check attribute type TosaPoolAttribute *attr = static_cast(op->GetAttribute()); mlir::DenseI64ArrayAttr kernel = BuildDenseI64ArrayAttr(op_builder, attr->kernel()); mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride()); mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); int32_t input_zp = attr->input_zp(); int32_t output_zp = attr->output_zp(); assert(input_zp == 0 && output_zp == 0); mlir::Operation *mlir_op = op_builder->create( loc, output_type, input_val, kernel, stride, pad); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } // BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D) template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_PoolAttribute); // double check attribute type TosaPoolAttribute *attr = static_cast(op->GetAttribute()); mlir::DenseI64ArrayAttr kernel = BuildDenseI64ArrayAttr(op_builder, attr->kernel()); mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride()); mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); auto acc_attr = mlir::TypeAttr::get(AccDType2Type(op_builder, attr->acc_type())); int32_t input_zp = attr->input_zp(); int32_t output_zp = attr->output_zp(); mlir::Operation *mlir_op; if (input_zp == 0 && output_zp == 0) { mlir_op = op_builder->create( loc, output_type, input_val, kernel, stride, pad, acc_attr); } else { auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp); mlir_op = op_builder->create( loc, output_type, input_val, kernel, stride, pad, acc_attr, input_zp_attr, output_zp_attr); } block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } std::vector TosaMlirOperatorBuilder::BuildVariableReadOp( TosaSerializationOperator *op) const { auto input_tensor_name = op->GetInputTensorNames()[0]; auto output_tensor_name = op->GetOutputTensorNames()[0]; assert(IsVariableTensor(input_tensor_name)); auto variable_name = GetVariableTensorName(input_tensor_name); mlir::RankedTensorType output_type = tensor_type_map->at(output_tensor_name); assert(op->GetAttributeType() == Attribute_NONE); // double check that there is no attribute mlir::Operation *mlir_op = op_builder->create( loc, output_type, llvm::StringRef(variable_name)); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } void TosaMlirOperatorBuilder::BuildVariableWriteOp( TosaSerializationOperator *op) const { auto input_tensor_name = op->GetInputTensorNames()[0]; auto output_tensor_name = op->GetOutputTensorNames()[0]; assert(IsVariableTensor(output_tensor_name)); auto variable_name = GetVariableTensorName(output_tensor_name); mlir::Value input_val = tensor_map->at(input_tensor_name); mlir::Operation *mlir_op = op_builder->create( loc, llvm::StringRef(variable_name), input_val); block->push_back(mlir_op); } template std::vector TosaMlirOperatorBuilder::BuildEwiseUnaryOp( TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_NONE); // double check that there is no attribute mlir::Operation *mlir_op = op_builder->create(loc, output_type, input_val); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template std::vector TosaMlirOperatorBuilder::BuildEwiseBinaryOp( TosaSerializationOperator *op) const { mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_NONE); // double check that there is no attribute mlir::Operation *mlir_op = op_builder->create(loc, output_type, input0_val, input1_val); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template std::vector TosaMlirOperatorBuilder::BuildEwiseBinaryShapeOp( TosaSerializationOperator *op) const { mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::tosa::shapeType output_type = shape_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_NONE); // double check that there is no attribute mlir::Operation *mlir_op = op_builder->create(loc, output_type, input0_val, input1_val); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template std::vector TosaMlirOperatorBuilder::BuildReductionOp(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_AxisAttribute); // double check attribute type TosaAxisAttribute *attr = static_cast(op->GetAttribute()); auto axis = op_builder->getI32IntegerAttr(attr->axis()); mlir::Operation *mlir_op = op_builder->create(loc, output_type, input_val, axis); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } #define BUILD_OP_ELEMENTWISE_UNARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ template <> \ std::vector \ TosaMlirOperatorBuilder::build( \ TosaSerializationOperator * op) const { \ return BuildEwiseUnaryOp(op); \ } #define BUILD_OP_ELEMENTWISE_BINARY(MLIR_OP_NAME, SCHEMA_OP_NAME) \ template <> \ std::vector \ TosaMlirOperatorBuilder::build( \ TosaSerializationOperator * op) const { \ return BuildEwiseBinaryOp(op); \ } #define BUILD_OP_ELEMENTWISE_BINARY_SHAPE(MLIR_OP_NAME, SCHEMA_OP_NAME) \ template <> \ std::vector \ TosaMlirOperatorBuilder::build( \ TosaSerializationOperator * op) const { \ return BuildEwiseBinaryShapeOp(op); \ } #define BUILD_OP_REDUCTION(MLIR_OP_NAME, SCHEMA_OP_NAME) \ template <> \ std::vector \ TosaMlirOperatorBuilder::build( \ TosaSerializationOperator * op) const { \ return BuildReductionOp(op); \ } // BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) // BUILD_OP_POOL2D(AvgPool2d, AVG_POOL2D) BUILD_OP_ELEMENTWISE_BINARY(Add, ADD) BUILD_OP_ELEMENTWISE_BINARY(BitwiseAnd, BITWISE_AND) BUILD_OP_ELEMENTWISE_BINARY(BitwiseXor, BITWISE_XOR) BUILD_OP_ELEMENTWISE_BINARY(BitwiseOr, BITWISE_OR) BUILD_OP_ELEMENTWISE_BINARY(IntDiv, INTDIV) BUILD_OP_ELEMENTWISE_BINARY(LogicalAnd, LOGICAL_AND) BUILD_OP_ELEMENTWISE_BINARY(LogicalLeftShift, LOGICAL_LEFT_SHIFT) BUILD_OP_ELEMENTWISE_BINARY(LogicalRightShift, LOGICAL_RIGHT_SHIFT) BUILD_OP_ELEMENTWISE_BINARY(LogicalOr, LOGICAL_OR) BUILD_OP_ELEMENTWISE_BINARY(LogicalXor, LOGICAL_XOR) BUILD_OP_ELEMENTWISE_BINARY(Maximum, MAXIMUM) BUILD_OP_ELEMENTWISE_BINARY(Minimum, MINIMUM) BUILD_OP_ELEMENTWISE_BINARY(Pow, POW) BUILD_OP_ELEMENTWISE_BINARY(Sub, SUB) BUILD_OP_ELEMENTWISE_UNARY(Abs, ABS) BUILD_OP_ELEMENTWISE_UNARY(BitwiseNot, BITWISE_NOT) BUILD_OP_ELEMENTWISE_UNARY(Ceil, CEIL) BUILD_OP_ELEMENTWISE_UNARY(Clz, CLZ) BUILD_OP_ELEMENTWISE_UNARY(Cos, COS) BUILD_OP_ELEMENTWISE_UNARY(Exp, EXP) BUILD_OP_ELEMENTWISE_UNARY(Floor, FLOOR) BUILD_OP_ELEMENTWISE_UNARY(Log, LOG) BUILD_OP_ELEMENTWISE_UNARY(LogicalNot, LOGICAL_NOT) BUILD_OP_ELEMENTWISE_UNARY(Reciprocal, RECIPROCAL) BUILD_OP_ELEMENTWISE_UNARY(Rsqrt, RSQRT) BUILD_OP_ELEMENTWISE_UNARY(Sin, SIN) BUILD_OP_REDUCTION(ReduceAny, REDUCE_ANY) BUILD_OP_REDUCTION(ReduceAll, REDUCE_ALL) BUILD_OP_REDUCTION(ReduceMax, REDUCE_MAX) BUILD_OP_REDUCTION(ReduceMin, REDUCE_MIN) BUILD_OP_REDUCTION(ReduceProd, REDUCE_PRODUCT) BUILD_OP_REDUCTION(ReduceSum, REDUCE_SUM) BUILD_OP_ELEMENTWISE_BINARY(Equal, EQUAL) BUILD_OP_ELEMENTWISE_BINARY(Greater, GREATER) BUILD_OP_ELEMENTWISE_BINARY(GreaterEqual, GREATER_EQUAL) BUILD_OP_ELEMENTWISE_UNARY(Erf, ERF) BUILD_OP_ELEMENTWISE_UNARY(Sigmoid, SIGMOID) BUILD_OP_ELEMENTWISE_UNARY(Tanh, TANH) BUILD_OP_ELEMENTWISE_UNARY(Identity, IDENTITY) BUILD_OP_ELEMENTWISE_UNARY(Cast, CAST) BUILD_OP_ELEMENTWISE_BINARY_SHAPE(AddShape, ADD_SHAPE) BUILD_OP_ELEMENTWISE_BINARY_SHAPE(SubShape, SUB_SHAPE) BUILD_OP_ELEMENTWISE_BINARY_SHAPE(MulShape, MUL_SHAPE) BUILD_OP_ELEMENTWISE_BINARY_SHAPE(DivShape, DIV_SHAPE) template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { const auto &output_name = op->GetOutputTensorNames()[0]; mlir::RankedTensorType output_type = tensor_type_map->at(output_name); TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name); auto value_attr = ConstructConstAttr(output_type, ts, get_string(op)); if (!value_attr) { return {}; } mlir::Operation *mlir_op = op_builder->create(loc, output_type, value_attr); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template mlir::Value TosaMlirOperatorBuilder::BuildConstShape(mlir::OpBuilder *op_builder, mlir::Location loc, const std::vector &values) const { std::vector vec; for (auto val : values) { vec.push_back(val); } auto attr = op_builder->getIndexTensorAttr(vec); auto type = mlir::tosa::shapeType::get(op_builder->getContext(), /* rank = */ vec.size()); mlir::Operation *mlir_op = op_builder->create(loc, type, attr); block->push_back(mlir_op); return mlir_op->getResult(0); } template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { const auto &output_name = op->GetOutputTensorNames()[0]; mlir::tosa::shapeType output_type = shape_type_map->at(output_name); TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name); const auto &data = ts->GetData(); std::vector i64_data; TosaSerializationHandler::ConvertU8toI64(data, output_type.getRank(), i64_data); mlir::Value result = BuildConstShape(op_builder, loc, i64_data); return std::vector({result}); } template std::vector TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const { mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_ConvAttribute); // double check attribute type TosaConvAttribute *attr = static_cast(op->GetAttribute()); mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride()); mlir::DenseI64ArrayAttr dilation = BuildDenseI64ArrayAttr(op_builder, attr->dilation()); auto input_zp = attr->input_zp(); auto weight_zp = attr->weight_zp(); bool local_bound = attr->local_bound(); auto acc_type = AccDType2Type(op_builder, attr->acc_type()); // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; if (output_type.getElementType().isa()) { assert(input_zp == 0 && weight_zp == 0); } auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, input2_val, pad, stride, dilation, acc_type, input_zp_attr, weight_zp_attr, local_bound); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } #define BUILD_OP_CONV(MLIR_OP_NAME, SCHEMA_OP_NAME) \ template <> \ std::vector \ TosaMlirOperatorBuilder::build( \ TosaSerializationOperator * op) const { \ return BuildConvOp(op); \ } BUILD_OP_CONV(Conv2D, CONV2D) BUILD_OP_CONV(Conv3D, CONV3D) BUILD_OP_CONV(DepthwiseConv2D, DEPTHWISE_CONV2D) template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_TransposeConvAttribute); // double check attribute type TosaTransposeConvAttribute *attr = static_cast(op->GetAttribute()); mlir::DenseI64ArrayAttr out_pad = BuildDenseI64ArrayAttr(op_builder, attr->out_pad()); mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride()); auto input_zp = attr->input_zp(); auto weight_zp = attr->weight_zp(); bool local_bound = attr->local_bound(); auto acc_type = AccDType2Type(op_builder, attr->acc_type()); // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; if (output_type.getElementType().isa()) { assert(input_zp == 0 && weight_zp == 0); } auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, acc_type, input_zp_attr, weight_zp_attr, local_bound); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_FullyConnectedAttribute); // double check attribute type TosaFullyConnectedAttribute *attr = static_cast(op->GetAttribute()); auto input_zp = attr->input_zp(); auto weight_zp = attr->weight_zp(); // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; if (output_type.getElementType().isa()) { assert(input_zp == 0 && weight_zp == 0); mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, input2_val); } else { auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, input2_val, input_zp_attr, weight_zp_attr); } block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_MatMulAttribute); // double check attribute type TosaMatMulAttribute *attr = static_cast(op->GetAttribute()); auto A_zp = attr->a_zp(); auto B_zp = attr->b_zp(); mlir::Operation *mlir_op; if (A_zp == 0 && B_zp == 0) { mlir_op = op_builder->create(loc, output_type, input0_val, input1_val); } else { auto a_zp_attr = op_builder->getI32IntegerAttr(A_zp); auto b_zp_attr = op_builder->getI32IntegerAttr(B_zp); mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, a_zp_attr, b_zp_attr); } block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_NONE); // double check that there is no attribute mlir::Operation *mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, input2_val); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_ClampAttribute); // double check attribute type TosaClampAttribute *attr = static_cast(op->GetAttribute()); mlir::Type element_type = llvm::cast(input_val.getType()).getElementType(); if (auto quantType = llvm::dyn_cast(element_type)) { element_type = quantType.getStorageType(); } auto element_const_type = mlir::RankedTensorType::get({1}, element_type); auto min_values_attr = GetConstAttr(attr->min_val(), element_const_type, 1); auto max_values_attr = GetConstAttr(attr->max_val(), element_const_type, 1); mlir::Attribute min_val_attr, max_val_attr; if (element_type.isa()) { min_val_attr = op_builder->getFloatAttr( element_type, min_values_attr.getValues()[0]); max_val_attr = op_builder->getFloatAttr( element_type, max_values_attr.getValues()[0]); } else { min_val_attr = op_builder->getIntegerAttr( element_type, min_values_attr.getValues()[0]); max_val_attr = op_builder->getIntegerAttr( element_type, max_values_attr.getValues()[0]); } auto mlir_op = op_builder->create( loc, output_type, input_val, min_val_attr, max_val_attr); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } // ArgMax has single input, and single I64 axis attribute BUILD_OP_REDUCTION(ArgMax, ARGMAX) template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); llvm::SmallVector input_values; for (auto &input_name : op->GetInputTensorNames()) { mlir::Value input_val = tensor_map->at(input_name); input_values.push_back(input_val); } assert(op->GetAttributeType() == Attribute_AxisAttribute); // double check attribute type TosaAxisAttribute *attr = static_cast(op->GetAttribute()); auto axis = op_builder->getI32IntegerAttr(attr->axis()); mlir::Operation *mlir_op = op_builder->create( loc, output_type, input_values, axis); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { mlir::tosa::shapeType output_type = shape_type_map->at(op->GetOutputTensorNames()[0]); llvm::SmallVector input_values; for (auto &input_name : op->GetInputTensorNames()) { mlir::Value input_val = tensor_map->at(input_name); input_values.push_back(input_val); } mlir::Operation *mlir_op = op_builder->create( loc, output_type, input_values); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_NegateAttribute); // double check attribute type TosaNegateAttribute *attr = static_cast(op->GetAttribute()); auto input_zp = attr->input1_zp(); auto output_zp = attr->output_zp(); mlir::Operation *mlir_op; if (input_zp == 0 && output_zp == 0) { mlir_op = op_builder->create(loc, output_type, input_val); } else { auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp); mlir_op = op_builder->create( loc, output_type, input_val, input_zp_attr, output_zp_attr); } block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); mlir::Value shape_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::Operation *mlir_op = op_builder->create( loc, output_type, input_val, shape_val); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value padding_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::RankedTensorType input_type = tensor_type_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); const auto element_type = input_val.getType().cast().getElementType(); assert(op->GetAttributeType() == Attribute_PadAttribute); // double check attribute type TosaPadAttribute *attr = static_cast(op->GetAttribute()); const auto &pad_const_u8_data = attr->pad_const(); // check for any value in pad_const_u8_data bool has_pad_const = false; for (auto v : pad_const_u8_data) { if (v != 0) { has_pad_const = true; break; } } if (!has_pad_const) { // handle the cases where no explicit pad_const input. auto mlir_op = op_builder->create( loc, output_type, input_val, padding_val); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } // has pad const - create a const op for pad_const input auto pad_const_type = mlir::RankedTensorType::get({}, element_type); auto pad_const_attr = GetConstAttr(pad_const_u8_data, pad_const_type, 1); auto pad_const_op = op_builder->create( loc, pad_const_type, pad_const_attr); block->push_back(pad_const_op); mlir::Value pad_const_value = pad_const_op->getResult(0); auto mlir_op = op_builder->create( loc, output_type, input_val, padding_val, pad_const_value); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_AxisAttribute); // double check attribute type TosaAxisAttribute *attr = static_cast(op->GetAttribute()); auto axis = op_builder->getI32IntegerAttr(attr->axis()); mlir::Operation *mlir_op = op_builder->create(loc, output_type, input_val, axis); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_TransposeAttribute); // double check attribute type TosaTransposeAttribute *attr = static_cast(op->GetAttribute()); // make a constant op from attr->perms values, of type: { shape = { perms.size // }, element_type = I32 } const auto perms_values = attr->perms(); auto const_type = mlir::RankedTensorType::get( {static_cast(perms_values.size())}, op_builder->getI32Type()); mlir::DenseElementsAttr const_attr = BuildDenseI32ElementsAttr(op_builder, const_type, perms_values); mlir::Operation *mlir_const_op = op_builder->create(loc, const_type, const_attr); auto perms_val = mlir_const_op->getResult(0); mlir::Operation *mlir_op = op_builder->create( loc, output_type, input_val, perms_val); block->push_back(mlir_const_op); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); mlir::Value start = tensor_map->at(op->GetInputTensorNames()[1]); mlir::Value size = tensor_map->at(op->GetInputTensorNames()[2]); mlir::Operation *mlir_op = op_builder->create( loc, output_type, input_val, start, size); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value multiples = tensor_map->at(op->GetInputTensorNames()[1]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_NONE); // double check attribute type mlir::Operation *mlir_op = op_builder->create( loc, output_type, input_val, multiples); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } // Gather is a binary op BUILD_OP_ELEMENTWISE_BINARY(Gather, GATHER) template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::Value input2_val = tensor_map->at(op->GetInputTensorNames()[2]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_NONE); // double check that there is no attribute mlir::Operation *mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, input2_val); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value scale_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::Value offset_val = tensor_map->at(op->GetInputTensorNames()[2]); mlir::Value border_val = tensor_map->at(op->GetInputTensorNames()[3]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_ResizeAttribute); // double check attribute type TosaResizeAttribute *attr = static_cast(op->GetAttribute()); auto mode = op_builder->getStringAttr(ResizeEnum2Str(attr->mode())); mlir::Operation *mlir_op = op_builder->create( loc, output_type, input_val, scale_val, offset_val, border_val, mode); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } // Reverse has single input, and single I64 axis attribute BUILD_OP_REDUCTION(Reverse, REVERSE) template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_MulAttribute); // double check attribute type mlir::Operation *mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, shift_val); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert( op->GetAttributeType() == Attribute_ArithmeticRightShiftAttribute); // double check attribute type TosaArithmeticRightShiftAttribute *attr = static_cast(op->GetAttribute()); auto round = op_builder->getBoolAttr(attr->round()); mlir::Operation *mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, round); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_TableAttribute); // double check attribute type TosaTableAttribute *attr = static_cast(op->GetAttribute()); // create a const op for table value attribute const auto table_values = attr->table(); mlir::RankedTensorType const_type; mlir::DenseElementsAttr const_attr; const auto input_element_type = input_val.getType().cast().getElementType(); if (input_element_type.isInteger(8)) { // table is signed 8 mode const_type = mlir::RankedTensorType::get( {static_cast(table_values.size())}, op_builder->getI8Type()); const_attr = BuildDenseI8ElementsAttr(op_builder, table_values); } else { // table is signed 16 mode const_type = mlir::RankedTensorType::get( {static_cast(table_values.size())}, op_builder->getI16Type()); const_attr = BuildDenseI16ElementsAttr(op_builder, table_values); } mlir::Operation *mlir_const_op = op_builder->create(loc, const_type, const_attr); auto table_value = mlir_const_op->getResult(0); mlir::Operation *mlir_op = op_builder->create( loc, output_type, input_val, table_value); block->push_back(mlir_const_op); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value multiplier_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_RescaleAttribute); // double check attribute type TosaRescaleAttribute *attr = static_cast(op->GetAttribute()); auto input_zp = op_builder->getI32IntegerAttr(attr->input_zp()); auto output_zp = op_builder->getI32IntegerAttr(attr->output_zp()); auto scale32 = op_builder->getBoolAttr(attr->scale32()); auto double_round = op_builder->getBoolAttr(attr->double_round()); auto per_channel = op_builder->getBoolAttr(attr->per_channel()); auto input_unsigned = op_builder->getBoolAttr(attr->input_unsigned()); auto output_unsigned = op_builder->getBoolAttr(attr->output_unsigned()); mlir::Operation *mlir_op = op_builder->create( loc, output_type, input_val, multiplier_val, shift_val, input_zp, output_zp, scale32, double_round, per_channel, input_unsigned, output_unsigned); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); assert(op->GetAttributeType() == Attribute_CustomAttribute); // double check attribute type TosaCustomAttribute *attr = static_cast(op->GetAttribute()); auto operator_name = op_builder->getStringAttr(attr->operator_name()); auto domain_name = op_builder->getStringAttr(attr->domain_name()); std::string impl_str; impl_str.resize(attr->implementation_attrs().size() + 1); int idx = 0; for (auto c : attr->implementation_attrs()) { impl_str[idx++] = c; } auto impl = op_builder->getStringAttr(impl_str); mlir::Operation *mlir_op = op_builder->create( loc, output_type, operator_name, domain_name, impl, input_val); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output0_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); mlir::RankedTensorType output1_type = tensor_type_map->at(op->GetOutputTensorNames()[1]); assert(op->GetAttributeType() == Attribute_RFFTAttribute); // double check attribute type TosaRFFTAttribute *attr = static_cast(op->GetAttribute()); bool local_bound = attr->local_bound(); mlir::Operation *mlir_op = op_builder->create( loc, output0_type, output1_type, input_val, local_bound); block->push_back(mlir_op); return std::vector( {mlir_op->getResult(0), mlir_op->getResult(1)}); } template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input0_val = tensor_map->at(op->GetInputTensorNames()[0]); mlir::Value input1_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::RankedTensorType output0_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); mlir::RankedTensorType output1_type = tensor_type_map->at(op->GetOutputTensorNames()[1]); assert(op->GetAttributeType() == Attribute_FFTAttribute); TosaFFTAttribute *attr = static_cast(op->GetAttribute()); auto inverse = op_builder->getBoolAttr(attr->inverse()); auto local_bound = op_builder->getBoolAttr(attr->local_bound()); mlir::Operation *mlir_op = op_builder->create( loc, output0_type, output1_type, input0_val, input1_val, inverse, local_bound); block->push_back(mlir_op); return std::vector( {mlir_op->getResult(0), mlir_op->getResult(1)}); } class TosaMlirRegionBuilder { public: TosaMlirRegionBuilder(TosaSerializationRegion *_ser_region, TosaSerializationHandler *_tsh, mlir::Region *_region, mlir::OpBuilder *_op_builder, mlir::Location _loc, TosaMlirRegionBuilder *_parent_value_scope = nullptr) : ser_region(_ser_region), tsh(_tsh), region(_region), op_builder(_op_builder), loc(_loc) { if (_parent_value_scope) { // inherit parent_value_scope's tensor_map for (auto &kv : _parent_value_scope->GetTensorMap()) { tensor_map.insert(kv); } } } mlir::LogicalResult BuildAllBlocksInRegion(std::vector &return_values); mlir::OpBuilder *GetOpBuilder() { return op_builder; } mlir::Location GetLocation() { return loc; } std::unordered_map &GetTensorMap() { return tensor_map; } TosaSerializationHandler *GetTsh() const { return tsh; } private: mlir::Region *region; TosaSerializationRegion *ser_region; TosaSerializationHandler *tsh; mlir::OpBuilder *op_builder; mlir::Location loc; std::unordered_map tensor_map; }; class TosaMlirBlockBuilder { public: TosaMlirBlockBuilder(TosaSerializationBasicBlock *_ser_block, TosaMlirRegionBuilder *_region_builder, mlir::Block *_block) : ser_block(_ser_block), region_builder(_region_builder), block(_block) {} mlir::LogicalResult BuildAllOpsInBlock(std::vector &return_values); mlir::OpBuilder *GetOpBuilder() { return region_builder->GetOpBuilder(); } mlir::Location GetLocation() { return region_builder->GetLocation(); } std::unordered_map &GetTensorMap() { return region_builder->GetTensorMap(); } TosaSerializationHandler *GetTsh() const { return region_builder->GetTsh(); } TosaMlirRegionBuilder *GetRegionBuilder() const { return region_builder; } private: TosaSerializationBasicBlock *ser_block; TosaMlirRegionBuilder *region_builder; mlir::Block *block; std::unordered_map tensor_type_map; std::unordered_map shape_type_map; std::unordered_set unranked_tensors; }; TosaSerializationHandler *TosaMlirOperatorBuilder::GetTsh() const { return block_builder->GetTsh(); } TosaMlirRegionBuilder *TosaMlirOperatorBuilder::GetRegionBuilder() const { return block_builder->GetRegionBuilder(); } // build control flow ops: namespace { mlir::LogicalResult BuildRegion(TosaSerializationRegion *ser_region, TosaSerializationHandler *tsh, mlir::Region *mlir_region, mlir::OpBuilder *op_builder, mlir::Location loc, std::vector &return_values, bool isolated_from_above = false, TosaMlirRegionBuilder *parent_region_builder = nullptr) { TosaMlirRegionBuilder *parent_value_scope = isolated_from_above ? nullptr : parent_region_builder; TosaMlirRegionBuilder region_builder(ser_region, tsh, mlir_region, op_builder, loc, parent_value_scope); return region_builder.BuildAllBlocksInRegion(return_values); } } // namespace template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { mlir::Value cond_val = tensor_map->at(op->GetInputTensorNames().at(0)); std::vector input_values; for (auto idx = 1u; idx < op->GetInputTensorNames().size(); idx++) { input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx))); } std::vector output_types; for (auto &name : op->GetInputTensorNames()) { output_types.push_back(tensor_type_map->at(name)); } assert(op->GetAttributeType() == Attribute_CondIfAttribute); // double check attribute type TosaCondIfAttribute *attr = static_cast(op->GetAttribute()); auto ser_then_region = GetTsh()->GetRegionByName(attr->then_graph()); auto ser_else_region = GetTsh()->GetRegionByName(attr->else_graph()); if (!ser_then_region || !ser_else_region) { llvm::errs() << "ERROR: " << get_string(op) << " region serialization hasn't been implemented\n"; return {}; } mlir::Operation *mlir_op = op_builder->create( loc, output_types, cond_val, input_values); const bool isolated_from_above = mlir_op->hasTrait(); mlir::Region &then_region = mlir_op->getRegion(0); mlir::Region &else_region = mlir_op->getRegion(1); auto curr_region_builder = GetRegionBuilder(); std::vector then_returns, else_returns; if (BuildRegion(ser_then_region, GetTsh(), &then_region, op_builder, loc, then_returns, isolated_from_above, curr_region_builder) .failed()) { return {}; } if (then_returns.size() != mlir_op->getNumResults()) { llvm::errs() << "ERROR: " << get_string(op) << " then_region yield.size() doesn't match cond_if's output size\n"; return {}; } if (BuildRegion(ser_else_region, GetTsh(), &else_region, op_builder, loc, else_returns, isolated_from_above, curr_region_builder) .failed()) { return {}; } if (else_returns.size() != mlir_op->getNumResults()) { llvm::errs() << "ERROR: " << get_string(op) << " else_region yield.size() doesn't match cond_if's output size\n"; return {}; } block->push_back(mlir_op); return std::vector(mlir_op->getResults().begin(), mlir_op->getResults().end()); } template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { std::vector input_values; for (auto idx = 0u; idx < op->GetInputTensorNames().size(); idx++) { input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx))); } std::vector output_types; for (auto &name : op->GetInputTensorNames()) { output_types.push_back(tensor_type_map->at(name)); } assert(op->GetAttributeType() == Attribute_WhileLoopAttribute); // double check attribute type TosaWhileLoopAttribute *attr = static_cast(op->GetAttribute()); auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_graph()); auto ser_body_region = GetTsh()->GetRegionByName(attr->body_graph()); mlir::Operation *mlir_op = op_builder->create(loc, output_types, input_values); const bool isolated_from_above = mlir_op->hasTrait(); mlir::Region &cond_region = mlir_op->getRegion(0); mlir::Region &body_region = mlir_op->getRegion(1); auto curr_region_builder = GetRegionBuilder(); std::vector cond_returns, body_returns; if (BuildRegion(ser_cond_region, GetTsh(), &cond_region, op_builder, loc, cond_returns, isolated_from_above, curr_region_builder) .failed()) { return {}; } if (cond_returns.size() != 1) { llvm::errs() << "ERROR: " << get_string(op) << " cond_region yield.size() is not 1\n"; return {}; } if (BuildRegion(ser_body_region, GetTsh(), &body_region, op_builder, loc, body_returns, isolated_from_above, curr_region_builder) .failed()) { return {}; } if (body_returns.size() != mlir_op->getNumResults()) { llvm::errs() << "ERROR: " << get_string(op) << " body_region yield.size() doesn't match while_loop's output size\n"; return {}; } block->push_back(mlir_op); return std::vector(mlir_op->getResults().begin(), mlir_op->getResults().end()); } mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( std::vector &return_values) { block->clear(); auto loc = GetLocation(); auto op_builder = GetOpBuilder(); auto &tensor_map = GetTensorMap(); std::unordered_set operator_built; std::queue operator_queue; TosaMlirOperatorBuilder tosa_op_builder(op_builder, ser_block, block, loc, this, &tensor_map, &tensor_type_map, &shape_type_map); for (auto ts : ser_block->GetTensors()) { if (ts->GetVariable()) { RegisterVariableTensor(ts); } const auto &ts_name = ts->GetName(); if (ts->GetDtype() == DType::DType_SHAPE) { // ts is tosa.shape type auto shape_rank = ts->GetShape()[0]; shape_type_map[ts_name] = mlir::tosa::shapeType::get(op_builder->getContext(), shape_rank); continue; } mlir::RankedTensorType type; if (BuildTensorType(op_builder, ts, type).failed()) { return mlir::failure(); } tensor_type_map[ts_name] = type; if (ts->GetIsUnranked()) { assert(ts->GetShape().empty()); // unranked tensors should have shape = {} unranked_tensors.insert(ts_name); } } // Initialize tensor_map/operator_queue based on block input arguments for (const std::string &block_input_name : ser_block->GetInputs()) { mlir::Type type = tensor_type_map[block_input_name]; if (unranked_tensors.count(block_input_name)) { // recast type as unranked tensor type auto element_type = type.cast().getElementType(); type = mlir::UnrankedTensorType::get(element_type); } auto input_value = block->addArgument(type, loc); if (tensor_map.count(block_input_name)) { llvm::errs() << "ERROR: block input tensor " << block_input_name << " already exists\n"; return mlir::failure(); } tensor_map[block_input_name] = input_value; } for (auto op : ser_block->GetOperators()) { // skip if operator has been built if (operator_built.count(op)) { // this happens when same input appears twice or more in operator, eg, // concat(%0, %0) continue; } operator_built.insert(op); std::vector output_values; if (IsVariableReadOp(op)) { output_values = tosa_op_builder.BuildVariableReadOp(op); } else if (IsVariableWriteOp(op)) { tosa_op_builder.BuildVariableWriteOp(op); } #define DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME) \ else if (op->GetOp() == Op_##SCHEMA_OP_NAME) { \ output_values = tosa_op_builder.build(op); \ } #include "schema_operator.def" #undef DEF_SCHEMA_OPERATOR else { llvm::errs() << "ERROR: unsupported opcode=" << EnumNamesOp()[op->GetOp()] << "\n"; return mlir::failure(); } if (IsVariableWriteOp(op)) { // the sanity checking below does not apply for variable write op because // it has no output tensors whereas the original identity op has continue; } // Sanity check if number of built mlir::Value is expected if (op->GetOutputTensorNames().size() != output_values.size()) { llvm::errs() << "ERROR: number of built mlir::Value is not matching " "number of operator output tensor\n"; return mlir::failure(); } for (size_t i = 0; i < output_values.size(); i++) { // Sanity check tensor hasn't been built std::string op_output_name = op->GetOutputTensorNames()[i]; if (tensor_map.count(op_output_name)) { llvm::errs() << "ERROR: tensor " << op_output_name << " is already built\n"; return mlir::failure(); } tensor_map[op_output_name] = output_values[i]; } } // Construct return values std::vector return_operands; for (const auto &output_name : ser_block->GetOutputs()) { // Sanity check if terminator mlir::Value is built if (!tensor_map.count(output_name)) { llvm::errs() << "ERROR: terminator mlir::Value " << output_name << " is not built in block " << ser_block->GetName() << "\n"; return mlir::failure(); } mlir::Value output_value = tensor_map.at(output_name); return_operands.push_back(output_value); return_values.push_back(output_value); } mlir::Operation *terminator_op; auto parent_op = block->getParentOp(); if (mlir::isa(parent_op)) { terminator_op = op_builder->create(loc, return_operands); } else { terminator_op = op_builder->create(loc, return_operands); } block->push_back(terminator_op); // need topological sorting? return mlir::success(); } mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion( std::vector &return_values) { for (auto &ser_block : ser_region->GetBlocks()) { auto &block = region->emplaceBlock(); TosaMlirBlockBuilder block_builder(ser_block, this, &block); if (block_builder.BuildAllOpsInBlock(return_values).failed()) { return mlir::failure(); } if (return_values.empty()) { llvm::errs() << "Warning: graph doesn't have return values\n"; } } return mlir::success(); } mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp &func, mlir::MLIRContext &context, tosa::TosaSerializationHandler &tsh, std::vector &main_returns) { mlir::Region *main_region = func.getCallableRegion(); if (!main_region) { llvm::errs() << "Invalid MLIR: doesn't have valid \"main\" region\n"; return mlir::failure(); } TosaSerializationRegion *ser_main_region = tsh.GetRegions().front(); auto loc = func.getLoc(); main_region->takeBody(*main_region); // empty old func body auto op_builder = mlir::OpBuilder(func.getBody()); if (BuildRegion(ser_main_region, &tsh, main_region, &op_builder, loc, main_returns) .failed()) { return mlir::failure(); } if (main_returns.empty()) { llvm::errs() << "Warning: graph doesn't have return values\n"; } return mlir::success(); } // Load Tosa Schema into TosaSerializationHandler, required for JSON save/load mlir::LogicalResult loadTosaSchema(tosa::TosaSerializationHandler &tsh) { const char *tosa_schema = tosa_deserialize_schema.c_str(); if (!tosa_schema) { llvm::errs() << "Flatbuffer schema not defined\n"; return mlir::failure(); } if (tsh.LoadFileSchema(tosa_schema)) { llvm::errs() << "Error loading tosa schema file: " << tosa_schema << "\n"; return mlir::failure(); } return mlir::success(); } namespace { mlir::NamedAttribute DefaultEntryFuncitonAttr(mlir::Builder &builder, bool is_input, int count) { std::string names; for (int i = 0; i < count; i++) { std::string name = kDefaultExportedName + "_"; name += (is_input ? kDefaultInputPrefix : kDefaultOutputPrefix); name += std::to_string(i) + ":0"; if (i > 0) { names += ","; } names += name; } return builder.getNamedAttr((is_input ? "inputs" : "outputs"), builder.getStringAttr(names)); } // erase all ops in block except for FuncOp void ClearNonFuncOps(mlir::Block *block) { std::vector to_delete; for (auto &op : block->getOperations()) { if (!mlir::isa(op)) { to_delete.push_back(&op); } } for (mlir::Operation *op : to_delete) { op->erase(); } } // erase function attrs and empty function region's body void ResetFunction(mlir::func::FuncOp &function, mlir::MLIRContext &context) { function->setAttrs(mlir::DictionaryAttr::get(&context, {})); mlir::Region *main_region = function.getCallableRegion(); main_region->takeBody(*main_region); } // replace attrs and body of @a to_function and its parent module // by @a from_module and its "main" function mlir::LogicalResult CloneIntoModuleAndFunction( mlir::MLIRContext &context, mlir::func::FuncOp &to_function, mlir::ModuleOp &to_module, mlir::func::FuncOp &from_function, mlir::ModuleOp &from_module) { auto from_block = from_function.getOperation()->getBlock(); auto to_block = to_function.getOperation()->getBlock(); ClearNonFuncOps(to_block); // copy all attrs from new_module to module to_module->setAttrs(from_module->getAttrDictionary()); // erase attrs and body of function ResetFunction(to_function, context); // clone new_func attrs and region into function mlir::IRMapping mapping; from_function.cloneInto(to_function, mapping); // copy variable ops in from_block to to_block // collect variable ops in from_block in reverse order std::vector variable_ops; for (mlir::Operation &op : *from_block) { if (mlir::isa(op)) { variable_ops.push_back(&op); } } auto cloneOptions = mlir::Operation::CloneOptions::all().cloneRegions(false).cloneOperands( false); for (auto iter = variable_ops.rbegin(); iter != variable_ops.rend(); iter++) { auto op = *iter; to_block->push_front(op->clone(mapping, cloneOptions)); } return mlir::success(); } } // namespace namespace mlir { namespace tosa { mlir::OwningOpRef BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context, bool file_is_fbs) { TosaSerializationHandler tsh; if (file_is_fbs) { if (tsh.LoadFileTosaFlatbuffer(file_name)) { llvm::errs() << "Fail to load TOSA file " << file_name << "\n"; return nullptr; } } else { // must load tosa schema before loading json file if (loadTosaSchema(tsh).failed()) { return nullptr; } if (tsh.LoadFileJson(file_name)) { llvm::errs() << "Fail to load TOSA JSON file " << file_name << "\n"; return nullptr; } } // create new module auto base_loc = mlir::FileLineColLoc::get(context, file_name, 0, 0); auto module = mlir::ModuleOp::create(base_loc); // set module attributes const auto &tosa_version = tsh.GetVersion().to_string(); std::string tosa_description = file_is_fbs ? kDefaultFBSDescription : kDefaultJSONDescription; auto builder = mlir::Builder(context); module->setAttr("tosa.fbs_version", builder.getStringAttr(tosa_version)); module->setAttr("tosa.description", builder.getStringAttr(tosa_description)); module->setAttr("tf_saved_model.semantics", mlir::UnitAttr::get(context)); // construct function with input and return types llvm::SmallVector ret_types; llvm::SmallVector input_types; auto func_type = builder.getFunctionType(input_types, ret_types); auto func_loc = mlir::NameLoc::get(builder.getStringAttr(kMainFunctionName), base_loc); auto func = mlir::func::FuncOp::create(func_loc, kMainFunctionName, func_type, /* attrs= */ {}); func.addEntryBlock(); // deserialize tosa fbs into function std::vector main_returns; if (buildTosaMlir(func, *context, tsh, main_returns).failed()) { llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n"; return nullptr; } auto main_args = func.getCallableRegion()->getArguments(); // extract function input types for (auto arg : main_args) { input_types.push_back(arg.getType()); } // extract function return types for (auto ret : main_returns) { ret_types.push_back(ret.getType()); } // set function type with full input and return types func_type = builder.getFunctionType(input_types, ret_types); func.setType(func_type); // set function attributes llvm::SmallVector attributes; if (!input_types.empty()) { attributes.push_back(DefaultEntryFuncitonAttr( builder, /* is_input = */ true, /* count = */ input_types.size())); for (int i = 0; i < input_types.size(); i++) { std::string input_i = kDefaultInputPrefix + std::to_string(i); func.setArgAttr(i, "tf_saved_model.index_path", mlir::ArrayAttr::get( context, {mlir::StringAttr::get(context, input_i)})); } } if (!ret_types.empty()) { attributes.push_back(DefaultEntryFuncitonAttr( builder, /* is_input = */ false, /* count = */ ret_types.size())); for (int i = 0; i < ret_types.size(); i++) { std::string output_i = kDefaultOutputPrefix + std::to_string(i); func.setResultAttr( i, "tf_saved_model.index_path", mlir::ArrayAttr::get(context, {mlir::StringAttr::get(context, output_i)})); } } func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes)); func->setAttr( "tf_saved_model.exported_names", mlir::ArrayAttr::get( context, {mlir::StringAttr::get(context, kDefaultExportedName)})); // deserialize variable ops in the new module just before adding func op if (ConstructVariableOps(module).failed()) { return nullptr; } // add func to module module.push_back(std::move(func)); return mlir::OwningOpRef(module); } namespace { class TosaDeserialize : public TosaDeserializationPassBase { public: void runOnOperation() final { auto function = getOperation(); auto &context = getContext(); auto new_module_ref = BuildMlirFromTosaFile( tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ true); if (!new_module_ref) { return signalPassFailure(); } mlir::ModuleOp new_module = *new_module_ref; auto builder = mlir::Builder(&context); auto module = function->getParentOfType(); auto new_function = new_module.lookupSymbol( builder.getStringAttr(kMainFunctionName)); if (!new_function) { llvm::errs() << "Failed to find main function in deserialized module\n"; return signalPassFailure(); } if (CloneIntoModuleAndFunction(context, /* to_function = */ function, /* to_module = */ module, /* from_function = */ new_function, /* from_module = */ new_module) .failed()) { return signalPassFailure(); } } }; class TosaDeserializeJSON : public TosaDeserializationJSONPassBase { public: void runOnOperation() final { auto function = getOperation(); auto &context = getContext(); auto new_module_ref = BuildMlirFromTosaFile( tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ false); if (!new_module_ref) { return signalPassFailure(); } mlir::ModuleOp new_module = *new_module_ref; auto builder = mlir::Builder(&context); auto module = function->getParentOfType(); auto new_function = new_module.lookupSymbol( builder.getStringAttr(kMainFunctionName)); if (!new_function) { llvm::errs() << "Failed to find main function in deserialized module\n"; return signalPassFailure(); } if (CloneIntoModuleAndFunction(context, /* to_function = */ function, /* to_module = */ module, /* from_function = */ new_function, /* from_module = */ new_module) .failed()) { return signalPassFailure(); } } }; } // anonymous namespace // Creates an instance of the TOSA flatbuffer deserialization pass std::unique_ptr createTosaDeserializePass() { return std::make_unique(); } std::unique_ptr createTosaDeserializeJSONPass() { return std::make_unique(); } static PassRegistration passDeserialize([] { return createTosaDeserializePass(); }); static PassRegistration passDeserializeJSON([] { return createTosaDeserializeJSONPass(); }); } // namespace tosa } // namespace mlir