diff options
Diffstat (limited to 'python/pyarmnn/setup.py')
-rwxr-xr-x[-rw-r--r--] | python/pyarmnn/setup.py | 147 |
1 files changed, 109 insertions, 38 deletions
diff --git a/python/pyarmnn/setup.py b/python/pyarmnn/setup.py index 5f81088150..1c84e6152a 100644..100755 --- a/python/pyarmnn/setup.py +++ b/python/pyarmnn/setup.py @@ -1,8 +1,18 @@ +#!/usr/bin/env python3 # Copyright © 2020 Arm Ltd. All rights reserved. +# Copyright 2020 NXP # SPDX-License-Identifier: MIT +"""Python bindings for Arm NN + +PyArmNN is a python extension for Arm NN SDK providing an interface similar to Arm NN C++ API. +""" +__version__ = None +__arm_ml_version__ = None + import logging import os import sys +import subprocess from functools import lru_cache from pathlib import Path from itertools import chain @@ -14,20 +24,21 @@ from setuptools.command.build_ext import build_ext logger = logging.Logger(__name__) -__version__ = None -__arm_ml_version__ = None +DOCLINES = __doc__.split("\n") +LIB_ENV_NAME = "ARMNN_LIB" +INCLUDE_ENV_NAME = "ARMNN_INCLUDE" def check_armnn_version(*args): pass +__current_dir = os.path.dirname(os.path.realpath(__file__)) -exec(open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src', 'pyarmnn', '_version.py')).read()) +exec(open(os.path.join(__current_dir, 'src', 'pyarmnn', '_version.py')).read()) class ExtensionPriorityBuilder(build_py): - """ - Runs extension builder before other stages. Otherwise generated files are not included to the distribution. + """Runs extension builder before other stages. Otherwise generated files are not included to the distribution. """ def run(self): @@ -36,6 +47,8 @@ class ExtensionPriorityBuilder(build_py): class ArmnnVersionCheckerExtBuilder(build_ext): + """Builds an extension (i.e. wrapper). Additionally checks for version. + """ def __init__(self, dist): super().__init__(dist) @@ -60,49 +73,84 @@ class ArmnnVersionCheckerExtBuilder(build_ext): super().copy_extensions_to_source() -def linux_gcc_lib_search(): +def linux_gcc_name(): + """Returns the name of the `gcc` compiler. Might happen that we are cross-compiling and the + compiler has a longer name. + + Args: + None + + Returns: + str: Name of the `gcc` compiler or None """ - Calls the `gcc` to get linker default system paths. + cc_env = os.getenv('CC') + if cc_env is not None: + if subprocess.Popen([cc_env, "--version"], stdout=subprocess.DEVNULL): + return cc_env + return "gcc" if subprocess.Popen(["gcc", "--version"], stdout=subprocess.DEVNULL) else None + + +def linux_gcc_lib_search(gcc_compiler_name: str = linux_gcc_name()): + """Calls the `gcc` to get linker default system paths. + + Args: + gcc_compiler_name(str): Name of the GCC compiler + Returns: - list of paths + list: A list of paths. + + Raises: + RuntimeError: If unable to find GCC. """ - cmd = 'gcc --print-search-dirs | grep libraries' - cmd_res = os.popen(cmd).read() - cmd_res = cmd_res.split('=') - if len(cmd_res) > 1: - return tuple(cmd_res[1].split(':')) - return None + if gcc_compiler_name is None: + raise RuntimeError("Unable to find gcc compiler") + cmd1 = subprocess.Popen([gcc_compiler_name, "--print-search-dirs"], stdout=subprocess.PIPE) + cmd2 = subprocess.Popen(["grep", "libraries"], stdin=cmd1.stdout, + stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) + cmd1.stdout.close() + out, _ = cmd2.communicate() + out = out.decode("utf-8").split('=') + return tuple(out[1].split(':')) if len(out) > 0 else None + +def find_includes(armnn_include_env: str = INCLUDE_ENV_NAME): + """Searches for ArmNN includes. + + Args: + armnn_include_env(str): Environmental variable to use as path. -def find_includes(armnn_include_env: str = 'ARMNN_INCLUDE'): - armnn_include_path = os.getenv(armnn_include_env, '') - return [armnn_include_path] if armnn_include_path else ['/usr/local/include', '/usr/include'] + Returns: + list: A list of paths to include. + """ + armnn_include_path = os.getenv(armnn_include_env) + if armnn_include_path is not None and os.path.exists(armnn_include_path): + armnn_include_path = [armnn_include_path] + else: + armnn_include_path = ['/usr/local/include', '/usr/include'] + return armnn_include_path @lru_cache(maxsize=1) def find_armnn(lib_name: str, optional: bool = False, - armnn_libs_env: str = 'ARMNN_LIB', + armnn_libs_env: str = LIB_ENV_NAME, default_lib_search: tuple = linux_gcc_lib_search()): - """ - Searches for ArmNN installation on the local machine. + """Searches for ArmNN installation on the local machine. Args: - lib_name: lib name to find - optional: Do not fail if optional. Default is False - fail if library was not found. - armnn_include_env: custom environment variable pointing to ArmNN headers, default is 'ARMNN_INCLUDE' - armnn_libs_env: custom environment variable pointing to ArmNN libraries location, default is 'ARMNN_LIBS' - default_lib_search: list of paths to search for ArmNN if not found within path provided by 'ARMNN_LIBS' + lib_name(str): Lib name to find. + optional(bool): Do not fail if optional. Default is False - fail if library was not found. + armnn_libs_env(str): Custom environment variable pointing to ArmNN libraries location, default is 'ARMNN_LIBS' + default_lib_search(tuple): list of paths to search for ArmNN if not found within path provided by 'ARMNN_LIBS' env variable - Returns: - tuple containing name of the armnn libs, paths to the libs - """ - - armnn_lib_path = os.getenv(armnn_libs_env, "") - - lib_search = [armnn_lib_path] if armnn_lib_path else default_lib_search + tuple: Contains name of the armnn libs, paths to the libs. + Raises: + RuntimeError: If armnn libs are not found. + """ + armnn_lib_path = os.getenv(armnn_libs_env) + lib_search = [armnn_lib_path] if armnn_lib_path is not None else default_lib_search armnn_libs = dict(map(lambda path: (':{}'.format(path.name), path), chain.from_iterable(map(lambda lib_path: Path(lib_path).glob(lib_name), lib_search)))) @@ -117,8 +165,7 @@ def find_armnn(lib_name: str, class LazyArmnnFinderExtension(Extension): - """ - Derived from `Extension` this class adds ArmNN libraries search on the user's machine. + """Derived from `Extension` this class adds ArmNN libraries search on the user's machine. SWIG options and compilation flags are updated with relevant ArmNN libraries files locations (-L) and headers (-I). Search for ArmNN is executed only when attributes include_dirs, library_dirs, runtime_library_dirs, libraries or @@ -195,6 +242,7 @@ class LazyArmnnFinderExtension(Extension): def __hash__(self): return self.name.__hash__() + if __name__ == '__main__': # mandatory extensions pyarmnn_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn', @@ -232,11 +280,30 @@ if __name__ == '__main__': setup( name='pyarmnn', version=__version__, - author='Arm ltd', + author='Arm Ltd, NXP Semiconductors', author_email='support@linaro.org', - description='Arm NN python wrapper', - url='https://www.arm.com', + description=DOCLINES[0], + long_description="\n".join(DOCLINES[2:]), + url='https://mlplatform.org/', license='MIT', + keywords='armnn neural network machine learning', + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3 :: Only', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], package_dir={'': 'src'}, packages=[ 'pyarmnn', @@ -245,8 +312,12 @@ if __name__ == '__main__': 'pyarmnn._tensor', 'pyarmnn._utilities' ], + data_files=[('', ['LICENSE'])], python_requires='>=3.5', install_requires=['numpy'], - cmdclass={'build_py': ExtensionPriorityBuilder, 'build_ext': ArmnnVersionCheckerExtBuilder}, + cmdclass={ + 'build_py': ExtensionPriorityBuilder, + 'build_ext': ArmnnVersionCheckerExtBuilder + }, ext_modules=extensions_to_build ) |