diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-02-15 14:50:58 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-14 15:45:40 +0000 |
commit | 0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 (patch) | |
tree | 09b40b939acbe0bcf02dcc77a7ed7ce4aba94322 /src/mlia/api.py | |
parent | 09b272be6e88d84a30cb89fb71f3fc3c64d20d2e (diff) | |
download | mlia-0b552d2ae47da4fb9c16d2a59d6ebe12c8307771.tar.gz |
feat: Enable rewrite parameterisation
Enables user to provide a toml or default profile to change training settings for rewrite optimization
Resolves: MLIA-1004
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061
Diffstat (limited to 'src/mlia/api.py')
-rw-r--r-- | src/mlia/api.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py index 7adae48..3901f56 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for the API functions.""" from __future__ import annotations @@ -10,6 +10,7 @@ from typing import Any from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext +from mlia.target.registry import get_optimization_profile from mlia.target.registry import profile from mlia.target.registry import registry as target_registry @@ -20,6 +21,7 @@ def get_advice( target_profile: str, model: str | Path, category: set[str], + optimization_profile: str | None = None, optimization_targets: list[dict[str, Any]] | None = None, context: ExecutionContext | None = None, backends: list[str] | None = None, @@ -69,9 +71,9 @@ def get_advice( target_profile, model, optimization_targets=optimization_targets, + optimization_profile=optimization_profile, backends=backends, ) - advisor.run(context) @@ -82,6 +84,10 @@ def get_advisor( **extra_args: Any, ) -> InferenceAdvisor: """Find appropriate advisor for the target.""" + if extra_args.get("optimization_profile"): + extra_args["optimization_profile"] = get_optimization_profile( + extra_args["optimization_profile"] + ) target = profile(target_profile).target factory_function = target_registry.items[target].advisor_factory_func return factory_function( |