aboutsummaryrefslogtreecommitdiff
path: root/tosa_checker/tosa_checker.h
diff options
context:
space:
mode:
Diffstat (limited to 'tosa_checker/tosa_checker.h')
-rw-r--r--tosa_checker/tosa_checker.h82
1 files changed, 82 insertions, 0 deletions
diff --git a/tosa_checker/tosa_checker.h b/tosa_checker/tosa_checker.h
new file mode 100644
index 0000000..d7750ea
--- /dev/null
+++ b/tosa_checker/tosa_checker.h
@@ -0,0 +1,82 @@
+/*
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+*/
+#ifndef TOSA_CHECKER_H_
+#define TOSA_CHECKER_H_
+
+#include <map>
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "mlir/include/mlir/IR/BuiltinOps.h"
+#include "mlir/include/mlir/IR/MLIRContext.h"
+#include "mlir/include/mlir/IR/OwningOpRef.h"
+
+namespace tosa_checker {
+
+class TOSAChecker {
+ public:
+ struct Operator {
+ Operator(std::string name, std::string location,
+ std::map<std::string, std::string> attributes,
+ bool is_tosa_compatible, std::string mlir_representation)
+ : name(std::move(name)),
+ location(std::move(location)),
+ attributes(std::move(attributes)),
+ is_tosa_compatible(is_tosa_compatible),
+ mlir_representation(std::move(mlir_representation)) {}
+
+ std::string name;
+ std::string location;
+ std::map<std::string, std::string> attributes;
+ bool is_tosa_compatible;
+ std::string mlir_representation;
+ };
+
+ TOSAChecker(const std::string& model_path);
+
+ bool IsTOSACompatible();
+
+ std::vector<Operator> GetTOSACompatibilityForOps(bool elide_large_attrs);
+
+ std::vector<Operator> GetUsedTOSAOps(bool elide_large_attrs);
+
+ std::string GetMLIRModelRepresentation(bool elide_large_attrs);
+ std::string GetMLIRTOSAModelRepresentation(bool elide_large_attrs);
+
+ private:
+ template <typename T>
+ static std::string GetMLIRRepresentation(T&& op);
+
+ template <typename T>
+ static std::string GetMLIRRepresentation(T&& op, bool elide_large_attrs);
+
+ static std::vector<mlir::Operation*> GetTOSAOps(mlir::ModuleOp model);
+
+ static Operator ToOperator(mlir::Operation& op, bool is_tosa_compatible,
+ bool elide_large_attrs);
+
+ static mlir::OwningOpRef<mlir::ModuleOp> TFLiteFileToMLIR(
+ const std::string& model_path, mlir::MLIRContext* context);
+
+ static void LegalizeTFLToTOSA(mlir::ModuleOp mlir_module);
+
+ static std::map<std::string, std::string> GetAttributes(
+ mlir::Operation& op, bool elide_large_attrs);
+
+ private:
+ static constexpr std::int64_t ELIDE_LARGE_ATTRS_LIMIT = 16;
+
+ mlir::MLIRContext m_context;
+ mlir::OwningOpRef<mlir::ModuleOp> m_model;
+ mlir::OwningOpRef<mlir::ModuleOp> m_tosa_model;
+};
+
+} // namespace tosa_checker
+
+std::ostream& operator<<(std::ostream& os,
+ const tosa_checker::TOSAChecker::Operator& op);
+
+#endif