diff options
Diffstat (limited to 'tosa_checker/tosa_checker.h')
-rw-r--r-- | tosa_checker/tosa_checker.h | 82 |
1 files changed, 82 insertions, 0 deletions
diff --git a/tosa_checker/tosa_checker.h b/tosa_checker/tosa_checker.h new file mode 100644 index 0000000..d7750ea --- /dev/null +++ b/tosa_checker/tosa_checker.h @@ -0,0 +1,82 @@ +/* +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 +*/ +#ifndef TOSA_CHECKER_H_ +#define TOSA_CHECKER_H_ + +#include <map> +#include <optional> +#include <string> +#include <vector> + +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/IR/MLIRContext.h" +#include "mlir/include/mlir/IR/OwningOpRef.h" + +namespace tosa_checker { + +class TOSAChecker { + public: + struct Operator { + Operator(std::string name, std::string location, + std::map<std::string, std::string> attributes, + bool is_tosa_compatible, std::string mlir_representation) + : name(std::move(name)), + location(std::move(location)), + attributes(std::move(attributes)), + is_tosa_compatible(is_tosa_compatible), + mlir_representation(std::move(mlir_representation)) {} + + std::string name; + std::string location; + std::map<std::string, std::string> attributes; + bool is_tosa_compatible; + std::string mlir_representation; + }; + + TOSAChecker(const std::string& model_path); + + bool IsTOSACompatible(); + + std::vector<Operator> GetTOSACompatibilityForOps(bool elide_large_attrs); + + std::vector<Operator> GetUsedTOSAOps(bool elide_large_attrs); + + std::string GetMLIRModelRepresentation(bool elide_large_attrs); + std::string GetMLIRTOSAModelRepresentation(bool elide_large_attrs); + + private: + template <typename T> + static std::string GetMLIRRepresentation(T&& op); + + template <typename T> + static std::string GetMLIRRepresentation(T&& op, bool elide_large_attrs); + + static std::vector<mlir::Operation*> GetTOSAOps(mlir::ModuleOp model); + + static Operator ToOperator(mlir::Operation& op, bool is_tosa_compatible, + bool elide_large_attrs); + + static mlir::OwningOpRef<mlir::ModuleOp> TFLiteFileToMLIR( + const std::string& model_path, mlir::MLIRContext* context); + + static void LegalizeTFLToTOSA(mlir::ModuleOp mlir_module); + + static std::map<std::string, std::string> GetAttributes( + mlir::Operation& op, bool elide_large_attrs); + + private: + static constexpr std::int64_t ELIDE_LARGE_ATTRS_LIMIT = 16; + + mlir::MLIRContext m_context; + mlir::OwningOpRef<mlir::ModuleOp> m_model; + mlir::OwningOpRef<mlir::ModuleOp> m_tosa_model; +}; + +} // namespace tosa_checker + +std::ostream& operator<<(std::ostream& os, + const tosa_checker::TOSAChecker::Operator& op); + +#endif |