aboutsummaryrefslogtreecommitdiff
path: root/setup.py
diff options
context:
space:
mode:
Diffstat (limited to 'setup.py')
-rw-r--r--setup.py127
1 files changed, 127 insertions, 0 deletions
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..8ea8232
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,127 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import argparse
+import os
+import setuptools
+import setuptools.command.build_ext
+import shutil
+import sys
+
+
+TOSA_CHECKER_VERSION = "0.1.0"
+TENSORFLOW_VERSION = "2.9.0"
+
+# Get the TensorFlowâ„¢ source directory passed to the command line (if any).
+# If none is given the sources are pulled from the official TF repository.
+argparser = argparse.ArgumentParser()
+argparser.add_argument(
+ "--tensorflow_src_dir", help="TensorFlow source directory path", required=False
+)
+args, unknown = argparser.parse_known_args()
+sys.argv = [sys.argv[0]] + unknown
+
+
+class BazelExtensionModule(setuptools.Extension):
+ def __init__(
+ self,
+ py_module_name,
+ library_name,
+ bazel_target,
+ bazel_shared_lib_output,
+ tensorflow_version,
+ ):
+ super().__init__(py_module_name, sources=[])
+ self.library_name = library_name
+ self.bazel_target = bazel_target
+ self.bazel_shared_lib_output = bazel_shared_lib_output
+ self.tensorflow_version = tensorflow_version
+
+
+class BazelBuildExtension(setuptools.command.build_ext.build_ext):
+ """Override build_extension to build the library with bazel and copying it
+ beforehand."""
+
+ def build_extension(self, ext):
+ tensorflow_src_dir = args.tensorflow_src_dir
+ if not tensorflow_src_dir:
+ tensorflow_src_dir = os.path.join(self.build_temp, "tensorflow")
+ self._clone_tf_repository(
+ tensorflow_src_dir,
+ ext.tensorflow_version,
+ )
+
+ self.spawn(
+ [
+ "bazel",
+ "build",
+ "-c",
+ "opt",
+ # FIXME Some of the Bazel targets dependencies we use have
+ # a 'friends' visibility, check if our Bazel target can be added
+ # to the 'friends' list.
+ "--check_visibility=false",
+ "--override_repository=org_tensorflow="
+ + os.path.abspath(tensorflow_src_dir),
+ ext.bazel_target,
+ ]
+ )
+
+ shared_lib_dest_path = self.get_ext_fullpath(ext.name)
+ shared_lib_dest_dir = os.path.dirname(shared_lib_dest_path)
+ package_dir = os.path.join(shared_lib_dest_dir, ext.library_name)
+
+ os.makedirs(shared_lib_dest_dir, exist_ok=True)
+ os.makedirs(package_dir, exist_ok=True)
+
+ shutil.copyfile(ext.bazel_shared_lib_output, shared_lib_dest_path)
+ shutil.copy(os.path.join(ext.library_name, "__init__.py"), package_dir)
+
+ super().build_extension(ext)
+
+ def _clone_tf_repository(self, tensorflow_src_dir, tensorflow_version):
+ if os.path.exists(tensorflow_src_dir):
+ return
+
+ tensorflow_repo = "https://github.com/tensorflow/tensorflow.git"
+ self.spawn(
+ [
+ "git",
+ "clone",
+ "--depth=1",
+ "--branch",
+ "v" + tensorflow_version,
+ tensorflow_repo,
+ tensorflow_src_dir,
+ ]
+ )
+
+
+setuptools.setup(
+ name="tosa-checker",
+ version=TOSA_CHECKER_VERSION,
+ description="Tool to check if a ML model is compatible with the TOSA specification",
+ long_description="file: README.md",
+ long_description_content_type="text/markdown",
+ author="Arm Limited",
+ url="https://git.mlplatform.org/tosa/tosa_checker.git/",
+ license="Apache-2.0",
+ license_files="LICENSES/*",
+ python_requires=">=3.7",
+ cmdclass={"build_ext": BazelBuildExtension},
+ ext_modules=[
+ BazelExtensionModule(
+ py_module_name="_tosa_checker_wrapper",
+ library_name="tosa_checker",
+ bazel_target="//tosa_checker:tosa_checker",
+ bazel_shared_lib_output="bazel-bin/tosa_checker/_tosa_checker_wrapper.so",
+ tensorflow_version=TENSORFLOW_VERSION,
+ ),
+ ],
+ classifiers=[
+ "Development Status :: 3 - Alpha",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ "Topic :: Utilities",
+ ],
+)