aboutsummaryrefslogtreecommitdiff
path: root/tosa_checker/tosa_checker_pybind11.cc
blob: c799817baf109dd433b3ada87c773aa6277561a2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
/*
SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
SPDX-License-Identifier: Apache-2.0
*/
#include "tosa_checker.h"

#include <optional>
#include <sstream>
#include <string>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

PYBIND11_MODULE(_tosa_checker_wrapper, m) {
  /**
   * tosa_checker::TOSAChecker
   */
  pybind11::class_<tosa_checker::TOSAChecker> tosa_checker_class(m,
                                                                 "TOSAChecker");
  tosa_checker_class.def(pybind11::init<const std::string&>(),
                         pybind11::arg("model_path"));

  tosa_checker_class.def(
      "is_tosa_compatible",
      [](tosa_checker::TOSAChecker& tc) { return tc.IsTOSACompatible(); },
      "Check if a model is compatible with the TOSA specification");

  tosa_checker_class.def(
      "_get_tosa_compatibility_for_ops",
      [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
        return tc.GetTOSACompatibilityForOps(elide_large_elements_attrs);
      },
      pybind11::arg("elide_large_elements_attrs") = false,
      "Get all the operators of the models with a TOSA compatibility flag for "
      "each operator");

  tosa_checker_class.def(
      "_get_used_tosa_ops",
      [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
        return tc.GetUsedTOSAOps(elide_large_elements_attrs);
      },
      pybind11::arg("elide_large_elements_attrs") = false,
      "Get the TOSA operators used by the model after its TOSA legalization");

  tosa_checker_class.def(
      "_get_mlir_model_representation",
      [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
        return tc.GetMLIRModelRepresentation(elide_large_elements_attrs);
      },
      pybind11::arg("elide_large_elements_attrs") = false,
      "Get the MLIR representation of the model");

  tosa_checker_class.def(
      "_get_mlir_tosa_model_representation",
      [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
        return tc.GetMLIRTOSAModelRepresentation(elide_large_elements_attrs);
      },
      pybind11::arg("elide_large_elements_attrs") = false,
      "Get the MLIR representation of the TOSA legalized model");

  /**
   * tosa_checker::TOSAChecker::Operator
   */
  pybind11::class_<tosa_checker::TOSAChecker::Operator>(tosa_checker_class,
                                                        "_Operator")
      .def_readonly("name", &tosa_checker::TOSAChecker::Operator::name)
      .def_readonly("location", &tosa_checker::TOSAChecker::Operator::location)
      .def_readonly("attributes",
                    &tosa_checker::TOSAChecker::Operator::attributes)
      .def_readonly("is_tosa_compatible",
                    &tosa_checker::TOSAChecker::Operator::is_tosa_compatible)
      .def_readonly("mlir_representation",
                    &tosa_checker::TOSAChecker::Operator::mlir_representation)
      .def("__repr__", [](const tosa_checker::TOSAChecker::Operator& o) {
        std::stringstream stream;
        stream << o;
        return stream.str();
      });
}