1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
|
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Wrapper to only run Vela when the input is not already optimised."""
import shutil
import subprocess
from pathlib import Path
from typing import Tuple
import click
from aiet.cli.common import ModelOptimisedException
from aiet.resources.tools.vela.check_model import check_model
def vela_output_model_path(input_model: str, output_dir: str) -> Path:
"""Construct the path to the Vela output file."""
in_path = Path(input_model)
tflite_vela = Path(output_dir) / f"{in_path.stem}_vela{in_path.suffix}"
return tflite_vela
def execute_vela(vela_args: Tuple, output_dir: Path, input_model: str) -> None:
"""Execute vela as external call."""
cmd = ["vela"] + list(vela_args)
cmd += ["--output-dir", str(output_dir)] # Re-add parsed out_dir to arguments
cmd += [input_model]
subprocess.run(cmd, check=True)
@click.command(context_settings=dict(ignore_unknown_options=True))
@click.option(
"--input-model",
"-i",
type=click.Path(exists=True, file_okay=True, readable=True),
required=True,
)
@click.option("--output-model", "-o", type=click.Path(), required=True)
# Collect the remaining arguments to be directly forwarded to Vela
@click.argument("vela-args", nargs=-1, type=click.UNPROCESSED)
def run_vela(input_model: str, output_model: str, vela_args: Tuple) -> None:
"""Check input, run Vela (if needed) and copy optimised file to destination."""
output_dir = Path(output_model).parent
try:
check_model(input_model) # raises an exception if already Vela-optimised
execute_vela(vela_args, output_dir, input_model)
print("Vela optimisation complete.")
src_model = vela_output_model_path(input_model, str(output_dir))
except ModelOptimisedException as ex:
# Input already optimized: copy input file to destination path and return
print(f"Input already vela-optimised.\n{ex}")
src_model = Path(input_model)
except subprocess.CalledProcessError as ex:
print(ex)
raise SystemExit(ex.returncode) from ex
try:
shutil.copyfile(src_model, output_model)
except (shutil.SameFileError, OSError) as ex:
print(ex)
raise SystemExit(ex.errno) from ex
def main() -> None:
"""Entry point of check_model application."""
run_vela() # pylint: disable=no-value-for-parameter
|