aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThibaut Goetghebuer-Planchon <thibaut.goetghebuer-planchon@arm.com>2022-07-06 10:23:22 +0100
committerThibaut Goetghebuer-Planchon <thibaut.goetghebuer-planchon@arm.com>2022-08-18 14:35:45 +0100
commit52dacd6556d60815253d4e4938e218ea3d8084a2 (patch)
tree4c470c567da6f70f65987d5af161bf4f950d107b
parentcc5d89eea4ff3dc398cac3b6025450f48ac20c1e (diff)
downloadtosa_checker-52dacd6556d60815253d4e4938e218ea3d8084a2.tar.gz
Initial commit0.1.0-rc.1
Change-Id: I2fb0933d595a6ede6417d09dd905ef72d6c60c9b
-rw-r--r--.bazelrc8
-rw-r--r--.bazelversion2
-rw-r--r--.bazelversion.license2
-rw-r--r--.gitignore17
-rw-r--r--LICENSES/Apache-2.0.txt177
-rw-r--r--MANIFEST.in4
-rw-r--r--README.md54
-rw-r--r--RELEASES.md20
-rw-r--r--WORKSPACE32
-rw-r--r--docker/Dockerfile26
-rw-r--r--docker/README.md52
-rwxr-xr-xdocker/install/install_bazel.sh24
-rw-r--r--examples/tosa_checker.ipynb117
-rw-r--r--examples/tosa_checker.ipynb.license3
-rw-r--r--setup.py127
-rw-r--r--tests/requirements.txt4
-rw-r--r--tests/test_tosa_checker.py216
-rw-r--r--tosa_checker/BUILD35
-rw-r--r--tosa_checker/__init__.py8
-rw-r--r--tosa_checker/tosa_checker.cc225
-rw-r--r--tosa_checker/tosa_checker.h82
-rw-r--r--tosa_checker/tosa_checker_pybind11.cc79
22 files changed, 1314 insertions, 0 deletions
diff --git a/.bazelrc b/.bazelrc
new file mode 100644
index 0000000..97fd61c
--- /dev/null
+++ b/.bazelrc
@@ -0,0 +1,8 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+# Flag to enable remote config
+common --experimental_repo_remote_exec
+common --cxxopt=-std=c++17
+common --host_cxxopt=-std=c++17
+common --copt=-w
diff --git a/.bazelversion b/.bazelversion
new file mode 100644
index 0000000..10803dc
--- /dev/null
+++ b/.bazelversion
@@ -0,0 +1,2 @@
+5.1.1
+# The version must be the same as the one in the TensorFlow version we build against
diff --git a/.bazelversion.license b/.bazelversion.license
new file mode 100644
index 0000000..487e9d8
--- /dev/null
+++ b/.bazelversion.license
@@ -0,0 +1,2 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..343775f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+*.egg-info
+*.pyc
+*~
+.eggs/
+\.coverage
+\.eggs
+__pycache__
+build/
+dist/
+*.egg-info
+bazel-bin/*
+bazel-out/*
+bazel-tosa_checker/*
+bazel-testlogs/*
+.pytest_cache/*
diff --git a/LICENSES/Apache-2.0.txt b/LICENSES/Apache-2.0.txt
new file mode 100644
index 0000000..f433b1a
--- /dev/null
+++ b/LICENSES/Apache-2.0.txt
@@ -0,0 +1,177 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..74d9656
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,4 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+include README.md
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..361ac8b
--- /dev/null
+++ b/README.md
@@ -0,0 +1,54 @@
+<!---
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+--->
+# TOSA Checker
+
+## Overview
+
+The TOSA Checker is a tool that provides an easy way to ensure that a TensorFlow™ Lite model is compatible with the [TOSA specification](https://www.mlplatform.org/tosa).
+
+The project is currently in alpha, some stability issues may still be found.
+
+## Prerequisites
+
+To build the TOSA Checker the following are required:
+
+* Ubuntu® 20.04.03 LTS (the TOSA Checker has been tested on this OS, but should work on other Linux® OS)
+* Bazel™ (Bazelisk is an easy way to install the correct version of Bazel™)
+* Git™
+* Python® >=3.7 (python, python-dev and python-pip)
+
+## Building and installation
+
+pip install -U pip setuptools wheel
+pip install -U numpy
+python setup.py bdist_wheel
+pip install dist/tosa_checker-0.1.0-cp38-cp38-linux_x86_64.whl
+
+## Docker™ image
+
+A Docker™ image to build the TOSA Checker is also provided. More information can be found on [docker/README.md](docker/README.md).
+
+## Usage
+
+```python
+>>> import tosa_checker as tc
+>>> checker = tc.TOSAChecker(model_path="model.tflite")
+>>> checker.is_tosa_compatible()
+True
+```
+
+## License
+
+TOSA Checker is licensed under [Apache License 2.0].
+
+## Trademarks and Copyrights
+
+Bazel™ is a trademark of Google® LLC.
+Git™ is a trademark of Software Freedom Conservancy.
+Linux® is the registered trademark of Linus Torvalds in the U.S. and elsewhere.
+Python® is a registered trademark of the PSF.
+Ubuntu® is a registered trademark of Canonical.
+Tensorflow™ is a trademark of Google® LLC.
+Docker™ is a trademark of Docker, Inc.
diff --git a/RELEASES.md b/RELEASES.md
new file mode 100644
index 0000000..ec666ee
--- /dev/null
+++ b/RELEASES.md
@@ -0,0 +1,20 @@
+<!---
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+--->
+# Release 0.1.0
+
+First release of the TOSA Checker tool. The goal of the tool is to provide an easy way to check if a TensorFlow™ Lite model is compatible with the [TOSA specification](https://www.mlplatform.org/tosa).
+
+The tool is provided as a Python® package and can be used as follow:
+
+```python
+>>> import tosa_checker as tc
+>>> checker = tc.TOSAChecker("model.tflite")
+>>> checker.is_tosa_compatible()
+True
+```
+
+The tool is currently in alpha, the features set is limited and some stability issues may exist.
+
+Future versions may extend the functionalities provided and support for other frameworks will be added.
diff --git a/WORKSPACE b/WORKSPACE
new file mode 100644
index 0000000..b8a0fba
--- /dev/null
+++ b/WORKSPACE
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+workspace(name = "tosa_checker")
+
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+http_archive(
+ name = "pybind11",
+ build_file = "@pybind11_bazel//:pybind11.BUILD",
+ strip_prefix = "pybind11-2.9.2",
+ urls = ["https://github.com/pybind/pybind11/archive/v2.9.2.tar.gz"],
+)
+http_archive(
+ name = "pybind11_bazel",
+ strip_prefix = "pybind11_bazel-72cbbf1fbc830e487e3012862b7b720001b70672",
+ urls = ["https://github.com/pybind/pybind11_bazel/archive/72cbbf1fbc830e487e3012862b7b720001b70672.zip"],
+)
+
+load("@pybind11_bazel//:python_configure.bzl", "python_configure")
+python_configure(name = "local_config_python")
+
+load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
+tf_workspace3()
+
+load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2")
+tf_workspace2()
+
+load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1")
+tf_workspace1()
+
+load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0")
+tf_workspace0()
+
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000..4481565
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,26 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+FROM quay.io/pypa/manylinux2014_x86_64
+
+ARG PYTHON_VERSION=3.9
+ARG TENSORFLOW_VERSION=2.9.0
+ARG BAZEL_VERSION=5.1.1
+
+RUN ln -s /opt/_internal/cpython-$PYTHON_VERSION*/bin/python3 /usr/local/bin/python3 && \
+ ln -s /opt/_internal/cpython-$PYTHON_VERSION*/bin/python /usr/local/bin/python && \
+ ln -s /opt/_internal/cpython-$PYTHON_VERSION*/bin/pip3 /usr/local/bin/pip3 && \
+ ln -s /opt/_internal/cpython-$PYTHON_VERSION*/bin/pip /usr/local/bin/pip && \
+ ln -s /opt/_internal/cpython-$PYTHON_VERSION*/include/python${PYTHON_VERSION}/ /usr/local/include/python
+
+ENV PYTHON_BIN_PATH=/usr/local/bin/python
+ENV CI_BUILD_PYTHON=/usr/local/bin/python
+ENV CROSSTOOL_PYTHON_INCLUDE_PATH=/usr/local/include/python/
+
+RUN pip install --no-cache-dir setuptools pybind11 numpy twine keyrings.alt
+
+COPY install/install_bazel.sh /install/
+RUN /install/install_bazel.sh ${BAZEL_VERSION}
+
+# CACHE_STOP is used to rerun future commands, otherwise the cloning will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+RUN git clone --depth=1 https://github.com/tensorflow/tensorflow.git --branch v${TENSORFLOW_VERSION} /tensorflow_src
diff --git a/docker/README.md b/docker/README.md
new file mode 100644
index 0000000..653b9a2
--- /dev/null
+++ b/docker/README.md
@@ -0,0 +1,52 @@
+<!---
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+--->
+# Docker™ image
+
+This directory contains different utilities to build/test the TOSA Checker.
+
+## How to build the TOSA Checker manylinux wheel with a Docker™ image
+
+To create a Docker™ image for the TOSA Checker to build it for Python® 3.9 on various Linux® distributions, please run the following command:
+
+```console
+docker build . -t tc-cp39-manylinux --build-arg PYTHON_VERSION=3.9 -f Dockerfile
+```
+
+The TensorFlow™ source code is automatically downloaded and is located in the `/tensorflow_src` directory.
+
+The command to run the container is:
+
+```console
+docker run -it -v <tosa_checker source code on your machine>:/tosa_checker tc-cp39-manylinux
+```
+
+Now call the following command to build a `tosa_checker` Python® wheel inside of the container:
+
+```console
+cd tosa_checker
+python3 setup.py --tensorflow_src_dir /tensorflow_src bdist_wheel
+```
+The `tosa_checker` wheel can be found in the `/dist` directory.
+
+Generate the new manylinux wheel from the `tosa_checker` wheel:
+```console
+auditwheel repair dist/<tosa_checker>.whl -w dist/
+```
+The `tosa_checker` manylinux wheel can now be found in the `/dist` directory.
+
+Install the `tosa_checker` manylinux wheel:
+```console
+pip install dist/<tosa_checker-manyliux>.whl
+```
+
+## Trademarks and Copyrights
+
+Python® is a registered trademark of the PSF.
+Linux® is the registered trademark of Linus Torvalds in the U.S. and other countries.
+Ubuntu® is a registered trademark of Canonical.
+TensorFlow™ is a trademark of Google® LLC.
+Docker™ is a trademark of Docker, Inc.
+
+
diff --git a/docker/install/install_bazel.sh b/docker/install/install_bazel.sh
new file mode 100755
index 0000000..687ed13
--- /dev/null
+++ b/docker/install/install_bazel.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+BAZEL_VERSION="$1"
+shift
+
+set +e
+local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
+
+if [[ "$local_bazel_ver" == "$BAZEL_VERSION" ]]; then
+ exit 0
+fi
+
+set -e
+
+# Install Bazel™
+mkdir -p /bazel
+cd /bazel
+if [[ ! -f "bazel-$BAZEL_VERSION-installer-linux-x86_64.sh" ]]; then
+ curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
+fi
+chmod +x /bazel/bazel-*.sh
+/bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
+rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
diff --git a/examples/tosa_checker.ipynb b/examples/tosa_checker.ipynb
new file mode 100644
index 0000000..8fe1f26
--- /dev/null
+++ b/examples/tosa_checker.ipynb
@@ -0,0 +1,117 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Import the modules needed to create a test model and run the TOSA Checker."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import tosa_checker as tc\n",
+ "import tensorflow as tf\n",
+ "import tempfile\n",
+ "import os"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create a simple model that is compatible with the TOSA specification."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "INFO:tensorflow:Assets written to: /tmp/tmpxc09cs65/assets\n"
+ ]
+ }
+ ],
+ "source": [
+ "input = tf.keras.layers.Input(shape=(16,))\n",
+ "x = tf.keras.layers.Dense(8, activation=\"relu\")(input)\n",
+ "model = tf.keras.models.Model(inputs=[input], outputs=x)\n",
+ "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
+ "tflite_model = converter.convert()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Save this model in `.tflite` format. Note that the TOSA Checker only accepts models in this format currently."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "_, tflite_file = tempfile.mkstemp('.tflite')\n",
+ "with open(tflite_file, \"wb\") as f:\n",
+ " f.write(tflite_model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the TOSA Checker to check this model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Is model TOSA compatible ? True\n"
+ ]
+ }
+ ],
+ "source": [
+ "checker = tc.TOSAChecker(model_path=tflite_file)\n",
+ "result = checker.is_tosa_compatible()\n",
+ "print(\"Is model TOSA compatible ? {}\".format(result))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.8.0 ('tosa_checker': venv)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.0"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/tosa_checker.ipynb.license b/examples/tosa_checker.ipynb.license
new file mode 100644
index 0000000..bd657a1
--- /dev/null
+++ b/examples/tosa_checker.ipynb.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..8ea8232
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,127 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import argparse
+import os
+import setuptools
+import setuptools.command.build_ext
+import shutil
+import sys
+
+
+TOSA_CHECKER_VERSION = "0.1.0"
+TENSORFLOW_VERSION = "2.9.0"
+
+# Get the TensorFlow™ source directory passed to the command line (if any).
+# If none is given the sources are pulled from the official TF repository.
+argparser = argparse.ArgumentParser()
+argparser.add_argument(
+ "--tensorflow_src_dir", help="TensorFlow source directory path", required=False
+)
+args, unknown = argparser.parse_known_args()
+sys.argv = [sys.argv[0]] + unknown
+
+
+class BazelExtensionModule(setuptools.Extension):
+ def __init__(
+ self,
+ py_module_name,
+ library_name,
+ bazel_target,
+ bazel_shared_lib_output,
+ tensorflow_version,
+ ):
+ super().__init__(py_module_name, sources=[])
+ self.library_name = library_name
+ self.bazel_target = bazel_target
+ self.bazel_shared_lib_output = bazel_shared_lib_output
+ self.tensorflow_version = tensorflow_version
+
+
+class BazelBuildExtension(setuptools.command.build_ext.build_ext):
+ """Override build_extension to build the library with bazel and copying it
+ beforehand."""
+
+ def build_extension(self, ext):
+ tensorflow_src_dir = args.tensorflow_src_dir
+ if not tensorflow_src_dir:
+ tensorflow_src_dir = os.path.join(self.build_temp, "tensorflow")
+ self._clone_tf_repository(
+ tensorflow_src_dir,
+ ext.tensorflow_version,
+ )
+
+ self.spawn(
+ [
+ "bazel",
+ "build",
+ "-c",
+ "opt",
+ # FIXME Some of the Bazel targets dependencies we use have
+ # a 'friends' visibility, check if our Bazel target can be added
+ # to the 'friends' list.
+ "--check_visibility=false",
+ "--override_repository=org_tensorflow="
+ + os.path.abspath(tensorflow_src_dir),
+ ext.bazel_target,
+ ]
+ )
+
+ shared_lib_dest_path = self.get_ext_fullpath(ext.name)
+ shared_lib_dest_dir = os.path.dirname(shared_lib_dest_path)
+ package_dir = os.path.join(shared_lib_dest_dir, ext.library_name)
+
+ os.makedirs(shared_lib_dest_dir, exist_ok=True)
+ os.makedirs(package_dir, exist_ok=True)
+
+ shutil.copyfile(ext.bazel_shared_lib_output, shared_lib_dest_path)
+ shutil.copy(os.path.join(ext.library_name, "__init__.py"), package_dir)
+
+ super().build_extension(ext)
+
+ def _clone_tf_repository(self, tensorflow_src_dir, tensorflow_version):
+ if os.path.exists(tensorflow_src_dir):
+ return
+
+ tensorflow_repo = "https://github.com/tensorflow/tensorflow.git"
+ self.spawn(
+ [
+ "git",
+ "clone",
+ "--depth=1",
+ "--branch",
+ "v" + tensorflow_version,
+ tensorflow_repo,
+ tensorflow_src_dir,
+ ]
+ )
+
+
+setuptools.setup(
+ name="tosa-checker",
+ version=TOSA_CHECKER_VERSION,
+ description="Tool to check if a ML model is compatible with the TOSA specification",
+ long_description="file: README.md",
+ long_description_content_type="text/markdown",
+ author="Arm Limited",
+ url="https://git.mlplatform.org/tosa/tosa_checker.git/",
+ license="Apache-2.0",
+ license_files="LICENSES/*",
+ python_requires=">=3.7",
+ cmdclass={"build_ext": BazelBuildExtension},
+ ext_modules=[
+ BazelExtensionModule(
+ py_module_name="_tosa_checker_wrapper",
+ library_name="tosa_checker",
+ bazel_target="//tosa_checker:tosa_checker",
+ bazel_shared_lib_output="bazel-bin/tosa_checker/_tosa_checker_wrapper.so",
+ tensorflow_version=TENSORFLOW_VERSION,
+ ),
+ ],
+ classifiers=[
+ "Development Status :: 3 - Alpha",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ "Topic :: Utilities",
+ ],
+)
diff --git a/tests/requirements.txt b/tests/requirements.txt
new file mode 100644
index 0000000..9984726
--- /dev/null
+++ b/tests/requirements.txt
@@ -0,0 +1,4 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+tensorflow==2.9.1
+pytest==7.1.2
diff --git a/tests/test_tosa_checker.py b/tests/test_tosa_checker.py
new file mode 100644
index 0000000..eb49e65
--- /dev/null
+++ b/tests/test_tosa_checker.py
@@ -0,0 +1,216 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+import pytest
+import tensorflow as tf
+import tempfile
+import tosa_checker
+
+
+@pytest.fixture(scope="module")
+def build_tosa_non_compat_model():
+ num_boxes = 6
+ max_output_size = 5
+ iou_threshold = 0.5
+ score_threshold = 0.1
+
+ def non_max_suppression(x):
+ boxes = x[0]
+ scores = x[1]
+ output = tf.image.non_max_suppression_with_scores(
+ boxes[0],
+ scores[0],
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ soft_nms_sigma=1.0,
+ )
+ return output
+
+ boxes_in = tf.keras.layers.Input(
+ shape=(num_boxes, 4), batch_size=1, dtype=tf.float32, name="boxes"
+ )
+ scores_in = tf.keras.layers.Input(
+ shape=(num_boxes), batch_size=1, dtype=tf.float32, name="scores"
+ )
+ outputs = tf.keras.layers.Lambda(non_max_suppression)([boxes_in, scores_in])
+ model = tf.keras.models.Model(inputs=[boxes_in, scores_in], outputs=outputs)
+
+ return model
+
+
+@pytest.fixture(scope="module")
+def build_tosa_compat_model():
+ input = tf.keras.layers.Input(shape=(16,))
+ x = tf.keras.layers.Dense(8, activation="relu")(input)
+ model = tf.keras.models.Model(inputs=[input], outputs=x)
+ return model
+
+
+def create_tflite(model):
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ tflite_model = converter.convert()
+ return tflite_model
+
+
+@pytest.fixture(scope="module")
+def non_compat_file(build_tosa_non_compat_model):
+ tflite_model = create_tflite(build_tosa_non_compat_model)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ file = os.path.join(tmp_dir, "test.tflite")
+ open(file, "wb").write(tflite_model)
+ yield file
+
+
+@pytest.fixture(scope="module")
+def compat_file(build_tosa_compat_model):
+ tflite_model = create_tflite(build_tosa_compat_model)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ file = os.path.join(tmp_dir, "test.tflite")
+ open(file, "wb").write(tflite_model)
+ yield file
+
+
+class TestTosaCompatibilityTool:
+ def test_bad_tflite_file(self):
+ make_bad_tfile = os.path.join(tempfile.mkdtemp(), "test.tflite")
+ open(make_bad_tfile, "wb").write("bad tflite file".encode("ASCII"))
+
+ with pytest.raises(RuntimeError):
+ checker = tosa_checker.TOSAChecker(model_path=make_bad_tfile)
+
+ def test_tosa_non_compat_model(self, non_compat_file):
+ checker = tosa_checker.TOSAChecker(model_path=non_compat_file)
+ tosa_compatible = checker.is_tosa_compatible()
+ assert tosa_compatible == False
+
+ ops = checker._get_tosa_compatibility_for_ops()
+ assert type(ops) == list
+ assert [[op.name, op.is_tosa_compatible] for op in ops] == [
+ ["tfl.pseudo_const", True],
+ ["tfl.pseudo_const", True],
+ ["tfl.pseudo_const", True],
+ ["tfl.strided_slice", True],
+ ["tfl.pseudo_const", True],
+ ["tfl.pseudo_const", True],
+ ["tfl.pseudo_const", True],
+ ["tfl.strided_slice", True],
+ ["tfl.pseudo_const", True],
+ ["tfl.pseudo_const", True],
+ ["tfl.pseudo_const", True],
+ ["tfl.pseudo_const", True],
+ ["tfl.non_max_suppression_v5", False],
+ ]
+
+ tosa_ops = checker._get_used_tosa_ops()
+ assert type(tosa_ops) == list
+ assert [[op.name, op.is_tosa_compatible] for op in tosa_ops] == [
+ ["tosa.const", True],
+ ["tosa.const", True],
+ ["tosa.const", True],
+ ["tosa.const", True],
+ ["tosa.reshape", True],
+ ["tosa.reshape", True],
+ ]
+
+ def test_tosa_compat_model(self, compat_file):
+ checker = tosa_checker.TOSAChecker(model_path=compat_file)
+ tosa_compatible = checker.is_tosa_compatible()
+ assert tosa_compatible == True
+
+ ops = checker._get_tosa_compatibility_for_ops()
+ assert type(ops) == list
+ assert [[op.name, op.is_tosa_compatible] for op in ops] == [
+ ["tfl.pseudo_const", True],
+ ["tfl.no_value", True],
+ ["tfl.fully_connected", True],
+ ]
+
+ tosa_ops = checker._get_used_tosa_ops()
+ assert type(tosa_ops) == list
+ assert [[op.name, op.is_tosa_compatible] for op in tosa_ops] == [
+ ["tosa.const", True],
+ ["tosa.const", True],
+ ["tosa.fully_connected", True],
+ ["tosa.clamp", True],
+ ]
+
+ def test_tosa_non_compat_model_mlir_representation(self, non_compat_file):
+ checker = tosa_checker.TOSAChecker(model_path=non_compat_file)
+
+ tfl_mlir_representation = checker._get_mlir_model_representation(
+ elide_large_elements_attrs=True
+ )
+ expected_mlir_representation = """\
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+ func @main(%arg0: tensor<1x6x4xf32> {tf_saved_model.index_path = ["boxes"]}, %arg1: tensor<1x6xf32> {tf_saved_model.index_path = ["scores"]}) -> (tensor<?xf32> {tf_saved_model.index_path = ["lambda_1"]}, tensor<?xi32> {tf_saved_model.index_path = ["lambda"]}) attributes {tf.entry_function = {inputs = "serving_default_boxes:0,serving_default_scores:0", outputs = "PartitionedCall:1,PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+ %0 = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32>
+ %1 = "tfl.pseudo_const"() {value = dense<[1, 6, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %2 = "tfl.pseudo_const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32>
+ %3 = "tfl.strided_slice"(%arg0, %0, %1, %2) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<1x6x4xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6x4xf32>
+ %4 = "tfl.pseudo_const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
+ %5 = "tfl.pseudo_const"() {value = dense<[1, 6]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %6 = "tfl.pseudo_const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32>
+ %7 = "tfl.strided_slice"(%arg1, %4, %5, %6) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 2 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<1x6xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6xf32>
+ %8 = "tfl.pseudo_const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
+ %9 = "tfl.pseudo_const"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
+ %10 = "tfl.pseudo_const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
+ %11 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+ %selected_indices, %selected_scores, %valid_outputs = "tfl.non_max_suppression_v5"(%3, %7, %8, %9, %10, %11) : (tensor<6x4xf32>, tensor<6xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<?xi32>, tensor<?xf32>, tensor<*xi32>)
+ return %selected_scores, %selected_indices : tensor<?xf32>, tensor<?xi32>
+ }
+}
+"""
+ assert tfl_mlir_representation == expected_mlir_representation
+
+ tosa_mlir_representation = checker._get_mlir_tosa_model_representation(
+ elide_large_elements_attrs=True
+ )
+ expected_tosa_mlir_representation = """\
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+ func @main(%arg0: tensor<1x6x4xf32> {tf_saved_model.index_path = ["boxes"]}, %arg1: tensor<1x6xf32> {tf_saved_model.index_path = ["scores"]}) -> (tensor<?xf32> {tf_saved_model.index_path = ["lambda_1"]}, tensor<?xi32> {tf_saved_model.index_path = ["lambda"]}) attributes {tf.entry_function = {inputs = "serving_default_boxes:0,serving_default_scores:0", outputs = "PartitionedCall:1,PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+ %0 = "tosa.const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
+ %1 = "tosa.const"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
+ %2 = "tosa.const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
+ %3 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+ %4 = "tosa.reshape"(%arg0) {new_shape = [6, 4]} : (tensor<1x6x4xf32>) -> tensor<6x4xf32>
+ %5 = "tosa.reshape"(%arg1) {new_shape = [6]} : (tensor<1x6xf32>) -> tensor<6xf32>
+ %selected_indices, %selected_scores, %valid_outputs = "tfl.non_max_suppression_v5"(%4, %5, %0, %1, %2, %3) : (tensor<6x4xf32>, tensor<6xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<?xi32>, tensor<?xf32>, tensor<*xi32>)
+ return %selected_scores, %selected_indices : tensor<?xf32>, tensor<?xi32>
+ }
+}
+"""
+ assert tosa_mlir_representation == expected_tosa_mlir_representation
+
+ def test_tosa_compat_model_mlir_representation(self, compat_file):
+ checker = tosa_checker.TOSAChecker(model_path=compat_file)
+ tfl_mlir_representation = checker._get_mlir_model_representation(
+ elide_large_elements_attrs=True
+ )
+ expected_mlir_representation = """\
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+ func @main(%arg0: tensor<?x16xf32> {tf_saved_model.index_path = ["input_1"]}) -> (tensor<?x8xf32> {tf_saved_model.index_path = ["dense"]}) attributes {tf.entry_function = {inputs = "serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+ %0 = "tfl.pseudo_const"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<8x16xf32>} : () -> tensor<8x16xf32>
+ %1 = "tfl.no_value"() {value} : () -> none
+ %2 = "tfl.fully_connected"(%arg0, %0, %1) {asymmetric_quantize_inputs = false, fused_activation_function = "RELU", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<?x16xf32>, tensor<8x16xf32>, none) -> tensor<?x8xf32>
+ return %2 : tensor<?x8xf32>
+ }
+}
+"""
+ assert tfl_mlir_representation == expected_mlir_representation
+
+ tosa_mlir_representation = checker._get_mlir_tosa_model_representation(
+ elide_large_elements_attrs=True
+ )
+ expected_tosa_mlir_representation = """\
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+ func @main(%arg0: tensor<?x16xf32> {tf_saved_model.index_path = ["input_1"]}) -> (tensor<?x8xf32> {tf_saved_model.index_path = ["dense"]}) attributes {tf.entry_function = {inputs = "serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+ %0 = "tosa.const"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<8x16xf32>} : () -> tensor<8x16xf32>
+ %1 = "tosa.const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32>
+ %2 = "tosa.fully_connected"(%arg0, %0, %1) : (tensor<?x16xf32>, tensor<8x16xf32>, tensor<8xf32>) -> tensor<?x8xf32>
+ %3 = "tosa.clamp"(%2) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x8xf32>) -> tensor<?x8xf32>
+ return %3 : tensor<?x8xf32>
+ }
+}
+"""
+ assert tosa_mlir_representation == expected_tosa_mlir_representation
diff --git a/tosa_checker/BUILD b/tosa_checker/BUILD
new file mode 100644
index 0000000..8c1c32d
--- /dev/null
+++ b/tosa_checker/BUILD
@@ -0,0 +1,35 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
+
+cc_library(
+ name = "tosa_checker_lib",
+ srcs = ["tosa_checker.cc"],
+ hdrs = ["tosa_checker.h"],
+ deps = [
+ "@llvm-project//mlir:MlirTranslateMain",
+ "@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
+ "@org_tensorflow//tensorflow/compiler/mlir/tosa:tfl_passes",
+ "@pybind11",
+ ],
+)
+
+pybind_extension(
+ name = "_tosa_checker_wrapper",
+ srcs = [
+ "tosa_checker_pybind11.cc",
+ ],
+ deps = [
+ ":tosa_checker_lib",
+ ],
+)
+
+py_library(
+ name = "tosa_checker",
+ srcs = [
+ "__init__.py",
+ ],
+ srcs_version = "PY3",
+ visibility = ["//visibility:public"],
+ data = ["//tosa_checker:_tosa_checker_wrapper.so"],
+)
diff --git a/tosa_checker/__init__.py b/tosa_checker/__init__.py
new file mode 100644
index 0000000..ce76797
--- /dev/null
+++ b/tosa_checker/__init__.py
@@ -0,0 +1,8 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+"""the package provides a way to check if a TFLite model is compatible with the TOSA specification."""
+
+from _tosa_checker_wrapper import *
+
+__version__ = "0.1.0"
diff --git a/tosa_checker/tosa_checker.cc b/tosa_checker/tosa_checker.cc
new file mode 100644
index 0000000..714cab3
--- /dev/null
+++ b/tosa_checker/tosa_checker.cc
@@ -0,0 +1,225 @@
+/*
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+*/
+#include "tosa_checker.h"
+
+#include "absl/strings/string_view.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
+#include "tensorflow/compiler/mlir/tosa/tfl_passes.h"
+
+#include <map>
+#include <memory>
+#include <optional>
+#include <stdexcept>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+namespace std {
+template <>
+struct hash<mlir::Location> {
+ std::size_t operator()(const mlir::Location &loc) const {
+ return mlir::hash_value(loc);
+ }
+};
+} // namespace std
+
+namespace tosa_checker {
+
+TOSAChecker::TOSAChecker(const std::string &model_path) {
+ m_model = TFLiteFileToMLIR(model_path, &m_context);
+ m_tosa_model = m_model->clone();
+ LegalizeTFLToTOSA(*m_tosa_model);
+}
+
+bool TOSAChecker::IsTOSACompatible() {
+ bool is_tosa_compatible = true;
+ for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) {
+ func.walk([&](mlir::Operation *op) {
+ // Ignore func namespace
+ const mlir::Dialect *dialect = op->getDialect();
+ if (!dialect || (!dialect->getNamespace().equals("tosa") &&
+ !dialect->getNamespace().equals("func"))) {
+ is_tosa_compatible = false;
+ return mlir::WalkResult::interrupt();
+ }
+
+ return mlir::WalkResult::advance();
+ });
+ }
+
+ return is_tosa_compatible;
+}
+
+std::vector<TOSAChecker::Operator> TOSAChecker::GetTOSACompatibilityForOps(
+ bool elide_large_attrs) {
+ // Get the locations of all the ops in the legalized model that were not
+ // converted during the TOSA legalization (i.e. the TOSA incompatible ones).
+ std::unordered_set<mlir::Location> tosa_incompatible_locs;
+ for (auto func : m_tosa_model->getOps<mlir::func::FuncOp>()) {
+ func.walk([&](mlir::Operation *op) {
+ // Ignore func namespace
+ const mlir::Dialect *dialect = op->getDialect();
+ if (!dialect || (!dialect->getNamespace().equals("tosa") &&
+ !dialect->getNamespace().equals("func"))) {
+ tosa_incompatible_locs.insert(op->getLoc());
+ }
+ });
+ }
+
+ // We assume that on legalization, the non-legalized ops keep their original
+ // location. If an op location from the original model is in
+ // tosa_incompatible_locs then the op is not tosa compatible, otherwise it is.
+ std::vector<Operator> ops;
+ for (auto func : m_model->getOps<mlir::func::FuncOp>()) {
+ func.walk([&](mlir::Operation *op) {
+ // Ignore func namespace
+ const mlir::Dialect *dialect = op->getDialect();
+ if (!dialect || !dialect->getNamespace().equals("func")) {
+ const bool is_tosa_compatible =
+ tosa_incompatible_locs.find(op->getLoc()) ==
+ tosa_incompatible_locs.end();
+ ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs));
+ }
+ });
+ }
+
+ return ops;
+}
+
+std::vector<TOSAChecker::Operator> TOSAChecker::GetUsedTOSAOps(
+ bool elide_large_attrs) {
+ std::vector<Operator> tosa_ops;
+ for (mlir::Operation *op : GetTOSAOps(*m_tosa_model)) {
+ const bool is_tosa_compatible = true;
+ tosa_ops.push_back(ToOperator(*op, is_tosa_compatible, elide_large_attrs));
+ }
+
+ return tosa_ops;
+}
+
+std::string TOSAChecker::GetMLIRModelRepresentation(bool elide_large_attrs) {
+ return GetMLIRRepresentation(*m_model, elide_large_attrs);
+}
+
+std::string TOSAChecker::GetMLIRTOSAModelRepresentation(
+ bool elide_large_attrs) {
+ return GetMLIRRepresentation(*m_tosa_model, elide_large_attrs);
+}
+
+template <typename T>
+std::string TOSAChecker::GetMLIRRepresentation(T &&op) {
+ std::string value;
+ llvm::raw_string_ostream value_ostream(value);
+
+ op.print(value_ostream);
+
+ return value;
+}
+
+template <typename T>
+std::string TOSAChecker::GetMLIRRepresentation(T &&op, bool elide_large_attrs) {
+ std::string value;
+ llvm::raw_string_ostream value_ostream(value);
+
+ mlir::OpPrintingFlags flags;
+ if (elide_large_attrs) {
+ flags.elideLargeElementsAttrs(ELIDE_LARGE_ATTRS_LIMIT);
+ }
+ op.print(value_ostream, flags);
+
+ return value;
+}
+
+std::vector<mlir::Operation *> TOSAChecker::GetTOSAOps(mlir::ModuleOp model) {
+ std::vector<mlir::Operation *> tosa_ops;
+ for (auto func : model.getOps<mlir::func::FuncOp>()) {
+ func.walk([&](mlir::Operation *op) {
+ const mlir::Dialect *dialect = op->getDialect();
+ if (dialect && dialect->getNamespace().equals("tosa")) {
+ tosa_ops.push_back(op);
+ }
+ });
+ }
+
+ return tosa_ops;
+}
+
+TOSAChecker::Operator TOSAChecker::ToOperator(mlir::Operation &op,
+ bool is_tosa_compatible,
+ bool elide_large_attrs) {
+ return Operator(op.getName().getStringRef().str(),
+ GetMLIRRepresentation(op.getLoc()),
+ GetAttributes(op, elide_large_attrs), is_tosa_compatible,
+ GetMLIRRepresentation(op, elide_large_attrs));
+}
+
+mlir::OwningOpRef<mlir::ModuleOp> TOSAChecker::TFLiteFileToMLIR(
+ const std::string &model_path, mlir::MLIRContext *context) {
+ std::string error_message;
+ std::unique_ptr<llvm::MemoryBuffer> input =
+ mlir::openInputFile(model_path, &error_message);
+ if (!input) {
+ throw std::runtime_error(error_message);
+ }
+
+ const mlir::FileLineColLoc location =
+ mlir::FileLineColLoc::get(context, input->getBufferIdentifier(), 0, 0);
+
+ auto mlir_module = tflite::FlatBufferToMlir(
+ absl::string_view(input->getBufferStart(), input->getBufferSize()),
+ context, location);
+ if (!mlir_module || mlir::failed(mlir::verify(*mlir_module))) {
+ throw std::runtime_error(
+ "Could not convert the TFLite model to its MLIR representation.");
+ }
+
+ return mlir_module;
+}
+
+void TOSAChecker::LegalizeTFLToTOSA(mlir::ModuleOp mlir_module) {
+ mlir::PassManager pm(mlir_module.getContext(),
+ mlir::OpPassManager::Nesting::Implicit);
+ mlir::tosa::TOSATFLLegalizationPipelineOptions opts;
+ mlir::tosa::createTFLtoTOSALegalizationPipeline(pm, opts);
+ // TODO Don't check for mlir::failed state for now due to some incoherences in
+ // how the legalization report non-convertible ops (sometimes with a hard
+ // fail, sometimes without). The legalization should not return a failed
+ // state if an operator can't be legalized and should leave it in its original
+ // dialect.
+ pm.run(mlir_module);
+}
+
+std::map<std::string, std::string> TOSAChecker::GetAttributes(
+ mlir::Operation &op, bool /*elide_large_attrs*/) {
+ std::map<std::string, std::string> attributes;
+ for (const mlir::NamedAttribute &attr : op.getAttrs()) {
+ attributes.emplace(attr.getName().str(),
+ // TODO Check how to elide large attributes when
+ // converting them to string, mlir::Attribute::print has
+ // no mlir::OpPrintingFlags.
+ GetMLIRRepresentation(attr.getValue()));
+ }
+
+ return attributes;
+}
+
+} // namespace tosa_checker
+
+std::ostream &operator<<(std::ostream &os,
+ const tosa_checker::TOSAChecker::Operator &op) {
+ os << op.mlir_representation;
+
+ return os;
+}
diff --git a/tosa_checker/tosa_checker.h b/tosa_checker/tosa_checker.h
new file mode 100644
index 0000000..d7750ea
--- /dev/null
+++ b/tosa_checker/tosa_checker.h
@@ -0,0 +1,82 @@
+/*
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+*/
+#ifndef TOSA_CHECKER_H_
+#define TOSA_CHECKER_H_
+
+#include <map>
+#include <optional>
+#include <string>
+#include <vector>
+
+#include "mlir/include/mlir/IR/BuiltinOps.h"
+#include "mlir/include/mlir/IR/MLIRContext.h"
+#include "mlir/include/mlir/IR/OwningOpRef.h"
+
+namespace tosa_checker {
+
+class TOSAChecker {
+ public:
+ struct Operator {
+ Operator(std::string name, std::string location,
+ std::map<std::string, std::string> attributes,
+ bool is_tosa_compatible, std::string mlir_representation)
+ : name(std::move(name)),
+ location(std::move(location)),
+ attributes(std::move(attributes)),
+ is_tosa_compatible(is_tosa_compatible),
+ mlir_representation(std::move(mlir_representation)) {}
+
+ std::string name;
+ std::string location;
+ std::map<std::string, std::string> attributes;
+ bool is_tosa_compatible;
+ std::string mlir_representation;
+ };
+
+ TOSAChecker(const std::string& model_path);
+
+ bool IsTOSACompatible();
+
+ std::vector<Operator> GetTOSACompatibilityForOps(bool elide_large_attrs);
+
+ std::vector<Operator> GetUsedTOSAOps(bool elide_large_attrs);
+
+ std::string GetMLIRModelRepresentation(bool elide_large_attrs);
+ std::string GetMLIRTOSAModelRepresentation(bool elide_large_attrs);
+
+ private:
+ template <typename T>
+ static std::string GetMLIRRepresentation(T&& op);
+
+ template <typename T>
+ static std::string GetMLIRRepresentation(T&& op, bool elide_large_attrs);
+
+ static std::vector<mlir::Operation*> GetTOSAOps(mlir::ModuleOp model);
+
+ static Operator ToOperator(mlir::Operation& op, bool is_tosa_compatible,
+ bool elide_large_attrs);
+
+ static mlir::OwningOpRef<mlir::ModuleOp> TFLiteFileToMLIR(
+ const std::string& model_path, mlir::MLIRContext* context);
+
+ static void LegalizeTFLToTOSA(mlir::ModuleOp mlir_module);
+
+ static std::map<std::string, std::string> GetAttributes(
+ mlir::Operation& op, bool elide_large_attrs);
+
+ private:
+ static constexpr std::int64_t ELIDE_LARGE_ATTRS_LIMIT = 16;
+
+ mlir::MLIRContext m_context;
+ mlir::OwningOpRef<mlir::ModuleOp> m_model;
+ mlir::OwningOpRef<mlir::ModuleOp> m_tosa_model;
+};
+
+} // namespace tosa_checker
+
+std::ostream& operator<<(std::ostream& os,
+ const tosa_checker::TOSAChecker::Operator& op);
+
+#endif
diff --git a/tosa_checker/tosa_checker_pybind11.cc b/tosa_checker/tosa_checker_pybind11.cc
new file mode 100644
index 0000000..c799817
--- /dev/null
+++ b/tosa_checker/tosa_checker_pybind11.cc
@@ -0,0 +1,79 @@
+/*
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+*/
+#include "tosa_checker.h"
+
+#include <optional>
+#include <sstream>
+#include <string>
+
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+
+PYBIND11_MODULE(_tosa_checker_wrapper, m) {
+ /**
+ * tosa_checker::TOSAChecker
+ */
+ pybind11::class_<tosa_checker::TOSAChecker> tosa_checker_class(m,
+ "TOSAChecker");
+ tosa_checker_class.def(pybind11::init<const std::string&>(),
+ pybind11::arg("model_path"));
+
+ tosa_checker_class.def(
+ "is_tosa_compatible",
+ [](tosa_checker::TOSAChecker& tc) { return tc.IsTOSACompatible(); },
+ "Check if a model is compatible with the TOSA specification");
+
+ tosa_checker_class.def(
+ "_get_tosa_compatibility_for_ops",
+ [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
+ return tc.GetTOSACompatibilityForOps(elide_large_elements_attrs);
+ },
+ pybind11::arg("elide_large_elements_attrs") = false,
+ "Get all the operators of the models with a TOSA compatibility flag for "
+ "each operator");
+
+ tosa_checker_class.def(
+ "_get_used_tosa_ops",
+ [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
+ return tc.GetUsedTOSAOps(elide_large_elements_attrs);
+ },
+ pybind11::arg("elide_large_elements_attrs") = false,
+ "Get the TOSA operators used by the model after its TOSA legalization");
+
+ tosa_checker_class.def(
+ "_get_mlir_model_representation",
+ [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
+ return tc.GetMLIRModelRepresentation(elide_large_elements_attrs);
+ },
+ pybind11::arg("elide_large_elements_attrs") = false,
+ "Get the MLIR representation of the model");
+
+ tosa_checker_class.def(
+ "_get_mlir_tosa_model_representation",
+ [](tosa_checker::TOSAChecker& tc, bool elide_large_elements_attrs) {
+ return tc.GetMLIRTOSAModelRepresentation(elide_large_elements_attrs);
+ },
+ pybind11::arg("elide_large_elements_attrs") = false,
+ "Get the MLIR representation of the TOSA legalized model");
+
+ /**
+ * tosa_checker::TOSAChecker::Operator
+ */
+ pybind11::class_<tosa_checker::TOSAChecker::Operator>(tosa_checker_class,
+ "_Operator")
+ .def_readonly("name", &tosa_checker::TOSAChecker::Operator::name)
+ .def_readonly("location", &tosa_checker::TOSAChecker::Operator::location)
+ .def_readonly("attributes",
+ &tosa_checker::TOSAChecker::Operator::attributes)
+ .def_readonly("is_tosa_compatible",
+ &tosa_checker::TOSAChecker::Operator::is_tosa_compatible)
+ .def_readonly("mlir_representation",
+ &tosa_checker::TOSAChecker::Operator::mlir_representation)
+ .def("__repr__", [](const tosa_checker::TOSAChecker::Operator& o) {
+ std::stringstream stream;
+ stream << o;
+ return stream.str();
+ });
+}