aboutsummaryrefslogtreecommitdiff
path: root/setup.py
blob: a62cd227b367544087105c5aa0b836643734588f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
import pathlib
import setuptools
import setuptools.command.build_ext
import shutil
import sys


TOSA_CHECKER_VERSION = "0.1.0"
TENSORFLOW_VERSION = "2.9.0"

# Get the TensorFlowâ„¢ source directory passed to the command line (if any).
# If none is given the sources are pulled from the official TF repository.
argparser = argparse.ArgumentParser()
argparser.add_argument(
    "--tensorflow_src_dir", help="TensorFlow source directory path", required=False
)
args, unknown = argparser.parse_known_args()
sys.argv = [sys.argv[0]] + unknown


class BazelExtensionModule(setuptools.Extension):
    def __init__(
        self,
        py_module_name,
        library_name,
        bazel_target,
        bazel_shared_lib_output,
        tensorflow_version,
    ):
        super().__init__(py_module_name, sources=[])
        self.library_name = library_name
        self.bazel_target = bazel_target
        self.bazel_shared_lib_output = bazel_shared_lib_output
        self.tensorflow_version = tensorflow_version


class BazelBuildExtension(setuptools.command.build_ext.build_ext):
    """Override build_extension to build the library with bazel and copying it
    beforehand."""

    def build_extension(self, ext):
        tensorflow_src_dir = args.tensorflow_src_dir
        if not tensorflow_src_dir:
            tensorflow_src_dir = os.path.join(self.build_temp, "tensorflow")
            self._clone_tf_repository(
                tensorflow_src_dir,
                ext.tensorflow_version,
            )

        self.spawn(
            [
                "bazel",
                "build",
                "-c",
                "opt",
                # FIXME Some of the Bazel targets dependencies we use have
                # a 'friends' visibility, check if our Bazel target can be added
                # to the 'friends' list.
                "--check_visibility=false",
                "--override_repository=org_tensorflow="
                + os.path.abspath(tensorflow_src_dir),
                ext.bazel_target,
            ]
        )

        shared_lib_dest_path = self.get_ext_fullpath(ext.name)
        shared_lib_dest_dir = os.path.dirname(shared_lib_dest_path)
        package_dir = os.path.join(shared_lib_dest_dir, ext.library_name)

        os.makedirs(shared_lib_dest_dir, exist_ok=True)
        os.makedirs(package_dir, exist_ok=True)

        shutil.copyfile(ext.bazel_shared_lib_output, shared_lib_dest_path)
        shutil.copy(os.path.join(ext.library_name, "__init__.py"), package_dir)

        super().build_extension(ext)

    def _clone_tf_repository(self, tensorflow_src_dir, tensorflow_version):
        if os.path.exists(tensorflow_src_dir):
            return

        tensorflow_repo = "https://github.com/tensorflow/tensorflow.git"
        self.spawn(
            [
                "git",
                "clone",
                "--depth=1",
                "--branch",
                "v" + tensorflow_version,
                tensorflow_repo,
                tensorflow_src_dir,
            ]
        )


setuptools.setup(
    name="tosa-checker",
    version=TOSA_CHECKER_VERSION,
    description="Tool to check if a ML model is compatible with the TOSA specification",
    long_description=(pathlib.Path(__file__).parent / "README.md").read_text(),
    long_description_content_type="text/markdown",
    author="Arm Limited",
    url="https://git.mlplatform.org/tosa/tosa_checker.git/",
    license="Apache-2.0",
    python_requires=">=3.7",
    cmdclass={"build_ext": BazelBuildExtension},
    ext_modules=[
        BazelExtensionModule(
            py_module_name="_tosa_checker_wrapper",
            library_name="tosa_checker",
            bazel_target="//tosa_checker:tosa_checker",
            bazel_shared_lib_output="bazel-bin/tosa_checker/_tosa_checker_wrapper.so",
            tensorflow_version=TENSORFLOW_VERSION,
        ),
    ],
    classifiers=[
        "Development Status :: 3 - Alpha",
        "Intended Audience :: Developers",
        "License :: OSI Approved :: Apache Software License",
        "Programming Language :: Python :: 3",
        "Topic :: Utilities",
    ],
)