aboutsummaryrefslogtreecommitdiff
path: root/tosa_checker/tosa_checker.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tosa_checker/tosa_checker.cc')
-rw-r--r--tosa_checker/tosa_checker.cc225
1 files changed, 225 insertions, 0 deletions
diff --git a/tosa_checker/tosa_checker.cc b/tosa_checker/tosa_checker.cc
new file mode 100644
index 0000000..714cab3
--- /dev/null
+++ b/tosa_checker/tosa_checker.cc
@@ -0,0 +1,225 @@
+/*
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+*/
+#include "tosa_checker.h"
+
+#include "absl/strings/string_view.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
+#include "tensorflow/compiler/mlir/tosa/tfl_passes.h"
+
+#include <map>
+#include <memory>
+#include <optional>
+#include <stdexcept>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+namespace std {
+template <>
+struct hash<mlir::Location> {
+ std::size_t operator()(const mlir::Location &loc) const {
+ return mlir::hash_value(loc);
+ }
+};
+} // namespace std
+
+namespace tosa_checker {
+
+TOSAChecker::TOSAChecker(const std::string &model_path) {
+ m_model = TFLiteFileToMLIR(model_path, &m_context);
+ m_tosa_model = m_model->clone();
+ LegalizeTFLToTOSA(*m_tosa_model);
+}
+
+bool TOSAChecker::IsTOSACompatible() {
+ bool is_tosa_compatible = true;
+ for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) {
+ func.walk([&](mlir::Operation *op) {
+ // Ignore func namespace
+ const mlir::Dialect *dialect = op->getDialect();
+ if (!dialect || (!dialect->getNamespace().equals("tosa") &&
+ !dialect->getNamespace().equals("func"))) {
+ is_tosa_compatible = false;
+ return mlir::WalkResult::interrupt();
+ }
+
+ return mlir::WalkResult::advance();
+ });
+ }
+
+ return is_tosa_compatible;
+}
+
+std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps(
+ bool elide_large_attrs) {
+ // Get the locations of all the ops in the legalized model that were not
+ // converted during the TOSA legalization (i.e. the TOSA incompatible ones).
+ std::unordered_set<mlir::Location> tosa_incompatible_locs;
+ for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) {
+ func.walk([&](mlir::Operation *op) {
+ // Ignore func namespace
+ const mlir::Dialect *dialect = op->getDialect();
+ if (!dialect || (!dialect->getNamespace().equals("tosa") &&
+ !dialect->getNamespace().equals("func"))) {
+ tosa_incompatible_locs.insert(op->getLoc());
+ }
+ });
+ }
+
+ // We assume that on legalization, the non-legalized ops keep their original
+ // location. If an op location from the original model is in
+ // tosa_incompatible_locs then the op is not tosa compatible, otherwise it is.
+ std::vector<Operator> ops;
+ for (auto func : m_model->getOps<mlir::func::FuncOp>()) {
+ func.walk([&](mlir::Operation *op) {
+ // Ignore func namespace
+ const mlir::Dialect *dialect = op->getDialect();
+ if (!dialect || !dialect->getNamespace().equals("func")) {
+ const bool is_tosa_compatible =
+ tosa_incompatible_locs.find(op->getLoc()) ==
+ tosa_incompatible_locs.end();
+ ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs));
+ }
+ });
+ }
+
+ return ops;
+}
+
+std::vector<TOSAChecker::Operator> TOSAChecker::GetUsedTOSAOps(
+ bool elide_large_attrs) {
+ std::vector<Operator> tosa_ops;
+ for (mlir::Operation *op : GetTOSAOps(*m_tosa_model)) {
+ const bool is_tosa_compatible = true;
+ tosa_ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs));
+ }
+
+ return tosa_ops;
+}
+
+std::string TOSAChecker::GetMLIRModelRepresentation(bool elide_large_attrs) {
+ return GetMLIRRepresentation(*m_model, elide_large_attrs);
+}
+
+std::string TOSAChecker::GetMLIRTOSAModelRepresentation(
+ bool elide_large_attrs) {
+ return GetMLIRRepresentation(*m_tosa_model, elide_large_attrs);
+}
+
+template <typename T>
+std::string TOSAChecker::GetMLIRRepresentation(T &&op) {
+ std::string value;
+ llvm::raw_string_ostream value_ostream(value);
+
+ op.print(value_ostream);
+
+ return value;
+}
+
+template <typename T>
+std::string TOSAChecker::GetMLIRRepresentation(T &&op, bool elide_large_attrs) {
+ std::string value;
+ llvm::raw_string_ostream value_ostream(value);
+
+ mlir::OpPrintingFlags flags;
+ if (elide_large_attrs) {
+ flags.elideLargeElementsAttrs(ELIDE_LARGE_ATTRS_LIMIT);
+ }
+ op.print(value_ostream, flags);
+
+ return value;
+}
+
+std::vector<mlir::Operation *> TOSAChecker::GetTOSAOps(mlir::ModuleOp model) {
+ std::vector<mlir::Operation *> tosa_ops;
+ for (auto func : model.getOps<mlir::func::FuncOp>()) {
+ func.walk([&](mlir::Operation *op) {
+ const mlir::Dialect *dialect = op->getDialect();
+ if (dialect && dialect->getNamespace().equals("tosa")) {
+ tosa_ops.push_back(op);
+ }
+ });
+ }
+
+ return tosa_ops;
+}
+
+TOSAChecker::Operator TOSAChecker::ToOperator(mlir::Operation &op,
+ bool is_tosa_compatible,
+ bool elide_large_attrs) {
+ return Operator(op.getName().getStringRef().str(),
+ GetMLIRRepresentation(op.getLoc()),
+ GetAttributes(op, elide_large_attrs), is_tosa_compatible,
+ GetMLIRRepresentation(op, elide_large_attrs));
+}
+
+mlir::OwningOpRef<mlir::ModuleOp> TOSAChecker::TFLiteFileToMLIR(
+ const std::string &model_path, mlir::MLIRContext *context) {
+ std::string error_message;
+ std::unique_ptr<llvm::MemoryBuffer> input =
+ mlir::openInputFile(model_path, &error_message);
+ if (!input) {
+ throw std::runtime_error(error_message);
+ }
+
+ const mlir::FileLineColLoc location =
+ mlir::FileLineColLoc::get(context, input->getBufferIdentifier(), 0, 0);
+
+ auto mlir_module = tflite::FlatBufferToMlir(
+ absl::string_view(input->getBufferStart(), input->getBufferSize()),
+ context, location);
+ if (!mlir_module || mlir::failed(mlir::verify(*mlir_module))) {
+ throw std::runtime_error(
+ "Could not convert the TFLite model to its MLIR representation.");
+ }
+
+ return mlir_module;
+}
+
+void TOSAChecker::LegalizeTFLToTOSA(mlir::ModuleOp mlir_module) {
+ mlir::PassManager pm(mlir_module.getContext(),
+ mlir::OpPassManager::Nesting::Implicit);
+ mlir::tosa::TOSATFLLegalizationPipelineOptions opts;
+ mlir::tosa::createTFLtoTOSALegalizationPipeline(pm, opts);
+ // TODO Don't check for mlir::failed state for now due to some incoherences in
+ // how the legalization report non-convertible ops (sometimes with a hard
+ // fail, sometimes without). The legalization should not return a failed
+ // state if an operator can't be legalized and should leave it in its original
+ // dialect.
+ pm.run(mlir_module);
+}
+
+std::map<std::string, std::string> TOSAChecker::GetAttributes(
+ mlir::Operation &op, bool /*elide_large_attrs*/) {
+ std::map<std::string, std::string> attributes;
+ for (const mlir::NamedAttribute &attr : op.getAttrs()) {
+ attributes.emplace(attr.getName().str(),
+ // TODO Check how to elide large attributes when
+ // converting them to string, mlir::Attribute::print has
+ // no mlir::OpPrintingFlags.
+ GetMLIRRepresentation(attr.getValue()));
+ }
+
+ return attributes;
+}
+
+} // namespace tosa_checker
+
+std::ostream &operator<<(std::ostream &os,
+ const tosa_checker::TOSAChecker::Operator &op) {
+ os << op.mlir_representation;
+
+ return os;
+}