From 52dacd6556d60815253d4e4938e218ea3d8084a2 Mon Sep 17 00:00:00 2001 From: Thibaut Goetghebuer-Planchon Date: Wed, 6 Jul 2022 10:23:22 +0100 Subject: Initial commit Change-Id: I2fb0933d595a6ede6417d09dd905ef72d6c60c9b --- .bazelrc | 8 ++ .bazelversion | 2 + .bazelversion.license | 2 + .gitignore | 17 +++ LICENSES/Apache-2.0.txt | 177 ++++++++++++++++++++++++++ MANIFEST.in | 4 + README.md | 54 ++++++++ RELEASES.md | 20 +++ WORKSPACE | 32 +++++ docker/Dockerfile | 26 ++++ docker/README.md | 52 ++++++++ docker/install/install_bazel.sh | 24 ++++ examples/tosa_checker.ipynb | 117 ++++++++++++++++++ examples/tosa_checker.ipynb.license | 3 + setup.py | 127 +++++++++++++++++++ tests/requirements.txt | 4 + tests/test_tosa_checker.py | 216 ++++++++++++++++++++++++++++++++ tosa_checker/BUILD | 35 ++++++ tosa_checker/__init__.py | 8 ++ tosa_checker/tosa_checker.cc | 225 ++++++++++++++++++++++++++++++++++ tosa_checker/tosa_checker.h | 82 +++++++++++++ tosa_checker/tosa_checker_pybind11.cc | 79 ++++++++++++ 22 files changed, 1314 insertions(+) create mode 100644 .bazelrc create mode 100644 .bazelversion create mode 100644 .bazelversion.license create mode 100644 .gitignore create mode 100644 LICENSES/Apache-2.0.txt create mode 100644 MANIFEST.in create mode 100644 README.md create mode 100644 RELEASES.md create mode 100644 WORKSPACE create mode 100644 docker/Dockerfile create mode 100644 docker/README.md create mode 100755 docker/install/install_bazel.sh create mode 100644 examples/tosa_checker.ipynb create mode 100644 examples/tosa_checker.ipynb.license create mode 100644 setup.py create mode 100644 tests/requirements.txt create mode 100644 tests/test_tosa_checker.py create mode 100644 tosa_checker/BUILD create mode 100644 tosa_checker/__init__.py create mode 100644 tosa_checker/tosa_checker.cc create mode 100644 tosa_checker/tosa_checker.h create mode 100644 tosa_checker/tosa_checker_pybind11.cc diff --git a/.bazelrc b/.bazelrc new file mode 100644 index 0000000..97fd61c --- /dev/null +++ b/.bazelrc @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +# Flag to enable remote config +common --experimental_repo_remote_exec +common --cxxopt=-std=c++17 +common --host_cxxopt=-std=c++17 +common --copt=-w diff --git a/.bazelversion b/.bazelversion new file mode 100644 index 0000000..10803dc --- /dev/null +++ b/.bazelversion @@ -0,0 +1,2 @@ +5.1.1 +# The version must be the same as the one in the TensorFlow version we build against diff --git a/.bazelversion.license b/.bazelversion.license new file mode 100644 index 0000000..487e9d8 --- /dev/null +++ b/.bazelversion.license @@ -0,0 +1,2 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..343775f --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +*.egg-info +*.pyc +*~ +.eggs/ +\.coverage +\.eggs +__pycache__ +build/ +dist/ +*.egg-info +bazel-bin/* +bazel-out/* +bazel-tosa_checker/* +bazel-testlogs/* +.pytest_cache/* diff --git a/LICENSES/Apache-2.0.txt b/LICENSES/Apache-2.0.txt new file mode 100644 index 0000000..f433b1a --- /dev/null +++ b/LICENSES/Apache-2.0.txt @@ -0,0 +1,177 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..74d9656 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +include README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..361ac8b --- /dev/null +++ b/README.md @@ -0,0 +1,54 @@ + +# TOSA Checker + +## Overview + +The TOSA Checker is a tool that provides an easy way to ensure that a TensorFlow™ Lite model is compatible with the [TOSA specification](https://www.mlplatform.org/tosa). + +The project is currently in alpha, some stability issues may still be found. + +## Prerequisites + +To build the TOSA Checker the following are required: + +* Ubuntu® 20.04.03 LTS (the TOSA Checker has been tested on this OS, but should work on other Linux® OS) +* Bazel™ (Bazelisk is an easy way to install the correct version of Bazel™) +* Git™ +* Python® >=3.7 (python, python-dev and python-pip) + +## Building and installation + +pip install -U pip setuptools wheel +pip install -U numpy +python setup.py bdist_wheel +pip install dist/tosa_checker-0.1.0-cp38-cp38-linux_x86_64.whl + +## Docker™ image + +A Docker™ image to build the TOSA Checker is also provided. More information can be found on [docker/README.md](docker/README.md). + +## Usage + +```python +>>> import tosa_checker as tc +>>> checker = tc.TOSAChecker(model_path="model.tflite") +>>> checker.is_tosa_compatible() +True +``` + +## License + +TOSA Checker is licensed under [Apache License 2.0]. + +## Trademarks and Copyrights + +Bazel™ is a trademark of Google® LLC. +Git™ is a trademark of Software Freedom Conservancy. +Linux® is the registered trademark of Linus Torvalds in the U.S. and elsewhere. +Python® is a registered trademark of the PSF. +Ubuntu® is a registered trademark of Canonical. +Tensorflow™ is a trademark of Google® LLC. +Docker™ is a trademark of Docker, Inc. diff --git a/RELEASES.md b/RELEASES.md new file mode 100644 index 0000000..ec666ee --- /dev/null +++ b/RELEASES.md @@ -0,0 +1,20 @@ + +# Release 0.1.0 + +First release of the TOSA Checker tool. The goal of the tool is to provide an easy way to check if a TensorFlow™ Lite model is compatible with the [TOSA specification](https://www.mlplatform.org/tosa). + +The tool is provided as a Python® package and can be used as follow: + +```python +>>> import tosa_checker as tc +>>> checker = tc.TOSAChecker("model.tflite") +>>> checker.is_tosa_compatible() +True +``` + +The tool is currently in alpha, the features set is limited and some stability issues may exist. + +Future versions may extend the functionalities provided and support for other frameworks will be added. diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000..b8a0fba --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +workspace(name = "tosa_checker") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +http_archive( + name = "pybind11", + build_file = "@pybind11_bazel//:pybind11.BUILD", + strip_prefix = "pybind11-2.9.2", + urls = ["https://github.com/pybind/pybind11/archive/v2.9.2.tar.gz"], +) +http_archive( + name = "pybind11_bazel", + strip_prefix = "pybind11_bazel-72cbbf1fbc830e487e3012862b7b720001b70672", + urls = ["https://github.com/pybind/pybind11_bazel/archive/72cbbf1fbc830e487e3012862b7b720001b70672.zip"], +) + +load("@pybind11_bazel//:python_configure.bzl", "python_configure") +python_configure(name = "local_config_python") + +load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") +tf_workspace3() + +load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") +tf_workspace2() + +load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1") +tf_workspace1() + +load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0") +tf_workspace0() + diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..4481565 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +FROM quay.io/pypa/manylinux2014_x86_64 + +ARG PYTHON_VERSION=3.9 +ARG TENSORFLOW_VERSION=2.9.0 +ARG BAZEL_VERSION=5.1.1 + +RUN ln -s /opt/_internal/cpython-$PYTHON_VERSION*/bin/python3 /usr/local/bin/python3 && \ + ln -s /opt/_internal/cpython-$PYTHON_VERSION*/bin/python /usr/local/bin/python && \ + ln -s /opt/_internal/cpython-$PYTHON_VERSION*/bin/pip3 /usr/local/bin/pip3 && \ + ln -s /opt/_internal/cpython-$PYTHON_VERSION*/bin/pip /usr/local/bin/pip && \ + ln -s /opt/_internal/cpython-$PYTHON_VERSION*/include/python${PYTHON_VERSION}/ /usr/local/include/python + +ENV PYTHON_BIN_PATH=/usr/local/bin/python +ENV CI_BUILD_PYTHON=/usr/local/bin/python +ENV CROSSTOOL_PYTHON_INCLUDE_PATH=/usr/local/include/python/ + +RUN pip install --no-cache-dir setuptools pybind11 numpy twine keyrings.alt + +COPY install/install_bazel.sh /install/ +RUN /install/install_bazel.sh ${BAZEL_VERSION} + +# CACHE_STOP is used to rerun future commands, otherwise the cloning will be cached and will not pull the most recent version +ARG CACHE_STOP=1 +RUN git clone --depth=1 https://github.com/tensorflow/tensorflow.git --branch v${TENSORFLOW_VERSION} /tensorflow_src diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 0000000..653b9a2 --- /dev/null +++ b/docker/README.md @@ -0,0 +1,52 @@ + +# Docker™ image + +This directory contains different utilities to build/test the TOSA Checker. + +## How to build the TOSA Checker manylinux wheel with a Docker™ image + +To create a Docker™ image for the TOSA Checker to build it for Python® 3.9 on various Linux® distributions, please run the following command: + +```console +docker build . -t tc-cp39-manylinux --build-arg PYTHON_VERSION=3.9 -f Dockerfile +``` + +The TensorFlow™ source code is automatically downloaded and is located in the `/tensorflow_src` directory. + +The command to run the container is: + +```console +docker run -it -v :/tosa_checker tc-cp39-manylinux +``` + +Now call the following command to build a `tosa_checker` Python® wheel inside of the container: + +```console +cd tosa_checker +python3 setup.py --tensorflow_src_dir /tensorflow_src bdist_wheel +``` +The `tosa_checker` wheel can be found in the `/dist` directory. + +Generate the new manylinux wheel from the `tosa_checker` wheel: +```console +auditwheel repair dist/.whl -w dist/ +``` +The `tosa_checker` manylinux wheel can now be found in the `/dist` directory. + +Install the `tosa_checker` manylinux wheel: +```console +pip install dist/.whl +``` + +## Trademarks and Copyrights + +Python® is a registered trademark of the PSF. +Linux® is the registered trademark of Linus Torvalds in the U.S. and other countries. +Ubuntu® is a registered trademark of Canonical. +TensorFlow™ is a trademark of Google® LLC. +Docker™ is a trademark of Docker, Inc. + + diff --git a/docker/install/install_bazel.sh b/docker/install/install_bazel.sh new file mode 100755 index 0000000..687ed13 --- /dev/null +++ b/docker/install/install_bazel.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +BAZEL_VERSION="$1" +shift + +set +e +local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') + +if [[ "$local_bazel_ver" == "$BAZEL_VERSION" ]]; then + exit 0 +fi + +set -e + +# Install Bazel™ +mkdir -p /bazel +cd /bazel +if [[ ! -f "bazel-$BAZEL_VERSION-installer-linux-x86_64.sh" ]]; then + curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh +fi +chmod +x /bazel/bazel-*.sh +/bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh +rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh diff --git a/examples/tosa_checker.ipynb b/examples/tosa_checker.ipynb new file mode 100644 index 0000000..8fe1f26 --- /dev/null +++ b/examples/tosa_checker.ipynb @@ -0,0 +1,117 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Import the modules needed to create a test model and run the TOSA Checker." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import tosa_checker as tc\n", + "import tensorflow as tf\n", + "import tempfile\n", + "import os" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create a simple model that is compatible with the TOSA specification." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Assets written to: /tmp/tmpxc09cs65/assets\n" + ] + } + ], + "source": [ + "input = tf.keras.layers.Input(shape=(16,))\n", + "x = tf.keras.layers.Dense(8, activation=\"relu\")(input)\n", + "model = tf.keras.models.Model(inputs=[input], outputs=x)\n", + "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", + "tflite_model = converter.convert()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Save this model in `.tflite` format. Note that the TOSA Checker only accepts models in this format currently." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "_, tflite_file = tempfile.mkstemp('.tflite')\n", + "with open(tflite_file, \"wb\") as f:\n", + " f.write(tflite_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use the TOSA Checker to check this model." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Is model TOSA compatible ? True\n" + ] + } + ], + "source": [ + "checker = tc.TOSAChecker(model_path=tflite_file)\n", + "result = checker.is_tosa_compatible()\n", + "print(\"Is model TOSA compatible ? {}\".format(result))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.0 ('tosa_checker': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/tosa_checker.ipynb.license b/examples/tosa_checker.ipynb.license new file mode 100644 index 0000000..bd657a1 --- /dev/null +++ b/examples/tosa_checker.ipynb.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8ea8232 --- /dev/null +++ b/setup.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import argparse +import os +import setuptools +import setuptools.command.build_ext +import shutil +import sys + + +TOSA_CHECKER_VERSION = "0.1.0" +TENSORFLOW_VERSION = "2.9.0" + +# Get the TensorFlow™ source directory passed to the command line (if any). +# If none is given the sources are pulled from the official TF repository. +argparser = argparse.ArgumentParser() +argparser.add_argument( + "--tensorflow_src_dir", help="TensorFlow source directory path", required=False +) +args, unknown = argparser.parse_known_args() +sys.argv = [sys.argv[0]] + unknown + + +class BazelExtensionModule(setuptools.Extension): + def __init__( + self, + py_module_name, + library_name, + bazel_target, + bazel_shared_lib_output, + tensorflow_version, + ): + super().__init__(py_module_name, sources=[]) + self.library_name = library_name + self.bazel_target = bazel_target + self.bazel_shared_lib_output = bazel_shared_lib_output + self.tensorflow_version = tensorflow_version + + +class BazelBuildExtension(setuptools.command.build_ext.build_ext): + """Override build_extension to build the library with bazel and copying it + beforehand.""" + + def build_extension(self, ext): + tensorflow_src_dir = args.tensorflow_src_dir + if not tensorflow_src_dir: + tensorflow_src_dir = os.path.join(self.build_temp, "tensorflow") + self._clone_tf_repository( + tensorflow_src_dir, + ext.tensorflow_version, + ) + + self.spawn( + [ + "bazel", + "build", + "-c", + "opt", + # FIXME Some of the Bazel targets dependencies we use have + # a 'friends' visibility, check if our Bazel target can be added + # to the 'friends' list. + "--check_visibility=false", + "--override_repository=org_tensorflow=" + + os.path.abspath(tensorflow_src_dir), + ext.bazel_target, + ] + ) + + shared_lib_dest_path = self.get_ext_fullpath(ext.name) + shared_lib_dest_dir = os.path.dirname(shared_lib_dest_path) + package_dir = os.path.join(shared_lib_dest_dir, ext.library_name) + + os.makedirs(shared_lib_dest_dir, exist_ok=True) + os.makedirs(package_dir, exist_ok=True) + + shutil.copyfile(ext.bazel_shared_lib_output, shared_lib_dest_path) + shutil.copy(os.path.join(ext.library_name, "__init__.py"), package_dir) + + super().build_extension(ext) + + def _clone_tf_repository(self, tensorflow_src_dir, tensorflow_version): + if os.path.exists(tensorflow_src_dir): + return + + tensorflow_repo = "https://github.com/tensorflow/tensorflow.git" + self.spawn( + [ + "git", + "clone", + "--depth=1", + "--branch", + "v" + tensorflow_version, + tensorflow_repo, + tensorflow_src_dir, + ] + ) + + +setuptools.setup( + name="tosa-checker", + version=TOSA_CHECKER_VERSION, + description="Tool to check if a ML model is compatible with the TOSA specification", + long_description="file: README.md", + long_description_content_type="text/markdown", + author="Arm Limited", + url="https://git.mlplatform.org/tosa/tosa_checker.git/", + license="Apache-2.0", + license_files="LICENSES/*", + python_requires=">=3.7", + cmdclass={"build_ext": BazelBuildExtension}, + ext_modules=[ + BazelExtensionModule( + py_module_name="_tosa_checker_wrapper", + library_name="tosa_checker", + bazel_target="//tosa_checker:tosa_checker", + bazel_shared_lib_output="bazel-bin/tosa_checker/_tosa_checker_wrapper.so", + tensorflow_version=TENSORFLOW_VERSION, + ), + ], + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Topic :: Utilities", + ], +) diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000..9984726 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +tensorflow==2.9.1 +pytest==7.1.2 diff --git a/tests/test_tosa_checker.py b/tests/test_tosa_checker.py new file mode 100644 index 0000000..eb49e65 --- /dev/null +++ b/tests/test_tosa_checker.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import os +import pytest +import tensorflow as tf +import tempfile +import tosa_checker + + +@pytest.fixture(scope="module") +def build_tosa_non_compat_model(): + num_boxes = 6 + max_output_size = 5 + iou_threshold = 0.5 + score_threshold = 0.1 + + def non_max_suppression(x): + boxes = x[0] + scores = x[1] + output = tf.image.non_max_suppression_with_scores( + boxes[0], + scores[0], + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + soft_nms_sigma=1.0, + ) + return output + + boxes_in = tf.keras.layers.Input( + shape=(num_boxes, 4), batch_size=1, dtype=tf.float32, name="boxes" + ) + scores_in = tf.keras.layers.Input( + shape=(num_boxes), batch_size=1, dtype=tf.float32, name="scores" + ) + outputs = tf.keras.layers.Lambda(non_max_suppression)([boxes_in, scores_in]) + model = tf.keras.models.Model(inputs=[boxes_in, scores_in], outputs=outputs) + + return model + + +@pytest.fixture(scope="module") +def build_tosa_compat_model(): + input = tf.keras.layers.Input(shape=(16,)) + x = tf.keras.layers.Dense(8, activation="relu")(input) + model = tf.keras.models.Model(inputs=[input], outputs=x) + return model + + +def create_tflite(model): + converter = tf.lite.TFLiteConverter.from_keras_model(model) + tflite_model = converter.convert() + return tflite_model + + +@pytest.fixture(scope="module") +def non_compat_file(build_tosa_non_compat_model): + tflite_model = create_tflite(build_tosa_non_compat_model) + with tempfile.TemporaryDirectory() as tmp_dir: + file = os.path.join(tmp_dir, "test.tflite") + open(file, "wb").write(tflite_model) + yield file + + +@pytest.fixture(scope="module") +def compat_file(build_tosa_compat_model): + tflite_model = create_tflite(build_tosa_compat_model) + with tempfile.TemporaryDirectory() as tmp_dir: + file = os.path.join(tmp_dir, "test.tflite") + open(file, "wb").write(tflite_model) + yield file + + +class TestTosaCompatibilityTool: + def test_bad_tflite_file(self): + make_bad_tfile = os.path.join(tempfile.mkdtemp(), "test.tflite") + open(make_bad_tfile, "wb").write("bad tflite file".encode("ASCII")) + + with pytest.raises(RuntimeError): + checker = tosa_checker.TOSAChecker(model_path=make_bad_tfile) + + def test_tosa_non_compat_model(self, non_compat_file): + checker = tosa_checker.TOSAChecker(model_path=non_compat_file) + tosa_compatible = checker.is_tosa_compatible() + assert tosa_compatible == False + + ops = checker._get_tosa_compatibility_for_ops() + assert type(ops) == list + assert [[op.name, op.is_tosa_compatible] for op in ops] == [ + ["tfl.pseudo_const", True], + ["tfl.pseudo_const", True], + ["tfl.pseudo_const", True], + ["tfl.strided_slice", True], + ["tfl.pseudo_const", True], + ["tfl.pseudo_const", True], + ["tfl.pseudo_const", True], + ["tfl.strided_slice", True], + ["tfl.pseudo_const", True], + ["tfl.pseudo_const", True], + ["tfl.pseudo_const", True], + ["tfl.pseudo_const", True], + ["tfl.non_max_suppression_v5", False], + ] + + tosa_ops = checker._get_used_tosa_ops() + assert type(tosa_ops) == list + assert [[op.name, op.is_tosa_compatible] for op in tosa_ops] == [ + ["tosa.const", True], + ["tosa.const", True], + ["tosa.const", True], + ["tosa.const", True], + ["tosa.reshape", True], + ["tosa.reshape", True], + ] + + def test_tosa_compat_model(self, compat_file): + checker = tosa_checker.TOSAChecker(model_path=compat_file) + tosa_compatible = checker.is_tosa_compatible() + assert tosa_compatible == True + + ops = checker._get_tosa_compatibility_for_ops() + assert type(ops) == list + assert [[op.name, op.is_tosa_compatible] for op in ops] == [ + ["tfl.pseudo_const", True], + ["tfl.no_value", True], + ["tfl.fully_connected", True], + ] + + tosa_ops = checker._get_used_tosa_ops() + assert type(tosa_ops) == list + assert [[op.name, op.is_tosa_compatible] for op in tosa_ops] == [ + ["tosa.const", True], + ["tosa.const", True], + ["tosa.fully_connected", True], + ["tosa.clamp", True], + ] + + def test_tosa_non_compat_model_mlir_representation(self, non_compat_file): + checker = tosa_checker.TOSAChecker(model_path=non_compat_file) + + tfl_mlir_representation = checker._get_mlir_model_representation( + elide_large_elements_attrs=True + ) + expected_mlir_representation = """\ +module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + func @main(%arg0: tensor<1x6x4xf32> {tf_saved_model.index_path = ["boxes"]}, %arg1: tensor<1x6xf32> {tf_saved_model.index_path = ["scores"]}) -> (tensor {tf_saved_model.index_path = ["lambda_1"]}, tensor {tf_saved_model.index_path = ["lambda"]}) attributes {tf.entry_function = {inputs = "serving_default_boxes:0,serving_default_scores:0", outputs = "PartitionedCall:1,PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "tfl.pseudo_const"() {value = dense<[1, 6, 4]> : tensor<3xi32>} : () -> tensor<3xi32> + %2 = "tfl.pseudo_const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> + %3 = "tfl.strided_slice"(%arg0, %0, %1, %2) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<1x6x4xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6x4xf32> + %4 = "tfl.pseudo_const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32> + %5 = "tfl.pseudo_const"() {value = dense<[1, 6]> : tensor<2xi32>} : () -> tensor<2xi32> + %6 = "tfl.pseudo_const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32> + %7 = "tfl.strided_slice"(%arg1, %4, %5, %6) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 2 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<1x6xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6xf32> + %8 = "tfl.pseudo_const"() {value = dense<5> : tensor} : () -> tensor + %9 = "tfl.pseudo_const"() {value = dense<5.000000e-01> : tensor} : () -> tensor + %10 = "tfl.pseudo_const"() {value = dense<1.000000e-01> : tensor} : () -> tensor + %11 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %selected_indices, %selected_scores, %valid_outputs = "tfl.non_max_suppression_v5"(%3, %7, %8, %9, %10, %11) : (tensor<6x4xf32>, tensor<6xf32>, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor<*xi32>) + return %selected_scores, %selected_indices : tensor, tensor + } +} +""" + assert tfl_mlir_representation == expected_mlir_representation + + tosa_mlir_representation = checker._get_mlir_tosa_model_representation( + elide_large_elements_attrs=True + ) + expected_tosa_mlir_representation = """\ +module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + func @main(%arg0: tensor<1x6x4xf32> {tf_saved_model.index_path = ["boxes"]}, %arg1: tensor<1x6xf32> {tf_saved_model.index_path = ["scores"]}) -> (tensor {tf_saved_model.index_path = ["lambda_1"]}, tensor {tf_saved_model.index_path = ["lambda"]}) attributes {tf.entry_function = {inputs = "serving_default_boxes:0,serving_default_scores:0", outputs = "PartitionedCall:1,PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tosa.const"() {value = dense<5> : tensor} : () -> tensor + %1 = "tosa.const"() {value = dense<5.000000e-01> : tensor} : () -> tensor + %2 = "tosa.const"() {value = dense<1.000000e-01> : tensor} : () -> tensor + %3 = "tosa.const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %4 = "tosa.reshape"(%arg0) {new_shape = [6, 4]} : (tensor<1x6x4xf32>) -> tensor<6x4xf32> + %5 = "tosa.reshape"(%arg1) {new_shape = [6]} : (tensor<1x6xf32>) -> tensor<6xf32> + %selected_indices, %selected_scores, %valid_outputs = "tfl.non_max_suppression_v5"(%4, %5, %0, %1, %2, %3) : (tensor<6x4xf32>, tensor<6xf32>, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor<*xi32>) + return %selected_scores, %selected_indices : tensor, tensor + } +} +""" + assert tosa_mlir_representation == expected_tosa_mlir_representation + + def test_tosa_compat_model_mlir_representation(self, compat_file): + checker = tosa_checker.TOSAChecker(model_path=compat_file) + tfl_mlir_representation = checker._get_mlir_model_representation( + elide_large_elements_attrs=True + ) + expected_mlir_representation = """\ +module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + func @main(%arg0: tensor {tf_saved_model.index_path = ["input_1"]}) -> (tensor {tf_saved_model.index_path = ["dense"]}) attributes {tf.entry_function = {inputs = "serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tfl.pseudo_const"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<8x16xf32>} : () -> tensor<8x16xf32> + %1 = "tfl.no_value"() {value} : () -> none + %2 = "tfl.fully_connected"(%arg0, %0, %1) {asymmetric_quantize_inputs = false, fused_activation_function = "RELU", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor, tensor<8x16xf32>, none) -> tensor + return %2 : tensor + } +} +""" + assert tfl_mlir_representation == expected_mlir_representation + + tosa_mlir_representation = checker._get_mlir_tosa_model_representation( + elide_large_elements_attrs=True + ) + expected_tosa_mlir_representation = """\ +module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + func @main(%arg0: tensor {tf_saved_model.index_path = ["input_1"]}) -> (tensor {tf_saved_model.index_path = ["dense"]}) attributes {tf.entry_function = {inputs = "serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tosa.const"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<8x16xf32>} : () -> tensor<8x16xf32> + %1 = "tosa.const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32> + %2 = "tosa.fully_connected"(%arg0, %0, %1) : (tensor, tensor<8x16xf32>, tensor<8xf32>) -> tensor + %3 = "tosa.clamp"(%2) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor + return %3 : tensor + } +} +""" + assert tosa_mlir_representation == expected_tosa_mlir_representation diff --git a/tosa_checker/BUILD b/tosa_checker/BUILD new file mode 100644 index 0000000..8c1c32d --- /dev/null +++ b/tosa_checker/BUILD @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +cc_library( + name = "tosa_checker_lib", + srcs = ["tosa_checker.cc"], + hdrs = ["tosa_checker.h"], + deps = [ + "@llvm-project//mlir:MlirTranslateMain", + "@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", + "@org_tensorflow//tensorflow/compiler/mlir/tosa:tfl_passes", + "@pybind11", + ], +) + +pybind_extension( + name = "_tosa_checker_wrapper", + srcs = [ + "tosa_checker_pybind11.cc", + ], + deps = [ + ":tosa_checker_lib", + ], +) + +py_library( + name = "tosa_checker", + srcs = [ + "__init__.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + data = ["//tosa_checker:_tosa_checker_wrapper.so"], +) diff --git a/tosa_checker/__init__.py b/tosa_checker/__init__.py new file mode 100644 index 0000000..ce76797 --- /dev/null +++ b/tosa_checker/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +"""the package provides a way to check if a TFLite model is compatible with the TOSA specification.""" + +from _tosa_checker_wrapper import * + +__version__ = "0.1.0" diff --git a/tosa_checker/tosa_checker.cc b/tosa_checker/tosa_checker.cc new file mode 100644 index 0000000..714cab3 --- /dev/null +++ b/tosa_checker/tosa_checker.cc @@ -0,0 +1,225 @@ +/* +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 +*/ +#include "tosa_checker.h" + +#include "absl/strings/string_view.h" +#include "llvm/Support/MemoryBuffer.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/tosa/tfl_passes.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace std { +template <> +struct hash { + std::size_t operator()(const mlir::Location &loc) const { + return mlir::hash_value(loc); + } +}; +} // namespace std + +namespace tosa_checker { + +TOSAChecker::TOSAChecker(const std::string &model_path) { + m_model = TFLiteFileToMLIR(model_path, &m_context); + m_tosa_model = m_model->clone(); + LegalizeTFLToTOSA(*m_tosa_model); +} + +bool TOSAChecker::IsTOSACompatible() { + bool is_tosa_compatible = true; + for (auto func : m_tosa_model->getOps()) { + func.walk([&](mlir::Operation *op) { + // Ignore func namespace + const mlir::Dialect *dialect = op->getDialect(); + if (!dialect || (!dialect->getNamespace().equals("tosa") && + !dialect->getNamespace().equals("func"))) { + is_tosa_compatible = false; + return mlir::WalkResult::interrupt(); + } + + return mlir::WalkResult::advance(); + }); + } + + return is_tosa_compatible; +} + +std::vector TOSAChecker::GetTOSACompatibilityForOps( + bool elide_large_attrs) { + // Get the locations of all the ops in the legalized model that were not + // converted during the TOSA legalization (i.e. the TOSA incompatible ones). + std::unordered_set tosa_incompatible_locs; + for (auto func : m_tosa_model->getOps()) { + func.walk([&](mlir::Operation *op) { + // Ignore func namespace + const mlir::Dialect *dialect = op->getDialect(); + if (!dialect || (!dialect->getNamespace().equals("tosa") && + !dialect->getNamespace().equals("func"))) { + tosa_incompatible_locs.insert(op->getLoc()); + } + }); + } + + // We assume that on legalization, the non-legalized ops keep their original + // location. If an op location from the original model is in + // tosa_incompatible_locs then the op is not tosa compatible, otherwise it is. + std::vector ops; + for (auto func : m_model->getOps()) { + func.walk([&](mlir::Operation *op) { + // Ignore func namespace + const mlir::Dialect *dialect = op->getDialect(); + if (!dialect || !dialect->getNamespace().equals("func")) { + const bool is_tosa_compatible = + tosa_incompatible_locs.find(op->getLoc()) == + tosa_incompatible_locs.end(); + ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs)); + } + }); + } + + return ops; +} + +std::vector TOSAChecker::GetUsedTOSAOps( + bool elide_large_attrs) { + std::vector tosa_ops; + for (mlir::Operation *op : GetTOSAOps(*m_tosa_model)) { + const bool is_tosa_compatible = true; + tosa_ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs)); + } + + return tosa_ops; +} + +std::string TOSAChecker::GetMLIRModelRepresentation(bool elide_large_attrs) { + return GetMLIRRepresentation(*m_model, elide_large_attrs); +} + +std::string TOSAChecker::GetMLIRTOSAModelRepresentation( + bool elide_large_attrs) { + return GetMLIRRepresentation(*m_tosa_model, elide_large_attrs); +} + +template +std::string TOSAChecker::GetMLIRRepresentation(T &&op) { + std::string value; + llvm::raw_string_ostream value_ostream(value); + + op.print(value_ostream); + + return value; +} + +template +std::string TOSAChecker::GetMLIRRepresentation(T &&op, bool elide_large_attrs) { + std::string value; + llvm::raw_string_ostream value_ostream(value); + + mlir::OpPrintingFlags flags; + if (elide_large_attrs) { + flags.elideLargeElementsAttrs(ELIDE_LARGE_ATTRS_LIMIT); + } + op.print(value_ostream, flags); + + return value; +} + +std::vector TOSAChecker::GetTOSAOps(mlir::ModuleOp model) { + std::vector tosa_ops; + for (auto func : model.getOps()) { + func.walk([&](mlir::Operation *op) { + const mlir::Dialect *dialect = op->getDialect(); + if (dialect && dialect->getNamespace().equals("tosa")) { + tosa_ops.push_back(op); + } + }); + } + + return tosa_ops; +} + +TOSAChecker::Operator TOSAChecker::ToOperator(mlir::Operation &op, + bool is_tosa_compatible, + bool elide_large_attrs) { + return Operator(op.getName().getStringRef().str(), + GetMLIRRepresentation(op.getLoc()), + GetAttributes(op, elide_large_attrs), is_tosa_compatible, + GetMLIRRepresentation(op, elide_large_attrs)); +} + +mlir::OwningOpRef TOSAChecker::TFLiteFileToMLIR( + const std::string &model_path, mlir::MLIRContext *context) { + std::string error_message; + std::unique_ptr input = + mlir::openInputFile(model_path, &error_message); + if (!input) { + throw std::runtime_error(error_message); + } + + const mlir::FileLineColLoc location = + mlir::FileLineColLoc::get(context, input->getBufferIdentifier(), 0, 0); + + auto mlir_module = tflite::FlatBufferToMlir( + absl::string_view(input->getBufferStart(), input->getBufferSize()), + context, location); + if (!mlir_module || mlir::failed(mlir::verify(*mlir_module))) { + throw std::runtime_error( + "Could not convert the TFLite model to its MLIR representation."); + } + + return mlir_module; +} + +void TOSAChecker::LegalizeTFLToTOSA(mlir::ModuleOp mlir_module) { + mlir::PassManager pm(mlir_module.getContext(), + mlir::OpPassManager::Nesting::Implicit); + mlir::tosa::TOSATFLLegalizationPipelineOptions opts; + mlir::tosa::createTFLtoTOSALegalizationPipeline(pm, opts); + // TODO Don't check for mlir::failed state for now due to some incoherences in + // how the legalization report non-convertible ops (sometimes with a hard + // fail, sometimes without). The legalization should not return a failed + // state if an operator can't be legalized and should leave it in its original + // dialect. + pm.run(mlir_module); +} + +std::map TOSAChecker::GetAttributes( + mlir::Operation &op, bool /*elide_large_attrs*/) { + std::map attributes; + for (const mlir::NamedAttribute &attr : op.getAttrs()) { + attributes.emplace(attr.getName().str(), + // TODO Check how to elide large attributes when + // converting them to string, mlir::Attribute::print has + // no mlir::OpPrintingFlags. + GetMLIRRepresentation(attr.getValue())); + } + + return attributes; +} + +} // namespace tosa_checker + +std::ostream &operator<<(std::ostream &os, + const tosa_checker::TOSAChecker::Operator &op) { + os << op.mlir_representation; + + return os; +} diff --git a/tosa_checker/tosa_checker.h b/tosa_checker/tosa_checker.h new file mode 100644 index 0000000..d7750ea --- /dev/null +++ b/tosa_checker/tosa_checker.h @@ -0,0 +1,82 @@ +/* +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 +*/ +#ifndef TOSA_CHECKER_H_ +#define TOSA_CHECKER_H_ + +#include +#include +#include +#include + +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/IR/MLIRContext.h" +#include "mlir/include/mlir/IR/OwningOpRef.h" + +namespace tosa_checker { + +class TOSAChecker { + public: + struct Operator { + Operator(std::string name, std::string location, + std::map attributes, + bool is_tosa_compatible, std::string mlir_representation) + : name(std::move(name)), + location(std::move(location)), + attributes(std::move(attributes)), + is_tosa_compatible(is_tosa_compatible), + mlir_representation(std::move(mlir_representation)) {} + + std::string name; + std::string location; + std::map attributes; + bool is_tosa_compatible; + std::string mlir_representation; + }; + + TOSAChecker(const std::string& model_path); + + bool IsTOSACompatible(); + + std::vector GetTOSACompatibilityForOps(bool elide_large_attrs); + + std::vector GetUsedTOSAOps(bool elide_large_attrs); + + std::string GetMLIRModelRepresentation(bool elide_large_attrs); + std::string GetMLIRTOSAModelRepresentation(bool elide_large_attrs); + + private: + template + static std::string GetMLIRRepresentation(T&& op); + + template + static std::string GetMLIRRepresentation(T&& op, bool elide_large_attrs); + + static std::vector GetTOSAOps(mlir::ModuleOp model); + + static Operator ToOperator(mlir::Operation& op, bool is_tosa_compatible, + bool elide_large_attrs); + + static mlir::OwningOpRef TFLiteFileToMLIR( + const std::string& model_path, mlir::MLIRContext* context); + + static void LegalizeTFLToTOSA(mlir::ModuleOp mlir_module); + + static std::map GetAttributes( + mlir::Operation& op, bool elide_large_attrs); + + private: + static constexpr std::int64_t ELIDE_LARGE_ATTRS_LIMIT = 16; + + mlir::MLIRContext m_context; + mlir::OwningOpRef m_model; + mlir::OwningOpRef m_tosa_model; +}; + +} // namespace tosa_checker + +std::ostream& operator<<(std::ostream& os, + const tosa_checker::TOSAChecker::Operator& op); + +#endif diff --git a/tosa_checker/tosa_checker_pybind11.cc b/tosa_checker/tosa_checker_pybind11.cc new file mode 100644 index 0000000..c799817 --- /dev/null +++ b/tosa_checker/tosa_checker_pybind11.cc @@ -0,0 +1,79 @@ +/* +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 +*/ +#include "tosa_checker.h" + +#include +#include +#include + +#include +#include + +PYBIND11_MODULE(_tosa_checker_wrapper, m) { + /** + * tosa_checker::TOSAChecker + */ + pybind11::class_ tosa_checker_class(m, + "TOSAChecker"); + tosa_checker_class.def(pybind11::init(), + pybind11::arg("model_path")); + + tosa_checker_class.def( + "is_tosa_compatible", + [](tosa_checker::TOSAChecker& tc) { return tc.IsTOSACompatible(); }, + "Check if a model is compatible with the TOSA specification"); + + tosa_checker_class.def( + "_get_tosa_compatibility_for_ops", + [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) { + return tc.GetTOSACompatibilityForOps(elide_large_elements_attrs); + }, + pybind11::arg("elide_large_elements_attrs") = false, + "Get all the operators of the models with a TOSA compatibility flag for " + "each operator"); + + tosa_checker_class.def( + "_get_used_tosa_ops", + [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) { + return tc.GetUsedTOSAOps(elide_large_elements_attrs); + }, + pybind11::arg("elide_large_elements_attrs") = false, + "Get the TOSA operators used by the model after its TOSA legalization"); + + tosa_checker_class.def( + "_get_mlir_model_representation", + [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) { + return tc.GetMLIRModelRepresentation(elide_large_elements_attrs); + }, + pybind11::arg("elide_large_elements_attrs") = false, + "Get the MLIR representation of the model"); + + tosa_checker_class.def( + "_get_mlir_tosa_model_representation", + [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) { + return tc.GetMLIRTOSAModelRepresentation(elide_large_elements_attrs); + }, + pybind11::arg("elide_large_elements_attrs") = false, + "Get the MLIR representation of the TOSA legalized model"); + + /** + * tosa_checker::TOSAChecker::Operator + */ + pybind11::class_(tosa_checker_class, + "_Operator") + .def_readonly("name", &tosa_checker::TOSAChecker::Operator::name) + .def_readonly("location", &tosa_checker::TOSAChecker::Operator::location) + .def_readonly("attributes", + &tosa_checker::TOSAChecker::Operator::attributes) + .def_readonly("is_tosa_compatible", + &tosa_checker::TOSAChecker::Operator::is_tosa_compatible) + .def_readonly("mlir_representation", + &tosa_checker::TOSAChecker::Operator::mlir_representation) + .def("__repr__", [](const tosa_checker::TOSAChecker::Operator& o) { + std::stringstream stream; + stream << o; + return stream.str(); + }); +} -- cgit v1.2.1