# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 import argparse import git import os import re import setuptools import setuptools.command.build_ext import shutil import sys import time TOSA_CHECKER_VERSION = "0.2.0" DEFAULT_TENSORFLOW_VERSION = "2.13.0" # Get the TensorFlowâ„¢ source directory passed to the command line (if any). # If none is given the sources are pulled from the official TF repository. argparser = argparse.ArgumentParser() argparser.add_argument( "--tensorflow_src_dir", help="TensorFlow source directory path", required=False ) argparser.add_argument( "--sanitizer", help="Build using a sanitizer (choose from asan or ubsan)", choices=["asan", "ubsan"], required=False ) argparser.add_argument( "--tosa_checker_copt", help="Build tosa_checker with additional copt (comma separated string)", default="", required=False ) argparser.add_argument( "--nightly", help="Build tosa_checker as a nightly wheel", action="store_true" ) args, unknown = argparser.parse_known_args() sys.argv = [sys.argv[0]] + unknown def increment_version(version_string): version_numbers = version_string.split(".") increment = int(version_numbers[1]) + 1 version_numbers[1] = str(increment) return ".".join(version_numbers) def get_package_version(nightly=False): if nightly: return "{}.dev{}".format( increment_version(TOSA_CHECKER_VERSION), time.strftime("%Y%m%d") ) else: return TOSA_CHECKER_VERSION def get_repo_version(repo_directory): r = git.repo.Repo(repo_directory) tag = r.git.tag('--points-at') if tag: return tag else: return r.head.commit.hexsha class BazelExtensionModule(setuptools.Extension): def __init__( self, py_module_name, library_name, bazel_target, bazel_shared_lib_output, tensorflow_version, ): super().__init__(py_module_name, sources=[]) self.library_name = library_name self.bazel_target = bazel_target self.bazel_shared_lib_output = bazel_shared_lib_output self.tensorflow_version = tensorflow_version class BazelBuildExtension(setuptools.command.build_ext.build_ext): """Override build_extension to build the library with bazel and copying it beforehand.""" def build_extension(self, ext): tensorflow_src_dir = args.tensorflow_src_dir if not tensorflow_src_dir: tensorflow_src_dir = os.path.join(self.build_temp, "tensorflow") self._clone_tf_repository( tensorflow_src_dir, ext.tensorflow_version, ) commands = [ "bazel", "build" ] if args.sanitizer: commands += [ "--config={}".format(args.sanitizer) ] if args.tosa_checker_copt: commands += [ "--per_file_copt=tosa_checker/tosa_checker.*@{}".format(args.tosa_checker_copt) ] commands += [ # FIXME Some of the Bazel targets dependencies we use have # a 'friends' visibility, check if our Bazel target can be added # to the 'friends' list. "-c", "opt", "--check_visibility=false", "--override_repository=org_tensorflow={}".format( os.path.abspath(tensorflow_src_dir) ), ext.bazel_target ] self.spawn(commands) shared_lib_dest_path = self.get_ext_fullpath(ext.name) shared_lib_dest_dir = os.path.dirname(shared_lib_dest_path) package_dir = os.path.join(shared_lib_dest_dir, ext.library_name) os.makedirs(shared_lib_dest_dir, exist_ok=True) os.makedirs(package_dir, exist_ok=True) shutil.copyfile(ext.bazel_shared_lib_output, shared_lib_dest_path) # Get the TensorFlow version this is built with tf_version = get_repo_version(tensorflow_src_dir) with open(os.path.join(ext.library_name, "__init__.py"), "r") as f: module_init_file = f.read() module_init_file += "__tensorflow_version__ = \"{}\"\n".format(tf_version) with open(os.path.join(package_dir, "__init__.py"), "w") as f: f.write(module_init_file) super().build_extension(ext) def _clone_tf_repository(self, tensorflow_src_dir, tensorflow_version): if os.path.exists(tensorflow_src_dir): return tensorflow_repo = "https://github.com/tensorflow/tensorflow.git" self.spawn( [ "git", "clone", "--depth=1", "--branch", "v" + tensorflow_version, tensorflow_repo, tensorflow_src_dir, ] ) def get_long_description(nightly): # Read the contents of README.md file this_directory = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: long_description = f.read() if not nightly: # Replace relative links to existing files with absolute links to https://review.mlplatform.org url = f"https://review.mlplatform.org/plugins/gitiles/tosa/tosa_checker/+/refs/tags/{TOSA_CHECKER_VERSION}/" # Find all markdown links that match the format: [text](link) for match, link in re.findall(r"(\[.+?\]\((.+?)\))", long_description): # If the link is a file that exists, replace it with the web link to the file instead if os.path.exists(os.path.join(this_directory, link)): url_link = re.sub(r"\((.+?)\)", rf"({url}{link})", match) long_description = long_description.replace(match, url_link) return long_description setuptools.setup( name="tosa-checker", version=get_package_version(args.nightly), description="Tool to check if a ML model is compatible with the TOSA specification", long_description=get_long_description(args.nightly), long_description_content_type="text/markdown", author="Arm Limited", url="https://git.mlplatform.org/tosa/tosa_checker.git/", license="Apache-2.0", python_requires=">=3.8", cmdclass={"build_ext": BazelBuildExtension}, ext_modules=[ BazelExtensionModule( py_module_name="_tosa_checker_wrapper", library_name="tosa_checker", bazel_target="//tosa_checker:tosa_checker", bazel_shared_lib_output="bazel-bin/tosa_checker/_tosa_checker_wrapper.so", tensorflow_version=DEFAULT_TENSORFLOW_VERSION, ), ], classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", "Topic :: Utilities", ], )