aboutsummaryrefslogtreecommitdiff
path: root/tosa_checker/tosa_checker_pybind11.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tosa_checker/tosa_checker_pybind11.cc')
-rw-r--r--tosa_checker/tosa_checker_pybind11.cc79
1 files changed, 79 insertions, 0 deletions
diff --git a/tosa_checker/tosa_checker_pybind11.cc b/tosa_checker/tosa_checker_pybind11.cc
new file mode 100644
index 0000000..c799817
--- /dev/null
+++ b/tosa_checker/tosa_checker_pybind11.cc
@@ -0,0 +1,79 @@
+/*
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+*/
+#include "tosa_checker.h"
+
+#include <optional>
+#include <sstream>
+#include <string>
+
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+
+PYBIND11_MODULE(_tosa_checker_wrapper, m) {
+ /**
+ * tosa_checker::TOSAChecker
+ */
+ pybind11::class_<tosa_checker::TOSAChecker> tosa_checker_class(m,
+ "TOSAChecker");
+ tosa_checker_class.def(pybind11::init<const std::string&>(),
+ pybind11::arg("model_path"));
+
+ tosa_checker_class.def(
+ "is_tosa_compatible",
+ [](tosa_checker::TOSAChecker& tc) { return tc.IsTOSACompatible(); },
+ "Check if a model is compatible with the TOSA specification");
+
+ tosa_checker_class.def(
+ "_get_tosa_compatibility_for_ops",
+ [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
+ return tc.GetTOSACompatibilityForOps(elide_large_elements_attrs);
+ },
+ pybind11::arg("elide_large_elements_attrs") = false,
+ "Get all the operators of the models with a TOSA compatibility flag for "
+ "each operator");
+
+ tosa_checker_class.def(
+ "_get_used_tosa_ops",
+ [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
+ return tc.GetUsedTOSAOps(elide_large_elements_attrs);
+ },
+ pybind11::arg("elide_large_elements_attrs") = false,
+ "Get the TOSA operators used by the model after its TOSA legalization");
+
+ tosa_checker_class.def(
+ "_get_mlir_model_representation",
+ [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
+ return tc.GetMLIRModelRepresentation(elide_large_elements_attrs);
+ },
+ pybind11::arg("elide_large_elements_attrs") = false,
+ "Get the MLIR representation of the model");
+
+ tosa_checker_class.def(
+ "_get_mlir_tosa_model_representation",
+ [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
+ return tc.GetMLIRTOSAModelRepresentation(elide_large_elements_attrs);
+ },
+ pybind11::arg("elide_large_elements_attrs") = false,
+ "Get the MLIR representation of the TOSA legalized model");
+
+ /**
+ * tosa_checker::TOSAChecker::Operator
+ */
+ pybind11::class_<tosa_checker::TOSAChecker::Operator>(tosa_checker_class,
+ "_Operator")
+ .def_readonly("name", &tosa_checker::TOSAChecker::Operator::name)
+ .def_readonly("location", &tosa_checker::TOSAChecker::Operator::location)
+ .def_readonly("attributes",
+ &tosa_checker::TOSAChecker::Operator::attributes)
+ .def_readonly("is_tosa_compatible",
+ &tosa_checker::TOSAChecker::Operator::is_tosa_compatible)
+ .def_readonly("mlir_representation",
+ &tosa_checker::TOSAChecker::Operator::mlir_representation)
+ .def("__repr__", [](const tosa_checker::TOSAChecker::Operator& o) {
+ std::stringstream stream;
+ stream << o;
+ return stream.str();
+ });
+}