diff options
Diffstat (limited to 'tosa_checker/tosa_checker_pybind11.cc')
-rw-r--r-- | tosa_checker/tosa_checker_pybind11.cc | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/tosa_checker/tosa_checker_pybind11.cc b/tosa_checker/tosa_checker_pybind11.cc new file mode 100644 index 0000000..c799817 --- /dev/null +++ b/tosa_checker/tosa_checker_pybind11.cc @@ -0,0 +1,79 @@ +/* +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 +*/ +#include "tosa_checker.h" + +#include <optional> +#include <sstream> +#include <string> + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +PYBIND11_MODULE(_tosa_checker_wrapper, m) { + /** + * tosa_checker::TOSAChecker + */ + pybind11::class_<tosa_checker::TOSAChecker> tosa_checker_class(m, + "TOSAChecker"); + tosa_checker_class.def(pybind11::init<const std::string&>(), + pybind11::arg("model_path")); + + tosa_checker_class.def( + "is_tosa_compatible", + [](tosa_checker::TOSAChecker& tc) { return tc.IsTOSACompatible(); }, + "Check if a model is compatible with the TOSA specification"); + + tosa_checker_class.def( + "_get_tosa_compatibility_for_ops", + [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) { + return tc.GetTOSACompatibilityForOps(elide_large_elements_attrs); + }, + pybind11::arg("elide_large_elements_attrs") = false, + "Get all the operators of the models with a TOSA compatibility flag for " + "each operator"); + + tosa_checker_class.def( + "_get_used_tosa_ops", + [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) { + return tc.GetUsedTOSAOps(elide_large_elements_attrs); + }, + pybind11::arg("elide_large_elements_attrs") = false, + "Get the TOSA operators used by the model after its TOSA legalization"); + + tosa_checker_class.def( + "_get_mlir_model_representation", + [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) { + return tc.GetMLIRModelRepresentation(elide_large_elements_attrs); + }, + pybind11::arg("elide_large_elements_attrs") = false, + "Get the MLIR representation of the model"); + + tosa_checker_class.def( + "_get_mlir_tosa_model_representation", + [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) { + return tc.GetMLIRTOSAModelRepresentation(elide_large_elements_attrs); + }, + pybind11::arg("elide_large_elements_attrs") = false, + "Get the MLIR representation of the TOSA legalized model"); + + /** + * tosa_checker::TOSAChecker::Operator + */ + pybind11::class_<tosa_checker::TOSAChecker::Operator>(tosa_checker_class, + "_Operator") + .def_readonly("name", &tosa_checker::TOSAChecker::Operator::name) + .def_readonly("location", &tosa_checker::TOSAChecker::Operator::location) + .def_readonly("attributes", + &tosa_checker::TOSAChecker::Operator::attributes) + .def_readonly("is_tosa_compatible", + &tosa_checker::TOSAChecker::Operator::is_tosa_compatible) + .def_readonly("mlir_representation", + &tosa_checker::TOSAChecker::Operator::mlir_representation) + .def("__repr__", [](const tosa_checker::TOSAChecker::Operator& o) { + std::stringstream stream; + stream << o; + return stream.str(); + }); +} |