aboutsummaryrefslogtreecommitdiff
path: root/tosa_checker/tosa_checker.h
blob: d7750ea0f1533e91ea9d2d1a53edee4bad923f7a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
/*
SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
SPDX-License-Identifier: Apache-2.0
*/
#ifndef TOSA_CHECKER_H_
#define TOSA_CHECKER_H_

#include <map>
#include <optional>
#include <string>
#include <vector>

#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/MLIRContext.h"
#include "mlir/include/mlir/IR/OwningOpRef.h"

namespace tosa_checker {

class TOSAChecker {
 public:
  struct Operator {
    Operator(std::string name, std::string location,
             std::map<std::string, std::string> attributes,
             bool is_tosa_compatible, std::string mlir_representation)
        : name(std::move(name)),
          location(std::move(location)),
          attributes(std::move(attributes)),
          is_tosa_compatible(is_tosa_compatible),
          mlir_representation(std::move(mlir_representation)) {}

    std::string name;
    std::string location;
    std::map<std::string, std::string> attributes;
    bool is_tosa_compatible;
    std::string mlir_representation;
  };

  TOSAChecker(const std::string& model_path);

  bool IsTOSACompatible();

  std::vector<Operator> GetTOSACompatibilityForOps(bool elide_large_attrs);

  std::vector<Operator> GetUsedTOSAOps(bool elide_large_attrs);

  std::string GetMLIRModelRepresentation(bool elide_large_attrs);
  std::string GetMLIRTOSAModelRepresentation(bool elide_large_attrs);

 private:
  template <typename T>
  static std::string GetMLIRRepresentation(T&& op);

  template <typename T>
  static std::string GetMLIRRepresentation(T&& op, bool elide_large_attrs);

  static std::vector<mlir::Operation*> GetTOSAOps(mlir::ModuleOp model);

  static Operator ToOperator(mlir::Operation& op, bool is_tosa_compatible,
                             bool elide_large_attrs);

  static mlir::OwningOpRef<mlir::ModuleOp> TFLiteFileToMLIR(
      const std::string& model_path, mlir::MLIRContext* context);

  static void LegalizeTFLToTOSA(mlir::ModuleOp mlir_module);

  static std::map<std::string, std::string> GetAttributes(
      mlir::Operation& op, bool elide_large_attrs);

 private:
  static constexpr std::int64_t ELIDE_LARGE_ATTRS_LIMIT = 16;

  mlir::MLIRContext m_context;
  mlir::OwningOpRef<mlir::ModuleOp> m_model;
  mlir::OwningOpRef<mlir::ModuleOp> m_tosa_model;
};

}  // namespace tosa_checker

std::ostream& operator<<(std::ostream& os,
                         const tosa_checker::TOSAChecker::Operator& op);

#endif