diff options
Diffstat (limited to 'set_up_default_resources.py')
-rwxr-xr-x | set_up_default_resources.py | 219 |
1 files changed, 140 insertions, 79 deletions
diff --git a/set_up_default_resources.py b/set_up_default_resources.py index f5cd0ac..7ed9e97 100755 --- a/set_up_default_resources.py +++ b/set_up_default_resources.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates <open-source-office@arm.com> # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -31,7 +31,6 @@ import urllib.request import venv from argparse import ArgumentParser from argparse import ArgumentTypeError -from collections import namedtuple from dataclasses import dataclass from pathlib import Path from urllib.error import URLError @@ -56,18 +55,29 @@ valid_npu_config_names = [ # Default NPU configurations (these are always run when the models are optimised) default_npu_config_names = [valid_npu_config_names[2], valid_npu_config_names[4]] -# NPU config named tuple -NPUConfig = namedtuple( - "NPUConfig", - [ - "config_name", - "memory_mode", - "system_config", - "ethos_u_npu_id", - "ethos_u_config_id", - "arena_cache_size", - ], -) +# The internal SRAM size for Corstone-300 implementation on MPS3 specified by AN552 +# The internal SRAM size for Corstone-310 implementation on MPS3 specified by AN555 +# is 4MB, but we are content with the 2MB specified below. +MPS3_MAX_SRAM_SZ = 2 * 1024 * 1024 # 2 MiB (2 banks of 1 MiB each) + +default_use_case_resources_path = (Path(__file__).parent.resolve() + / 'scripts' / 'py' / 'use_case_resources.json') + +default_requirements_path = (Path(__file__).parent.resolve() + / 'scripts' / 'py' / 'requirements.txt') + + +@dataclass(frozen=True) +class NpuConfig: + """ + Represent an NPU configuration for Vela + """ + config_name: str + memory_mode: str + system_config: str + ethos_u_npu_id: str + ethos_u_config_id: str + arena_cache_size: str @dataclass(frozen=True) @@ -90,36 +100,70 @@ class UseCase: resources: typing.List[UseCaseResource] -# The internal SRAM size for Corstone-300 implementation on MPS3 specified by AN552 -# The internal SRAM size for Corstone-310 implementation on MPS3 specified by AN555 -# is 4MB, but we are content with the 2MB specified below. -MPS3_MAX_SRAM_SZ = 2 * 1024 * 1024 # 2 MiB (2 banks of 1 MiB each) - - -def load_use_case_resources(current_file_dir: Path) -> typing.List[UseCase]: +@dataclass(frozen=True) +class SetupArgs: + """ + Args used to set up the project. + + Attributes: + run_vela_on_models (bool) : Whether to run Vela on the downloaded models + additional_npu_config_names (list) : List of strings of Ethos-U NPU configs. + use_case_names (list) : List of names of use cases to set up resources for + (default is all). + arena_cache_size (int) : Specifies arena cache size in bytes. If a value + greater than 0 is provided, this will be taken + as the cache size. If 0, the default values, as per + the NPU config requirements, are used. + check_clean_folder (bool) : Indicates whether the resources folder needs to + be checked for updates and cleaned. + additional_requirements_file (str) : Path to a requirements.txt file if + additional packages need to be + installed. + use_case_resources_file (str) : Path to a JSON file containing the use case + metadata resources. + """ + run_vela_on_models: bool = False + additional_npu_config_names: typing.List[str] = () + use_case_names: typing.List[str] = () + arena_cache_size: int = 0 + check_clean_folder: bool = False + additional_requirements_file: Path = "" + use_case_resources_file: Path = "" + + +def load_use_case_resources( + use_case_resources_file: Path, + use_case_names: typing.List[str] = () +) -> typing.List[UseCase]: """ Load use case metadata resources Parameters ---------- - current_file_dir: Directory of the current script - + use_case_resources_file : Path to a JSON file containing the use case + metadata resources. + use_case_names : List of named use cases to restrict + resource loading to. Returns ------- The use cases resources object parsed to a dict """ - resources_path = current_file_dir / "scripts" / "py" / "use_case_resources.json" - with open(resources_path, encoding="utf8") as f: - use_cases = json.load(f) - return [ + with open(use_case_resources_file, encoding="utf8") as f: + parsed_use_cases = json.load(f) + use_cases = ( UseCase( name=u["name"], url_prefix=u["url_prefix"], resources=[UseCaseResource(**r) for r in u["resources"]], ) - for u in use_cases - ] + for u in parsed_use_cases + ) + + if len(use_case_names) == 0: + return list(use_cases) + + return [uc for uc in use_cases if uc.name in use_case_names] def call_command(command: str, verbose: bool = True) -> str: @@ -147,7 +191,7 @@ def call_command(command: str, verbose: bool = True) -> str: def get_default_npu_config_from_name( config_name: str, arena_cache_size: int = 0 -) -> typing.Optional[NPUConfig]: +) -> typing.Optional[NpuConfig]: """ Gets the file suffix for the TFLite file from the `accelerator_config` string. @@ -190,7 +234,7 @@ def get_default_npu_config_from_name( for i, string_id in enumerate(strings_ids): if config_name.startswith(string_id): npu_config_id = config_name.replace(string_id, prefix_ids[i]) - return NPUConfig( + return NpuConfig( config_name=config_name, memory_mode=memory_modes[i], system_config=system_configs[i], @@ -332,7 +376,7 @@ def download_resources( def run_vela( - config: NPUConfig, + config: NpuConfig, env_activate_cmd: str, model: Path, config_file: Path, @@ -555,56 +599,50 @@ def set_up_python_venv( def update_metadata( metadata_dict: typing.Dict, setup_script_hash: str, - json_uc_res: typing.List[UseCase], + use_case_resources: typing.List[UseCase], metadata_file_path: Path ): """ Update the metadata file - @param metadata_dict: The metadata dictionary to update - @param setup_script_hash: The setup script hash - @param json_uc_res: The use case resources metadata - @param metadata_file_path The metadata file path + @param metadata_dict : The metadata dictionary to update + @param setup_script_hash : The setup script hash + @param use_case_resources : The use case resources metadata + @param metadata_file_path : The metadata file path """ metadata_dict["ethosu_vela_version"] = VELA_VERSION metadata_dict["set_up_script_md5sum"] = setup_script_hash.strip("\n") - metadata_dict["resources_info"] = [dataclasses.asdict(uc) for uc in json_uc_res] + metadata_dict["resources_info"] = [dataclasses.asdict(uc) for uc in use_case_resources] with open(metadata_file_path, "w", encoding="utf8") as metadata_file: json.dump(metadata_dict, metadata_file, indent=4) -def set_up_resources( - run_vela_on_models: bool = False, - additional_npu_config_names: tuple = (), - arena_cache_size: int = 0, - check_clean_folder: bool = False, - additional_requirements_file: Path = "" -) -> Path: +def get_default_use_cases_names() -> typing.List[str]: + """ + Get the names of the default use cases + + :return : List of use case names as strings + """ + use_case_resources = load_use_case_resources(default_use_case_resources_path) + return [uc.name for uc in use_case_resources] + + +def set_up_resources(args: SetupArgs) -> Path: """ Helpers function that retrieve the output from a command. Parameters: ---------- - run_vela_on_models (bool): Specifies if run vela on downloaded models. - additional_npu_config_names(list): list of strings of Ethos-U NPU configs. - arena_cache_size (int): Specifies arena cache size in bytes. If a value - greater than 0 is provided, this will be taken - as the cache size. If 0, the default values, as per - the NPU config requirements, are used. - check_clean_folder (bool): Indicates whether the resources folder needs to - be checked for updates and cleaned. - additional_requirements_file (str): Path to a requirements.txt file if - additional packages need to be - installed. + args (SetupArgs) : Arguments used to set up the project. Returns ------- - Tuple of pair of Paths: (download_directory_path, virtual_env_path) + Tuple of pairs of Paths: (download_directory_path, virtual_env_path) - download_directory_path: Root of the directory where the resources have been downloaded to. - virtual_env_path: Path to the root of virtual environment. + download_directory_path : Root of the directory where the resources have been downloaded to. + virtual_env_path : Path to the root of virtual environment. """ # Paths. current_file_dir = Path(__file__).parent.resolve() @@ -619,29 +657,32 @@ def set_up_resources( ) logging.info("Using Python version: %s", sys.version_info) - json_uc_res = load_use_case_resources(current_file_dir) + use_case_resources = load_use_case_resources( + args.use_case_resources_file, + args.use_case_names + ) setup_script_hash = get_md5sum_for_file(Path(__file__).resolve()) metadata_dict, setup_script_hash_verified = initialize_resources_directory( download_dir, - check_clean_folder, + args.check_clean_folder, metadata_file_path, setup_script_hash ) env_path, env_activate = set_up_python_venv( download_dir, - additional_requirements_file + args.additional_requirements_file ) # 2. Download models logging.info("Downloading resources.") - for use_case in json_uc_res: + for use_case in use_case_resources: download_resources( use_case, metadata_dict, download_dir, - check_clean_folder, + args.check_clean_folder, setup_script_hash_verified ) @@ -653,14 +694,16 @@ def set_up_resources( # # Note: To avoid to run vela twice on the same model, it's supposed that # downloaded model names don't contain the 'vela' word. - if run_vela_on_models is True: + if args.run_vela_on_models is True: # Consolidate all config names while discarding duplicates: run_vela_on_all_models( current_file_dir, download_dir, env_activate, - arena_cache_size, - npu_config_names=list(set(default_npu_config_names + list(additional_npu_config_names))) + args.arena_cache_size, + npu_config_names=list( + set(default_npu_config_names + list(args.additional_npu_config_names)) + ) ) # 4. Collect and write metadata @@ -668,7 +711,7 @@ def set_up_resources( update_metadata( metadata_dict, setup_script_hash.strip("\n"), - json_uc_res, + use_case_resources, metadata_file_path ) @@ -690,6 +733,14 @@ if __name__ == "__main__": action="append", ) parser.add_argument( + "--use-case", + help=f"""Only set up resources for the specified use case (can specify multiple times). + Valid values are: {get_default_use_cases_names()} + """, + default=[], + action="append", + ) + parser.add_argument( "--arena-cache-size", help="Arena cache size in bytes (if overriding the defaults)", type=int, @@ -704,24 +755,34 @@ if __name__ == "__main__": "--requirements-file", help="Path to requirements.txt file to install additional packages", type=str, - default=Path(__file__).parent.resolve() / 'scripts' / 'py' / 'requirements.txt' + default=default_requirements_path + ) + parser.add_argument( + "--use-case-resources-file", + help="Path to the use case resources file", + type=str, + default=default_use_case_resources_path ) - args = parser.parse_args() + parsed_args = parser.parse_args() - if args.arena_cache_size < 0: + if parsed_args.arena_cache_size < 0: raise ArgumentTypeError("Arena cache size cannot not be less than 0") - if not Path(args.requirements_file).is_file(): - raise ArgumentTypeError(f"Invalid requirements file: {args.requirements_file}") + if not Path(parsed_args.requirements_file).is_file(): + raise ArgumentTypeError(f"Invalid requirements file: {parsed_args.requirements_file}") logging.basicConfig(filename="log_build_default.log", level=logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - set_up_resources( - not args.skip_vela, - args.additional_ethos_u_config_name, - args.arena_cache_size, - args.clean, - args.requirements_file, + setup_args = SetupArgs( + run_vela_on_models=not parsed_args.skip_vela, + additional_npu_config_names=parsed_args.additional_ethos_u_config_name, + use_case_names=parsed_args.use_case, + arena_cache_size=parsed_args.arena_cache_size, + check_clean_folder=parsed_args.clean, + additional_requirements_file=parsed_args.requirements_file, + use_case_resources_file=parsed_args.use_case_resources_file, ) + + set_up_resources(setup_args) |