diff options
Diffstat (limited to 'tosa_checker/tosa_checker.cc')
-rw-r--r-- | tosa_checker/tosa_checker.cc | 225 |
1 files changed, 225 insertions, 0 deletions
diff --git a/tosa_checker/tosa_checker.cc b/tosa_checker/tosa_checker.cc new file mode 100644 index 0000000..714cab3 --- /dev/null +++ b/tosa_checker/tosa_checker.cc @@ -0,0 +1,225 @@ +/* +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 +*/ +#include "tosa_checker.h" + +#include "absl/strings/string_view.h" +#include "llvm/Support/MemoryBuffer.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/tosa/tfl_passes.h" + +#include <map> +#include <memory> +#include <optional> +#include <stdexcept> +#include <string> +#include <unordered_set> +#include <vector> + +namespace std { +template <> +struct hash<mlir::Location> { + std::size_t operator()(const mlir::Location &loc) const { + return mlir::hash_value(loc); + } +}; +} // namespace std + +namespace tosa_checker { + +TOSAChecker::TOSAChecker(const std::string &model_path) { + m_model = TFLiteFileToMLIR(model_path, &m_context); + m_tosa_model = m_model->clone(); + LegalizeTFLToTOSA(*m_tosa_model); +} + +bool TOSAChecker::IsTOSACompatible() { + bool is_tosa_compatible = true; + for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) { + func.walk([&](mlir::Operation *op) { + // Ignore func namespace + const mlir::Dialect *dialect = op->getDialect(); + if (!dialect || (!dialect->getNamespace().equals("tosa") && + !dialect->getNamespace().equals("func"))) { + is_tosa_compatible = false; + return mlir::WalkResult::interrupt(); + } + + return mlir::WalkResult::advance(); + }); + } + + return is_tosa_compatible; +} + +std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps( + bool elide_large_attrs) { + // Get the locations of all the ops in the legalized model that were not + // converted during the TOSA legalization (i.e. the TOSA incompatible ones). + std::unordered_set<mlir::Location> tosa_incompatible_locs; + for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) { + func.walk([&](mlir::Operation *op) { + // Ignore func namespace + const mlir::Dialect *dialect = op->getDialect(); + if (!dialect || (!dialect->getNamespace().equals("tosa") && + !dialect->getNamespace().equals("func"))) { + tosa_incompatible_locs.insert(op->getLoc()); + } + }); + } + + // We assume that on legalization, the non-legalized ops keep their original + // location. If an op location from the original model is in + // tosa_incompatible_locs then the op is not tosa compatible, otherwise it is. + std::vector<Operator> ops; + for (auto func : m_model->getOps<mlir::func::FuncOp>()) { + func.walk([&](mlir::Operation *op) { + // Ignore func namespace + const mlir::Dialect *dialect = op->getDialect(); + if (!dialect || !dialect->getNamespace().equals("func")) { + const bool is_tosa_compatible = + tosa_incompatible_locs.find(op->getLoc()) == + tosa_incompatible_locs.end(); + ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs)); + } + }); + } + + return ops; +} + +std::vector<TOSAChecker::Operator> TOSAChecker::GetUsedTOSAOps( + bool elide_large_attrs) { + std::vector<Operator> tosa_ops; + for (mlir::Operation *op : GetTOSAOps(*m_tosa_model)) { + const bool is_tosa_compatible = true; + tosa_ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs)); + } + + return tosa_ops; +} + +std::string TOSAChecker::GetMLIRModelRepresentation(bool elide_large_attrs) { + return GetMLIRRepresentation(*m_model, elide_large_attrs); +} + +std::string TOSAChecker::GetMLIRTOSAModelRepresentation( + bool elide_large_attrs) { + return GetMLIRRepresentation(*m_tosa_model, elide_large_attrs); +} + +template <typename T> +std::string TOSAChecker::GetMLIRRepresentation(T &&op) { + std::string value; + llvm::raw_string_ostream value_ostream(value); + + op.print(value_ostream); + + return value; +} + +template <typename T> +std::string TOSAChecker::GetMLIRRepresentation(T &&op, bool elide_large_attrs) { + std::string value; + llvm::raw_string_ostream value_ostream(value); + + mlir::OpPrintingFlags flags; + if (elide_large_attrs) { + flags.elideLargeElementsAttrs(ELIDE_LARGE_ATTRS_LIMIT); + } + op.print(value_ostream, flags); + + return value; +} + +std::vector<mlir::Operation *> TOSAChecker::GetTOSAOps(mlir::ModuleOp model) { + std::vector<mlir::Operation *> tosa_ops; + for (auto func : model.getOps<mlir::func::FuncOp>()) { + func.walk([&](mlir::Operation *op) { + const mlir::Dialect *dialect = op->getDialect(); + if (dialect && dialect->getNamespace().equals("tosa")) { + tosa_ops.push_back(op); + } + }); + } + + return tosa_ops; +} + +TOSAChecker::Operator TOSAChecker::ToOperator(mlir::Operation &op, + bool is_tosa_compatible, + bool elide_large_attrs) { + return Operator(op.getName().getStringRef().str(), + GetMLIRRepresentation(op.getLoc()), + GetAttributes(op, elide_large_attrs), is_tosa_compatible, + GetMLIRRepresentation(op, elide_large_attrs)); +} + +mlir::OwningOpRef<mlir::ModuleOp> TOSAChecker::TFLiteFileToMLIR( + const std::string &model_path, mlir::MLIRContext *context) { + std::string error_message; + std::unique_ptr<llvm::MemoryBuffer> input = + mlir::openInputFile(model_path, &error_message); + if (!input) { + throw std::runtime_error(error_message); + } + + const mlir::FileLineColLoc location = + mlir::FileLineColLoc::get(context, input->getBufferIdentifier(), 0, 0); + + auto mlir_module = tflite::FlatBufferToMlir( + absl::string_view(input->getBufferStart(), input->getBufferSize()), + context, location); + if (!mlir_module || mlir::failed(mlir::verify(*mlir_module))) { + throw std::runtime_error( + "Could not convert the TFLite model to its MLIR representation."); + } + + return mlir_module; +} + +void TOSAChecker::LegalizeTFLToTOSA(mlir::ModuleOp mlir_module) { + mlir::PassManager pm(mlir_module.getContext(), + mlir::OpPassManager::Nesting::Implicit); + mlir::tosa::TOSATFLLegalizationPipelineOptions opts; + mlir::tosa::createTFLtoTOSALegalizationPipeline(pm, opts); + // TODO Don't check for mlir::failed state for now due to some incoherences in + // how the legalization report non-convertible ops (sometimes with a hard + // fail, sometimes without). The legalization should not return a failed + // state if an operator can't be legalized and should leave it in its original + // dialect. + pm.run(mlir_module); +} + +std::map<std::string, std::string> TOSAChecker::GetAttributes( + mlir::Operation &op, bool /*elide_large_attrs*/) { + std::map<std::string, std::string> attributes; + for (const mlir::NamedAttribute &attr : op.getAttrs()) { + attributes.emplace(attr.getName().str(), + // TODO Check how to elide large attributes when + // converting them to string, mlir::Attribute::print has + // no mlir::OpPrintingFlags. + GetMLIRRepresentation(attr.getValue())); + } + + return attributes; +} + +} // namespace tosa_checker + +std::ostream &operator<<(std::ostream &os, + const tosa_checker::TOSAChecker::Operator &op) { + os << op.mlir_representation; + + return os; +} |