/* SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. SPDX-License-Identifier: Apache-2.0 */ #include "tosa_checker.h" #include "absl/strings/string_view.h" #include "llvm/Support/MemoryBuffer.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "mlir/IR/Verifier.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/FileUtilities.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/tosa/tfl_passes.h" #include #include #include #include #include #include #include namespace std { template <> struct hash { std::size_t operator()(const mlir::Location &loc) const { return mlir::hash_value(loc); } }; } // namespace std namespace tosa_checker { TOSAChecker::TOSAChecker(const std::string &model_path) { m_model = TFLiteFileToMLIR(model_path, &m_context); m_tosa_model = m_model->clone(); LegalizeTFLToTOSA(*m_tosa_model); } bool TOSAChecker::IsTOSACompatible() { bool is_tosa_compatible = true; for (auto func : m_tosa_model->getOps()) { func.walk([&](mlir::Operation *op) { // Ignore func dialect ops const bool is_func = op->getDialect() && op->getDialect()->getNamespace().equals("func"); if (!is_func && !IsTOSACompatibleOp(*op)) { is_tosa_compatible = false; return mlir::WalkResult::interrupt(); } return mlir::WalkResult::advance(); }); } return is_tosa_compatible; } std::vector TOSAChecker::GetTOSACompatibilityForOps( bool elide_large_attrs) { // Get the locations of all the ops in the legalized model that were not // converted during the TOSA legalization (i.e. the TOSA incompatible ones). std::unordered_set tosa_incompatible_locs; for (auto func : m_tosa_model->getOps()) { func.walk([&](mlir::Operation *op) { // Ignore func dialect ops const bool is_func = op->getDialect() && op->getDialect()->getNamespace().equals("func"); if (!is_func && !IsTOSACompatibleOp(*op)) { tosa_incompatible_locs.insert(op->getLoc()); } }); } // We assume that on legalization, the non-legalized ops keep their original // location. If an op location from the original model is in // tosa_incompatible_locs then the op is not tosa compatible, otherwise it is. std::vector ops; for (auto func : m_model->getOps()) { func.walk([&](mlir::Operation *op) { // Ignore func namespace const bool is_func = op->getDialect() && op->getDialect()->getNamespace().equals("func"); if (!is_func) { const bool is_tosa_compatible = tosa_incompatible_locs.find(op->getLoc()) == tosa_incompatible_locs.end(); ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs)); } }); } return ops; } std::vector TOSAChecker::GetUsedTOSAOps( bool elide_large_attrs) { std::vector tosa_ops; for (mlir::Operation *op : GetTOSAOps(*m_tosa_model)) { const bool is_tosa_compatible = IsTOSACompatibleOp(*op); tosa_ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs)); } return tosa_ops; } std::string TOSAChecker::GetMLIRModelRepresentation(bool elide_large_attrs) { return GetMLIRRepresentation(*m_model, elide_large_attrs); } std::string TOSAChecker::GetMLIRTOSAModelRepresentation( bool elide_large_attrs) { return GetMLIRRepresentation(*m_tosa_model, elide_large_attrs); } bool TOSAChecker::IsTOSACompatibleOp(mlir::Operation &op) { const mlir::Dialect *dialect = op.getDialect(); if (dialect && dialect->getNamespace().equals("tosa")) { // Due to the opaque nature of the tosa.custom operator, a TOSA compliant // system may not be able to run a model with such operators. We // consider these models as TOSA incompatible. if (op.getName().getStringRef().equals("tosa.custom")) { return false; } return true; } return false; } template std::string TOSAChecker::GetMLIRRepresentation(T &&op) { std::string value; llvm::raw_string_ostream value_ostream(value); op.print(value_ostream); return value; } template std::string TOSAChecker::GetMLIRRepresentation(T &&op, bool elide_large_attrs) { std::string value; llvm::raw_string_ostream value_ostream(value); mlir::OpPrintingFlags flags; if (elide_large_attrs) { flags.elideLargeElementsAttrs(ELIDE_LARGE_ATTRS_LIMIT); } op.print(value_ostream, flags); return value; } std::vector TOSAChecker::GetTOSAOps(mlir::ModuleOp model) { std::vector tosa_ops; for (auto func : model.getOps()) { func.walk([&](mlir::Operation *op) { const mlir::Dialect *dialect = op->getDialect(); if (dialect && dialect->getNamespace().equals("tosa")) { tosa_ops.push_back(op); } }); } return tosa_ops; } TOSAChecker::Operator TOSAChecker::ToOperator(mlir::Operation &op, bool is_tosa_compatible, bool elide_large_attrs) { return Operator(op.getName().getStringRef().str(), GetMLIRRepresentation(op.getLoc()), GetAttributes(op, elide_large_attrs), is_tosa_compatible, GetMLIRRepresentation(op, elide_large_attrs)); } mlir::OwningOpRef TOSAChecker::TFLiteFileToMLIR( const std::string &model_path, mlir::MLIRContext *context) { std::string error_message; std::unique_ptr input = mlir::openInputFile(model_path, &error_message); if (!input) { throw std::runtime_error(error_message); } const mlir::FileLineColLoc location = mlir::FileLineColLoc::get(context, input->getBufferIdentifier(), 0, 0); auto mlir_module = tflite::FlatBufferToMlir( absl::string_view(input->getBufferStart(), input->getBufferSize()), context, location); if (!mlir_module || mlir::failed(mlir::verify(*mlir_module))) { throw std::runtime_error( "Could not convert the TFLite model to its MLIR representation."); } return mlir_module; } void TOSAChecker::LegalizeTFLToTOSA(mlir::ModuleOp mlir_module) { mlir::PassManager pm(mlir_module.getContext(), mlir_module.getOperationName(), mlir::OpPassManager::Nesting::Implicit); mlir::tosa::TOSATFLLegalizationPipelineOptions opts; mlir::tosa::createTFLtoTOSALegalizationPipeline(pm, opts); // TODO Don't check for mlir::failed state for now due to some incoherences in // how the legalization report non-convertible ops (sometimes with a hard // fail, sometimes without). The legalization should not return a failed // state if an operator can't be legalized and should leave it in its original // dialect. pm.run(mlir_module); } std::map TOSAChecker::GetAttributes( mlir::Operation &op, bool /*elide_large_attrs*/) { std::map attributes; for (const mlir::NamedAttribute &attr : op.getAttrs()) { attributes.emplace(attr.getName().str(), // TODO Check how to elide large attributes when // converting them to string, mlir::Attribute::print has // no mlir::OpPrintingFlags. GetMLIRRepresentation(attr.getValue())); } return attributes; } } // namespace tosa_checker std::ostream &operator<<(std::ostream &os, const tosa_checker::TOSAChecker::Operator &op) { os << op.mlir_representation; return os; }