aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Allsop <tom.allsop@arm.com>2023-02-02 15:42:12 +0000
committerTom Allsop <tom.allsop@arm.com>2023-02-02 15:42:12 +0000
commitf756c5179bafd98cc4e912de8c4177cc4c000519 (patch)
treeb089c00461bc677280ee53e212f98db48158b290
parent6f326467592d282a67e289bb2a2c9ef9c3755a70 (diff)
downloadtosa_checker-f756c5179bafd98cc4e912de8c4177cc4c000519.tar.gz
Add support for building nightly packages
* --nightly option is now added to setup.py * TensorFlow version is now recorded in module Change-Id: Id11c8cb1ee91cfe83a93340b7bcba2667dae62f0
-rw-r--r--docker/Dockerfile2
-rw-r--r--setup.py53
2 files changed, 49 insertions, 6 deletions
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 4481565..4a8a68c 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -16,7 +16,7 @@ ENV PYTHON_BIN_PATH=/usr/local/bin/python
ENV CI_BUILD_PYTHON=/usr/local/bin/python
ENV CROSSTOOL_PYTHON_INCLUDE_PATH=/usr/local/include/python/
-RUN pip install --no-cache-dir setuptools pybind11 numpy twine keyrings.alt
+RUN pip install --no-cache-dir setuptools pybind11 numpy twine keyrings.alt GitPython
COPY install/install_bazel.sh /install/
RUN /install/install_bazel.sh ${BAZEL_VERSION}
diff --git a/setup.py b/setup.py
index 7fc714e..07baede 100644
--- a/setup.py
+++ b/setup.py
@@ -1,16 +1,18 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import argparse
+import git
import os
import pathlib
import setuptools
import setuptools.command.build_ext
import shutil
import sys
+import time
TOSA_CHECKER_VERSION = "0.1.0"
-TENSORFLOW_VERSION = "2.9.0"
+DEFAULT_TENSORFLOW_VERSION = "2.9.0"
# Get the TensorFlowâ„¢ source directory passed to the command line (if any).
# If none is given the sources are pulled from the official TF repository.
@@ -26,13 +28,43 @@ argparser.add_argument(
)
argparser.add_argument(
"--tosa_checker_copt",
- help="Build tosa_checker with addtional copt (comma separated string)",
+ help="Build tosa_checker with additional copt (comma separated string)",
default="",
required=False
)
+argparser.add_argument(
+ "--nightly",
+ help="Build tosa_checker as a nightly wheel",
+ action="store_true"
+)
args, unknown = argparser.parse_known_args()
sys.argv = [sys.argv[0]] + unknown
+def increment_version(version_string):
+ version_numbers = version_string.split(".")
+ increment = int(version_numbers[1]) + 1
+
+ version_numbers[1] = str(increment)
+
+ return ".".join(version_numbers)
+
+def get_package_version(nightly=False):
+ if nightly:
+ return "{}.dev{}".format(
+ increment_version(TOSA_CHECKER_VERSION),
+ time.strftime("%Y%m%d")
+ )
+ else:
+ return TOSA_CHECKER_VERSION
+
+def get_repo_version(repo_directory):
+ r = git.repo.Repo(repo_directory)
+ tag = r.git.tag('--points-at')
+
+ if tag:
+ return tag
+ else:
+ return r.head.commit.hexsha
class BazelExtensionModule(setuptools.Extension):
def __init__(
@@ -56,6 +88,7 @@ class BazelBuildExtension(setuptools.command.build_ext.build_ext):
def build_extension(self, ext):
tensorflow_src_dir = args.tensorflow_src_dir
+
if not tensorflow_src_dir:
tensorflow_src_dir = os.path.join(self.build_temp, "tensorflow")
self._clone_tf_repository(
@@ -99,7 +132,17 @@ class BazelBuildExtension(setuptools.command.build_ext.build_ext):
os.makedirs(package_dir, exist_ok=True)
shutil.copyfile(ext.bazel_shared_lib_output, shared_lib_dest_path)
- shutil.copy(os.path.join(ext.library_name, "__init__.py"), package_dir)
+
+ # Get the TensorFlow version this is built with
+ tf_version = get_repo_version(tensorflow_src_dir)
+
+ with open(os.path.join(ext.library_name, "__init__.py"), "r") as f:
+ module_init_file = f.read()
+
+ module_init_file += "__tensorflow_version__ = \"{}\"\n".format(tf_version)
+
+ with open(os.path.join(package_dir, "__init__.py"), "w") as f:
+ f.write(module_init_file)
super().build_extension(ext)
@@ -123,7 +166,7 @@ class BazelBuildExtension(setuptools.command.build_ext.build_ext):
setuptools.setup(
name="tosa-checker",
- version=TOSA_CHECKER_VERSION,
+ version=get_package_version(args.nightly),
description="Tool to check if a ML model is compatible with the TOSA specification",
long_description=(pathlib.Path(__file__).parent / "README.md").read_text(),
long_description_content_type="text/markdown",
@@ -138,7 +181,7 @@ setuptools.setup(
library_name="tosa_checker",
bazel_target="//tosa_checker:tosa_checker",
bazel_shared_lib_output="bazel-bin/tosa_checker/_tosa_checker_wrapper.so",
- tensorflow_version=TENSORFLOW_VERSION,
+ tensorflow_version=DEFAULT_TENSORFLOW_VERSION,
),
],
classifiers=[