aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/resources/tools/vela/run_vela.py
blob: 2c1b0bec00257589ff3703143ddf16da7f38e5b5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Wrapper to only run Vela when the input is not already optimised."""
import shutil
import subprocess
from pathlib import Path
from typing import Tuple

import click

from aiet.cli.common import ModelOptimisedException
from aiet.resources.tools.vela.check_model import check_model


def vela_output_model_path(input_model: str, output_dir: str) -> Path:
    """Construct the path to the Vela output file."""
    in_path = Path(input_model)
    tflite_vela = Path(output_dir) / f"{in_path.stem}_vela{in_path.suffix}"
    return tflite_vela


def execute_vela(vela_args: Tuple, output_dir: Path, input_model: str) -> None:
    """Execute vela as external call."""
    cmd = ["vela"] + list(vela_args)
    cmd += ["--output-dir", str(output_dir)]  # Re-add parsed out_dir to arguments
    cmd += [input_model]
    subprocess.run(cmd, check=True)


@click.command(context_settings=dict(ignore_unknown_options=True))
@click.option(
    "--input-model",
    "-i",
    type=click.Path(exists=True, file_okay=True, readable=True),
    required=True,
)
@click.option("--output-model", "-o", type=click.Path(), required=True)
# Collect the remaining arguments to be directly forwarded to Vela
@click.argument("vela-args", nargs=-1, type=click.UNPROCESSED)
def run_vela(input_model: str, output_model: str, vela_args: Tuple) -> None:
    """Check input, run Vela (if needed) and copy optimised file to destination."""
    output_dir = Path(output_model).parent
    try:
        check_model(input_model)  # raises an exception if already Vela-optimised
        execute_vela(vela_args, output_dir, input_model)
        print("Vela optimisation complete.")
        src_model = vela_output_model_path(input_model, str(output_dir))
    except ModelOptimisedException as ex:
        # Input already optimized: copy input file to destination path and return
        print(f"Input already vela-optimised.\n{ex}")
        src_model = Path(input_model)
    except subprocess.CalledProcessError as ex:
        print(ex)
        raise SystemExit(ex.returncode) from ex

    try:
        shutil.copyfile(src_model, output_model)
    except (shutil.SameFileError, OSError) as ex:
        print(ex)
        raise SystemExit(ex.errno) from ex


def main() -> None:
    """Entry point of check_model application."""
    run_vela()  # pylint: disable=no-value-for-parameter