aboutsummaryrefslogtreecommitdiff
path: root/tosa_checker
diff options
context:
space:
mode:
Diffstat (limited to 'tosa_checker')
-rw-r--r--tosa_checker/BUILD35
-rw-r--r--tosa_checker/__init__.py8
-rw-r--r--tosa_checker/tosa_checker.cc225
-rw-r--r--tosa_checker/tosa_checker.h82
-rw-r--r--tosa_checker/tosa_checker_pybind11.cc79
5 files changed, 429 insertions, 0 deletions
diff --git a/tosa_checker/BUILD b/tosa_checker/BUILD
new file mode 100644
index 0000000..8c1c32d
--- /dev/null
+++ b/tosa_checker/BUILD
@@ -0,0 +1,35 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
+
+cc_library(
+ name = "tosa_checker_lib",
+ srcs = ["tosa_checker.cc"],
+ hdrs = ["tosa_checker.h"],
+ deps = [
+ "@llvm-project//mlir:MlirTranslateMain",
+ "@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
+ "@org_tensorflow//tensorflow/compiler/mlir/tosa:tfl_passes",
+ "@pybind11",
+ ],
+)
+
+pybind_extension(
+ name = "_tosa_checker_wrapper",
+ srcs = [
+ "tosa_checker_pybind11.cc",
+ ],
+ deps = [
+ ":tosa_checker_lib",
+ ],
+)
+
+py_library(
+ name = "tosa_checker",
+ srcs = [
+ "__init__.py",
+ ],
+ srcs_version = "PY3",
+ visibility = ["//visibility:public"],
+ data = ["//tosa_checker:_tosa_checker_wrapper.so"],
+)
diff --git a/tosa_checker/__init__.py b/tosa_checker/__init__.py
new file mode 100644
index 0000000..ce76797
--- /dev/null
+++ b/tosa_checker/__init__.py
@@ -0,0 +1,8 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+"""the package provides a way to check if a TFLite model is compatible with the TOSA specification."""
+
+from _tosa_checker_wrapper import *
+
+__version__ = "0.1.0"
diff --git a/tosa_checker/tosa_checker.cc b/tosa_checker/tosa_checker.cc
new file mode 100644
index 0000000..714cab3
--- /dev/null
+++ b/tosa_checker/tosa_checker.cc
@@ -0,0 +1,225 @@
+/*
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+*/
+#include "tosa_checker.h"
+
+#include "absl/strings/string_view.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
+#include "tensorflow/compiler/mlir/tosa/tfl_passes.h"
+
+#include <map>
+#include <memory>
+#include <optional>
+#include <stdexcept>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+namespace std {
+template <>
+struct hash<mlir::Location> {
+ std::size_t operator()(const mlir::Location &loc) const {
+ return mlir::hash_value(loc);
+ }
+};
+} // namespace std
+
+namespace tosa_checker {
+
+TOSAChecker::TOSAChecker(const std::string &model_path) {
+ m_model = TFLiteFileToMLIR(model_path, &m_context);
+ m_tosa_model = m_model->clone();
+ LegalizeTFLToTOSA(*m_tosa_model);
+}
+
+bool TOSAChecker::IsTOSACompatible() {
+ bool is_tosa_compatible = true;
+ for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) {
+ func.walk([&](mlir::Operation *op) {
+ // Ignore func namespace
+ const mlir::Dialect *dialect = op->getDialect();
+ if (!dialect || (!dialect->getNamespace().equals("tosa") &&
+ !dialect->getNamespace().equals("func"))) {
+ is_tosa_compatible = false;
+ return mlir::WalkResult::interrupt();
+ }
+
+ return mlir::WalkResult::advance();
+ });
+ }
+
+ return is_tosa_compatible;
+}
+
+std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps(
+ bool elide_large_attrs) {
+ // Get the locations of all the ops in the legalized model that were not
+ // converted during the TOSA legalization (i.e. the TOSA incompatible ones).
+ std::unordered_set<mlir::Location> tosa_incompatible_locs;
+ for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) {
+ func.walk([&](mlir::Operation *op) {
+ // Ignore func namespace
+ const mlir::Dialect *dialect = op->getDialect();
+ if (!dialect || (!dialect->getNamespace().equals("tosa") &&
+ !dialect->getNamespace().equals("func"))) {
+ tosa_incompatible_locs.insert(op->getLoc());
+ }
+ });
+ }
+
+ // We assume that on legalization, the non-legalized ops keep their original
+ // location. If an op location from the original model is in
+ // tosa_incompatible_locs then the op is not tosa compatible, otherwise it is.
+ std::vector<Operator> ops;
+ for (auto func : m_model->getOps<mlir::func::FuncOp>()) {
+ func.walk([&](mlir::Operation *op) {
+ // Ignore func namespace
+ const mlir::Dialect *dialect = op->getDialect();
+ if (!dialect || !dialect->getNamespace().equals("func")) {
+ const bool is_tosa_compatible =
+ tosa_incompatible_locs.find(op->getLoc()) ==
+ tosa_incompatible_locs.end();
+ ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs));
+ }
+ });
+ }
+
+ return ops;
+}
+
+std::vector<TOSAChecker::Operator> TOSAChecker::GetUsedTOSAOps(
+ bool elide_large_attrs) {
+ std::vector<Operator> tosa_ops;
+ for (mlir::Operation *op : GetTOSAOps(*m_tosa_model)) {
+ const bool is_tosa_compatible = true;
+ tosa_ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs));
+ }
+
+ return tosa_ops;
+}
+
+std::string TOSAChecker::GetMLIRModelRepresentation(bool elide_large_attrs) {
+ return GetMLIRRepresentation(*m_model, elide_large_attrs);
+}
+
+std::string TOSAChecker::GetMLIRTOSAModelRepresentation(
+ bool elide_large_attrs) {
+ return GetMLIRRepresentation(*m_tosa_model, elide_large_attrs);
+}
+
+template <typename T>
+std::string TOSAChecker::GetMLIRRepresentation(T &&op) {
+ std::string value;
+ llvm::raw_string_ostream value_ostream(value);
+
+ op.print(value_ostream);
+
+ return value;
+}
+
+template <typename T>
+std::string TOSAChecker::GetMLIRRepresentation(T &&op, bool elide_large_attrs) {
+ std::string value;
+ llvm::raw_string_ostream value_ostream(value);
+
+ mlir::OpPrintingFlags flags;
+ if (elide_large_attrs) {
+ flags.elideLargeElementsAttrs(ELIDE_LARGE_ATTRS_LIMIT);
+ }
+ op.print(value_ostream, flags);
+
+ return value;
+}
+
+std::vector<mlir::Operation *> TOSAChecker::GetTOSAOps(mlir::ModuleOp model) {
+ std::vector<mlir::Operation *> tosa_ops;
+ for (auto func : model.getOps<mlir::func::FuncOp>()) {
+ func.walk([&](mlir::Operation *op) {
+ const mlir::Dialect *dialect = op->getDialect();
+ if (dialect && dialect->getNamespace().equals("tosa")) {
+ tosa_ops.push_back(op);
+ }
+ });
+ }
+
+ return tosa_ops;
+}
+
+TOSAChecker::Operator TOSAChecker::ToOperator(mlir::Operation &op,
+ bool is_tosa_compatible,
+ bool elide_large_attrs) {
+ return Operator(op.getName().getStringRef().str(),
+ GetMLIRRepresentation(op.getLoc()),
+ GetAttributes(op, elide_large_attrs), is_tosa_compatible,
+ GetMLIRRepresentation(op, elide_large_attrs));
+}
+
+mlir::OwningOpRef<mlir::ModuleOp> TOSAChecker::TFLiteFileToMLIR(
+ const std::string &model_path, mlir::MLIRContext *context) {
+ std::string error_message;
+ std::unique_ptr<llvm::MemoryBuffer> input =
+ mlir::openInputFile(model_path, &error_message);
+ if (!input) {
+ throw std::runtime_error(error_message);
+ }
+
+ const mlir::FileLineColLoc location =
+ mlir::FileLineColLoc::get(context, input->getBufferIdentifier(), 0, 0);
+
+ auto mlir_module = tflite::FlatBufferToMlir(
+ absl::string_view(input->getBufferStart(), input->getBufferSize()),
+ context, location);
+ if (!mlir_module || mlir::failed(mlir::verify(*mlir_module))) {
+ throw std::runtime_error(
+ "Could not convert the TFLite model to its MLIR representation.");
+ }
+
+ return mlir_module;
+}
+
+void TOSAChecker::LegalizeTFLToTOSA(mlir::ModuleOp mlir_module) {
+ mlir::PassManager pm(mlir_module.getContext(),
+ mlir::OpPassManager::Nesting::Implicit);
+ mlir::tosa::TOSATFLLegalizationPipelineOptions opts;
+ mlir::tosa::createTFLtoTOSALegalizationPipeline(pm, opts);
+ // TODO Don't check for mlir::failed state for now due to some incoherences in
+ // how the legalization report non-convertible ops (sometimes with a hard
+ // fail, sometimes without). The legalization should not return a failed
+ // state if an operator can't be legalized and should leave it in its original
+ // dialect.
+ pm.run(mlir_module);
+}
+
+std::map<std::string, std::string> TOSAChecker::GetAttributes(
+ mlir::Operation &op, bool /*elide_large_attrs*/) {
+ std::map<std::string, std::string> attributes;
+ for (const mlir::NamedAttribute &attr : op.getAttrs()) {
+ attributes.emplace(attr.getName().str(),
+ // TODO Check how to elide large attributes when
+ // converting them to string, mlir::Attribute::print has
+ // no mlir::OpPrintingFlags.
+ GetMLIRRepresentation(attr.getValue()));
+ }
+
+ return attributes;
+}
+
+} // namespace tosa_checker
+
+std::ostream &operator<<(std::ostream &os,
+ const tosa_checker::TOSAChecker::Operator &op) {
+ os << op.mlir_representation;
+
+ return os;
+}
diff --git a/tosa_checker/tosa_checker.h b/tosa_checker/tosa_checker.h
new file mode 100644
index 0000000..d7750ea
--- /dev/null
+++ b/tosa_checker/tosa_checker.h
@@ -0,0 +1,82 @@
+/*
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+*/
+#ifndef TOSA_CHECKER_H_
+#define TOSA_CHECKER_H_
+
+#include <map>
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "mlir/include/mlir/IR/BuiltinOps.h"
+#include "mlir/include/mlir/IR/MLIRContext.h"
+#include "mlir/include/mlir/IR/OwningOpRef.h"
+
+namespace tosa_checker {
+
+class TOSAChecker {
+ public:
+ struct Operator {
+ Operator(std::string name, std::string location,
+ std::map<std::string, std::string> attributes,
+ bool is_tosa_compatible, std::string mlir_representation)
+ : name(std::move(name)),
+ location(std::move(location)),
+ attributes(std::move(attributes)),
+ is_tosa_compatible(is_tosa_compatible),
+ mlir_representation(std::move(mlir_representation)) {}
+
+ std::string name;
+ std::string location;
+ std::map<std::string, std::string> attributes;
+ bool is_tosa_compatible;
+ std::string mlir_representation;
+ };
+
+ TOSAChecker(const std::string& model_path);
+
+ bool IsTOSACompatible();
+
+ std::vector<Operator> GetTOSACompatibilityForOps(bool elide_large_attrs);
+
+ std::vector<Operator> GetUsedTOSAOps(bool elide_large_attrs);
+
+ std::string GetMLIRModelRepresentation(bool elide_large_attrs);
+ std::string GetMLIRTOSAModelRepresentation(bool elide_large_attrs);
+
+ private:
+ template <typename T>
+ static std::string GetMLIRRepresentation(T&& op);
+
+ template <typename T>
+ static std::string GetMLIRRepresentation(T&& op, bool elide_large_attrs);
+
+ static std::vector<mlir::Operation*> GetTOSAOps(mlir::ModuleOp model);
+
+ static Operator ToOperator(mlir::Operation& op, bool is_tosa_compatible,
+ bool elide_large_attrs);
+
+ static mlir::OwningOpRef<mlir::ModuleOp> TFLiteFileToMLIR(
+ const std::string& model_path, mlir::MLIRContext* context);
+
+ static void LegalizeTFLToTOSA(mlir::ModuleOp mlir_module);
+
+ static std::map<std::string, std::string> GetAttributes(
+ mlir::Operation& op, bool elide_large_attrs);
+
+ private:
+ static constexpr std::int64_t ELIDE_LARGE_ATTRS_LIMIT = 16;
+
+ mlir::MLIRContext m_context;
+ mlir::OwningOpRef<mlir::ModuleOp> m_model;
+ mlir::OwningOpRef<mlir::ModuleOp> m_tosa_model;
+};
+
+} // namespace tosa_checker
+
+std::ostream& operator<<(std::ostream& os,
+ const tosa_checker::TOSAChecker::Operator& op);
+
+#endif
diff --git a/tosa_checker/tosa_checker_pybind11.cc b/tosa_checker/tosa_checker_pybind11.cc
new file mode 100644
index 0000000..c799817
--- /dev/null
+++ b/tosa_checker/tosa_checker_pybind11.cc
@@ -0,0 +1,79 @@
+/*
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+*/
+#include "tosa_checker.h"
+
+#include <optional>
+#include <sstream>
+#include <string>
+
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+
+PYBIND11_MODULE(_tosa_checker_wrapper, m) {
+ /**
+ * tosa_checker::TOSAChecker
+ */
+ pybind11::class_<tosa_checker::TOSAChecker> tosa_checker_class(m,
+ "TOSAChecker");
+ tosa_checker_class.def(pybind11::init<const std::string&>(),
+ pybind11::arg("model_path"));
+
+ tosa_checker_class.def(
+ "is_tosa_compatible",
+ [](tosa_checker::TOSAChecker& tc) { return tc.IsTOSACompatible(); },
+ "Check if a model is compatible with the TOSA specification");
+
+ tosa_checker_class.def(
+ "_get_tosa_compatibility_for_ops",
+ [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
+ return tc.GetTOSACompatibilityForOps(elide_large_elements_attrs);
+ },
+ pybind11::arg("elide_large_elements_attrs") = false,
+ "Get all the operators of the models with a TOSA compatibility flag for "
+ "each operator");
+
+ tosa_checker_class.def(
+ "_get_used_tosa_ops",
+ [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
+ return tc.GetUsedTOSAOps(elide_large_elements_attrs);
+ },
+ pybind11::arg("elide_large_elements_attrs") = false,
+ "Get the TOSA operators used by the model after its TOSA legalization");
+
+ tosa_checker_class.def(
+ "_get_mlir_model_representation",
+ [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
+ return tc.GetMLIRModelRepresentation(elide_large_elements_attrs);
+ },
+ pybind11::arg("elide_large_elements_attrs") = false,
+ "Get the MLIR representation of the model");
+
+ tosa_checker_class.def(
+ "_get_mlir_tosa_model_representation",
+ [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
+ return tc.GetMLIRTOSAModelRepresentation(elide_large_elements_attrs);
+ },
+ pybind11::arg("elide_large_elements_attrs") = false,
+ "Get the MLIR representation of the TOSA legalized model");
+
+ /**
+ * tosa_checker::TOSAChecker::Operator
+ */
+ pybind11::class_<tosa_checker::TOSAChecker::Operator>(tosa_checker_class,
+ "_Operator")
+ .def_readonly("name", &tosa_checker::TOSAChecker::Operator::name)
+ .def_readonly("location", &tosa_checker::TOSAChecker::Operator::location)
+ .def_readonly("attributes",
+ &tosa_checker::TOSAChecker::Operator::attributes)
+ .def_readonly("is_tosa_compatible",
+ &tosa_checker::TOSAChecker::Operator::is_tosa_compatible)
+ .def_readonly("mlir_representation",
+ &tosa_checker::TOSAChecker::Operator::mlir_representation)
+ .def("__repr__", [](const tosa_checker::TOSAChecker::Operator& o) {
+ std::stringstream stream;
+ stream << o;
+ return stream.str();
+ });
+}