From 52dacd6556d60815253d4e4938e218ea3d8084a2 Mon Sep 17 00:00:00 2001 From: Thibaut Goetghebuer-Planchon Date: Wed, 6 Jul 2022 10:23:22 +0100 Subject: Initial commit Change-Id: I2fb0933d595a6ede6417d09dd905ef72d6c60c9b --- setup.py | 127 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 setup.py (limited to 'setup.py') diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8ea8232 --- /dev/null +++ b/setup.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import argparse +import os +import setuptools +import setuptools.command.build_ext +import shutil +import sys + + +TOSA_CHECKER_VERSION = "0.1.0" +TENSORFLOW_VERSION = "2.9.0" + +# Get the TensorFlowâ„¢ source directory passed to the command line (if any). +# If none is given the sources are pulled from the official TF repository. +argparser = argparse.ArgumentParser() +argparser.add_argument( + "--tensorflow_src_dir", help="TensorFlow source directory path", required=False +) +args, unknown = argparser.parse_known_args() +sys.argv = [sys.argv[0]] + unknown + + +class BazelExtensionModule(setuptools.Extension): + def __init__( + self, + py_module_name, + library_name, + bazel_target, + bazel_shared_lib_output, + tensorflow_version, + ): + super().__init__(py_module_name, sources=[]) + self.library_name = library_name + self.bazel_target = bazel_target + self.bazel_shared_lib_output = bazel_shared_lib_output + self.tensorflow_version = tensorflow_version + + +class BazelBuildExtension(setuptools.command.build_ext.build_ext): + """Override build_extension to build the library with bazel and copying it + beforehand.""" + + def build_extension(self, ext): + tensorflow_src_dir = args.tensorflow_src_dir + if not tensorflow_src_dir: + tensorflow_src_dir = os.path.join(self.build_temp, "tensorflow") + self._clone_tf_repository( + tensorflow_src_dir, + ext.tensorflow_version, + ) + + self.spawn( + [ + "bazel", + "build", + "-c", + "opt", + # FIXME Some of the Bazel targets dependencies we use have + # a 'friends' visibility, check if our Bazel target can be added + # to the 'friends' list. + "--check_visibility=false", + "--override_repository=org_tensorflow=" + + os.path.abspath(tensorflow_src_dir), + ext.bazel_target, + ] + ) + + shared_lib_dest_path = self.get_ext_fullpath(ext.name) + shared_lib_dest_dir = os.path.dirname(shared_lib_dest_path) + package_dir = os.path.join(shared_lib_dest_dir, ext.library_name) + + os.makedirs(shared_lib_dest_dir, exist_ok=True) + os.makedirs(package_dir, exist_ok=True) + + shutil.copyfile(ext.bazel_shared_lib_output, shared_lib_dest_path) + shutil.copy(os.path.join(ext.library_name, "__init__.py"), package_dir) + + super().build_extension(ext) + + def _clone_tf_repository(self, tensorflow_src_dir, tensorflow_version): + if os.path.exists(tensorflow_src_dir): + return + + tensorflow_repo = "https://github.com/tensorflow/tensorflow.git" + self.spawn( + [ + "git", + "clone", + "--depth=1", + "--branch", + "v" + tensorflow_version, + tensorflow_repo, + tensorflow_src_dir, + ] + ) + + +setuptools.setup( + name="tosa-checker", + version=TOSA_CHECKER_VERSION, + description="Tool to check if a ML model is compatible with the TOSA specification", + long_description="file: README.md", + long_description_content_type="text/markdown", + author="Arm Limited", + url="https://git.mlplatform.org/tosa/tosa_checker.git/", + license="Apache-2.0", + license_files="LICENSES/*", + python_requires=">=3.7", + cmdclass={"build_ext": BazelBuildExtension}, + ext_modules=[ + BazelExtensionModule( + py_module_name="_tosa_checker_wrapper", + library_name="tosa_checker", + bazel_target="//tosa_checker:tosa_checker", + bazel_shared_lib_output="bazel-bin/tosa_checker/_tosa_checker_wrapper.so", + tensorflow_version=TENSORFLOW_VERSION, + ), + ], + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Topic :: Utilities", + ], +) -- cgit v1.2.1