# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 import argparse import os import setuptools import setuptools.command.build_ext import shutil import sys TOSA_CHECKER_VERSION = "0.1.0" TENSORFLOW_VERSION = "2.9.0" # Get the TensorFlowâ„¢ source directory passed to the command line (if any). # If none is given the sources are pulled from the official TF repository. argparser = argparse.ArgumentParser() argparser.add_argument( "--tensorflow_src_dir", help="TensorFlow source directory path", required=False ) args, unknown = argparser.parse_known_args() sys.argv = [sys.argv[0]] + unknown class BazelExtensionModule(setuptools.Extension): def __init__( self, py_module_name, library_name, bazel_target, bazel_shared_lib_output, tensorflow_version, ): super().__init__(py_module_name, sources=[]) self.library_name = library_name self.bazel_target = bazel_target self.bazel_shared_lib_output = bazel_shared_lib_output self.tensorflow_version = tensorflow_version class BazelBuildExtension(setuptools.command.build_ext.build_ext): """Override build_extension to build the library with bazel and copying it beforehand.""" def build_extension(self, ext): tensorflow_src_dir = args.tensorflow_src_dir if not tensorflow_src_dir: tensorflow_src_dir = os.path.join(self.build_temp, "tensorflow") self._clone_tf_repository( tensorflow_src_dir, ext.tensorflow_version, ) self.spawn( [ "bazel", "build", "-c", "opt", # FIXME Some of the Bazel targets dependencies we use have # a 'friends' visibility, check if our Bazel target can be added # to the 'friends' list. "--check_visibility=false", "--override_repository=org_tensorflow=" + os.path.abspath(tensorflow_src_dir), ext.bazel_target, ] ) shared_lib_dest_path = self.get_ext_fullpath(ext.name) shared_lib_dest_dir = os.path.dirname(shared_lib_dest_path) package_dir = os.path.join(shared_lib_dest_dir, ext.library_name) os.makedirs(shared_lib_dest_dir, exist_ok=True) os.makedirs(package_dir, exist_ok=True) shutil.copyfile(ext.bazel_shared_lib_output, shared_lib_dest_path) shutil.copy(os.path.join(ext.library_name, "__init__.py"), package_dir) super().build_extension(ext) def _clone_tf_repository(self, tensorflow_src_dir, tensorflow_version): if os.path.exists(tensorflow_src_dir): return tensorflow_repo = "https://github.com/tensorflow/tensorflow.git" self.spawn( [ "git", "clone", "--depth=1", "--branch", "v" + tensorflow_version, tensorflow_repo, tensorflow_src_dir, ] ) setuptools.setup( name="tosa-checker", version=TOSA_CHECKER_VERSION, description="Tool to check if a ML model is compatible with the TOSA specification", long_description="file: README.md", long_description_content_type="text/markdown", author="Arm Limited", url="https://git.mlplatform.org/tosa/tosa_checker.git/", license="Apache-2.0", license_files="LICENSES/*", python_requires=">=3.7", cmdclass={"build_ext": BazelBuildExtension}, ext_modules=[ BazelExtensionModule( py_module_name="_tosa_checker_wrapper", library_name="tosa_checker", bazel_target="//tosa_checker:tosa_checker", bazel_shared_lib_output="bazel-bin/tosa_checker/_tosa_checker_wrapper.so", tensorflow_version=TENSORFLOW_VERSION, ), ], classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", "Topic :: Utilities", ], )