diff options
Diffstat (limited to 'tosa_checker')
-rw-r--r-- | tosa_checker/BUILD | 35 | ||||
-rw-r--r-- | tosa_checker/__init__.py | 8 | ||||
-rw-r--r-- | tosa_checker/tosa_checker.cc | 225 | ||||
-rw-r--r-- | tosa_checker/tosa_checker.h | 82 | ||||
-rw-r--r-- | tosa_checker/tosa_checker_pybind11.cc | 79 |
5 files changed, 429 insertions, 0 deletions
diff --git a/tosa_checker/BUILD b/tosa_checker/BUILD new file mode 100644 index 0000000..8c1c32d --- /dev/null +++ b/tosa_checker/BUILD @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +cc_library( + name = "tosa_checker_lib", + srcs = ["tosa_checker.cc"], + hdrs = ["tosa_checker.h"], + deps = [ + "@llvm-project//mlir:MlirTranslateMain", + "@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", + "@org_tensorflow//tensorflow/compiler/mlir/tosa:tfl_passes", + "@pybind11", + ], +) + +pybind_extension( + name = "_tosa_checker_wrapper", + srcs = [ + "tosa_checker_pybind11.cc", + ], + deps = [ + ":tosa_checker_lib", + ], +) + +py_library( + name = "tosa_checker", + srcs = [ + "__init__.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + data = ["//tosa_checker:_tosa_checker_wrapper.so"], +) diff --git a/tosa_checker/__init__.py b/tosa_checker/__init__.py new file mode 100644 index 0000000..ce76797 --- /dev/null +++ b/tosa_checker/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +"""the package provides a way to check if a TFLite model is compatible with the TOSA specification.""" + +from _tosa_checker_wrapper import * + +__version__ = "0.1.0" diff --git a/tosa_checker/tosa_checker.cc b/tosa_checker/tosa_checker.cc new file mode 100644 index 0000000..714cab3 --- /dev/null +++ b/tosa_checker/tosa_checker.cc @@ -0,0 +1,225 @@ +/* +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 +*/ +#include "tosa_checker.h" + +#include "absl/strings/string_view.h" +#include "llvm/Support/MemoryBuffer.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/tosa/tfl_passes.h" + +#include <map> +#include <memory> +#include <optional> +#include <stdexcept> +#include <string> +#include <unordered_set> +#include <vector> + +namespace std { +template <> +struct hash<mlir::Location> { + std::size_t operator()(const mlir::Location &loc) const { + return mlir::hash_value(loc); + } +}; +} // namespace std + +namespace tosa_checker { + +TOSAChecker::TOSAChecker(const std::string &model_path) { + m_model = TFLiteFileToMLIR(model_path, &m_context); + m_tosa_model = m_model->clone(); + LegalizeTFLToTOSA(*m_tosa_model); +} + +bool TOSAChecker::IsTOSACompatible() { + bool is_tosa_compatible = true; + for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) { + func.walk([&](mlir::Operation *op) { + // Ignore func namespace + const mlir::Dialect *dialect = op->getDialect(); + if (!dialect || (!dialect->getNamespace().equals("tosa") && + !dialect->getNamespace().equals("func"))) { + is_tosa_compatible = false; + return mlir::WalkResult::interrupt(); + } + + return mlir::WalkResult::advance(); + }); + } + + return is_tosa_compatible; +} + +std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps( + bool elide_large_attrs) { + // Get the locations of all the ops in the legalized model that were not + // converted during the TOSA legalization (i.e. the TOSA incompatible ones). + std::unordered_set<mlir::Location> tosa_incompatible_locs; + for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) { + func.walk([&](mlir::Operation *op) { + // Ignore func namespace + const mlir::Dialect *dialect = op->getDialect(); + if (!dialect || (!dialect->getNamespace().equals("tosa") && + !dialect->getNamespace().equals("func"))) { + tosa_incompatible_locs.insert(op->getLoc()); + } + }); + } + + // We assume that on legalization, the non-legalized ops keep their original + // location. If an op location from the original model is in + // tosa_incompatible_locs then the op is not tosa compatible, otherwise it is. + std::vector<Operator> ops; + for (auto func : m_model->getOps<mlir::func::FuncOp>()) { + func.walk([&](mlir::Operation *op) { + // Ignore func namespace + const mlir::Dialect *dialect = op->getDialect(); + if (!dialect || !dialect->getNamespace().equals("func")) { + const bool is_tosa_compatible = + tosa_incompatible_locs.find(op->getLoc()) == + tosa_incompatible_locs.end(); + ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs)); + } + }); + } + + return ops; +} + +std::vector<TOSAChecker::Operator> TOSAChecker::GetUsedTOSAOps( + bool elide_large_attrs) { + std::vector<Operator> tosa_ops; + for (mlir::Operation *op : GetTOSAOps(*m_tosa_model)) { + const bool is_tosa_compatible = true; + tosa_ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs)); + } + + return tosa_ops; +} + +std::string TOSAChecker::GetMLIRModelRepresentation(bool elide_large_attrs) { + return GetMLIRRepresentation(*m_model, elide_large_attrs); +} + +std::string TOSAChecker::GetMLIRTOSAModelRepresentation( + bool elide_large_attrs) { + return GetMLIRRepresentation(*m_tosa_model, elide_large_attrs); +} + +template <typename T> +std::string TOSAChecker::GetMLIRRepresentation(T &&op) { + std::string value; + llvm::raw_string_ostream value_ostream(value); + + op.print(value_ostream); + + return value; +} + +template <typename T> +std::string TOSAChecker::GetMLIRRepresentation(T &&op, bool elide_large_attrs) { + std::string value; + llvm::raw_string_ostream value_ostream(value); + + mlir::OpPrintingFlags flags; + if (elide_large_attrs) { + flags.elideLargeElementsAttrs(ELIDE_LARGE_ATTRS_LIMIT); + } + op.print(value_ostream, flags); + + return value; +} + +std::vector<mlir::Operation *> TOSAChecker::GetTOSAOps(mlir::ModuleOp model) { + std::vector<mlir::Operation *> tosa_ops; + for (auto func : model.getOps<mlir::func::FuncOp>()) { + func.walk([&](mlir::Operation *op) { + const mlir::Dialect *dialect = op->getDialect(); + if (dialect && dialect->getNamespace().equals("tosa")) { + tosa_ops.push_back(op); + } + }); + } + + return tosa_ops; +} + +TOSAChecker::Operator TOSAChecker::ToOperator(mlir::Operation &op, + bool is_tosa_compatible, + bool elide_large_attrs) { + return Operator(op.getName().getStringRef().str(), + GetMLIRRepresentation(op.getLoc()), + GetAttributes(op, elide_large_attrs), is_tosa_compatible, + GetMLIRRepresentation(op, elide_large_attrs)); +} + +mlir::OwningOpRef<mlir::ModuleOp> TOSAChecker::TFLiteFileToMLIR( + const std::string &model_path, mlir::MLIRContext *context) { + std::string error_message; + std::unique_ptr<llvm::MemoryBuffer> input = + mlir::openInputFile(model_path, &error_message); + if (!input) { + throw std::runtime_error(error_message); + } + + const mlir::FileLineColLoc location = + mlir::FileLineColLoc::get(context, input->getBufferIdentifier(), 0, 0); + + auto mlir_module = tflite::FlatBufferToMlir( + absl::string_view(input->getBufferStart(), input->getBufferSize()), + context, location); + if (!mlir_module || mlir::failed(mlir::verify(*mlir_module))) { + throw std::runtime_error( + "Could not convert the TFLite model to its MLIR representation."); + } + + return mlir_module; +} + +void TOSAChecker::LegalizeTFLToTOSA(mlir::ModuleOp mlir_module) { + mlir::PassManager pm(mlir_module.getContext(), + mlir::OpPassManager::Nesting::Implicit); + mlir::tosa::TOSATFLLegalizationPipelineOptions opts; + mlir::tosa::createTFLtoTOSALegalizationPipeline(pm, opts); + // TODO Don't check for mlir::failed state for now due to some incoherences in + // how the legalization report non-convertible ops (sometimes with a hard + // fail, sometimes without). The legalization should not return a failed + // state if an operator can't be legalized and should leave it in its original + // dialect. + pm.run(mlir_module); +} + +std::map<std::string, std::string> TOSAChecker::GetAttributes( + mlir::Operation &op, bool /*elide_large_attrs*/) { + std::map<std::string, std::string> attributes; + for (const mlir::NamedAttribute &attr : op.getAttrs()) { + attributes.emplace(attr.getName().str(), + // TODO Check how to elide large attributes when + // converting them to string, mlir::Attribute::print has + // no mlir::OpPrintingFlags. + GetMLIRRepresentation(attr.getValue())); + } + + return attributes; +} + +} // namespace tosa_checker + +std::ostream &operator<<(std::ostream &os, + const tosa_checker::TOSAChecker::Operator &op) { + os << op.mlir_representation; + + return os; +} diff --git a/tosa_checker/tosa_checker.h b/tosa_checker/tosa_checker.h new file mode 100644 index 0000000..d7750ea --- /dev/null +++ b/tosa_checker/tosa_checker.h @@ -0,0 +1,82 @@ +/* +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 +*/ +#ifndef TOSA_CHECKER_H_ +#define TOSA_CHECKER_H_ + +#include <map> +#include <optional> +#include <string> +#include <vector> + +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/IR/MLIRContext.h" +#include "mlir/include/mlir/IR/OwningOpRef.h" + +namespace tosa_checker { + +class TOSAChecker { + public: + struct Operator { + Operator(std::string name, std::string location, + std::map<std::string, std::string> attributes, + bool is_tosa_compatible, std::string mlir_representation) + : name(std::move(name)), + location(std::move(location)), + attributes(std::move(attributes)), + is_tosa_compatible(is_tosa_compatible), + mlir_representation(std::move(mlir_representation)) {} + + std::string name; + std::string location; + std::map<std::string, std::string> attributes; + bool is_tosa_compatible; + std::string mlir_representation; + }; + + TOSAChecker(const std::string& model_path); + + bool IsTOSACompatible(); + + std::vector<Operator> GetTOSACompatibilityForOps(bool elide_large_attrs); + + std::vector<Operator> GetUsedTOSAOps(bool elide_large_attrs); + + std::string GetMLIRModelRepresentation(bool elide_large_attrs); + std::string GetMLIRTOSAModelRepresentation(bool elide_large_attrs); + + private: + template <typename T> + static std::string GetMLIRRepresentation(T&& op); + + template <typename T> + static std::string GetMLIRRepresentation(T&& op, bool elide_large_attrs); + + static std::vector<mlir::Operation*> GetTOSAOps(mlir::ModuleOp model); + + static Operator ToOperator(mlir::Operation& op, bool is_tosa_compatible, + bool elide_large_attrs); + + static mlir::OwningOpRef<mlir::ModuleOp> TFLiteFileToMLIR( + const std::string& model_path, mlir::MLIRContext* context); + + static void LegalizeTFLToTOSA(mlir::ModuleOp mlir_module); + + static std::map<std::string, std::string> GetAttributes( + mlir::Operation& op, bool elide_large_attrs); + + private: + static constexpr std::int64_t ELIDE_LARGE_ATTRS_LIMIT = 16; + + mlir::MLIRContext m_context; + mlir::OwningOpRef<mlir::ModuleOp> m_model; + mlir::OwningOpRef<mlir::ModuleOp> m_tosa_model; +}; + +} // namespace tosa_checker + +std::ostream& operator<<(std::ostream& os, + const tosa_checker::TOSAChecker::Operator& op); + +#endif diff --git a/tosa_checker/tosa_checker_pybind11.cc b/tosa_checker/tosa_checker_pybind11.cc new file mode 100644 index 0000000..c799817 --- /dev/null +++ b/tosa_checker/tosa_checker_pybind11.cc @@ -0,0 +1,79 @@ +/* +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 +*/ +#include "tosa_checker.h" + +#include <optional> +#include <sstream> +#include <string> + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +PYBIND11_MODULE(_tosa_checker_wrapper, m) { + /** + * tosa_checker::TOSAChecker + */ + pybind11::class_<tosa_checker::TOSAChecker> tosa_checker_class(m, + "TOSAChecker"); + tosa_checker_class.def(pybind11::init<const std::string&>(), + pybind11::arg("model_path")); + + tosa_checker_class.def( + "is_tosa_compatible", + [](tosa_checker::TOSAChecker& tc) { return tc.IsTOSACompatible(); }, + "Check if a model is compatible with the TOSA specification"); + + tosa_checker_class.def( + "_get_tosa_compatibility_for_ops", + [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) { + return tc.GetTOSACompatibilityForOps(elide_large_elements_attrs); + }, + pybind11::arg("elide_large_elements_attrs") = false, + "Get all the operators of the models with a TOSA compatibility flag for " + "each operator"); + + tosa_checker_class.def( + "_get_used_tosa_ops", + [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) { + return tc.GetUsedTOSAOps(elide_large_elements_attrs); + }, + pybind11::arg("elide_large_elements_attrs") = false, + "Get the TOSA operators used by the model after its TOSA legalization"); + + tosa_checker_class.def( + "_get_mlir_model_representation", + [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) { + return tc.GetMLIRModelRepresentation(elide_large_elements_attrs); + }, + pybind11::arg("elide_large_elements_attrs") = false, + "Get the MLIR representation of the model"); + + tosa_checker_class.def( + "_get_mlir_tosa_model_representation", + [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) { + return tc.GetMLIRTOSAModelRepresentation(elide_large_elements_attrs); + }, + pybind11::arg("elide_large_elements_attrs") = false, + "Get the MLIR representation of the TOSA legalized model"); + + /** + * tosa_checker::TOSAChecker::Operator + */ + pybind11::class_<tosa_checker::TOSAChecker::Operator>(tosa_checker_class, + "_Operator") + .def_readonly("name", &tosa_checker::TOSAChecker::Operator::name) + .def_readonly("location", &tosa_checker::TOSAChecker::Operator::location) + .def_readonly("attributes", + &tosa_checker::TOSAChecker::Operator::attributes) + .def_readonly("is_tosa_compatible", + &tosa_checker::TOSAChecker::Operator::is_tosa_compatible) + .def_readonly("mlir_representation", + &tosa_checker::TOSAChecker::Operator::mlir_representation) + .def("__repr__", [](const tosa_checker::TOSAChecker::Operator& o) { + std::stringstream stream; + stream << o; + return stream.str(); + }); +} |