aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/api.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/api.py')
-rw-r--r--src/mlia/api.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index 7adae48..3901f56 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the API functions."""
from __future__ import annotations
@@ -10,6 +10,7 @@ from typing import Any
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
+from mlia.target.registry import get_optimization_profile
from mlia.target.registry import profile
from mlia.target.registry import registry as target_registry
@@ -20,6 +21,7 @@ def get_advice(
target_profile: str,
model: str | Path,
category: set[str],
+ optimization_profile: str | None = None,
optimization_targets: list[dict[str, Any]] | None = None,
context: ExecutionContext | None = None,
backends: list[str] | None = None,
@@ -69,9 +71,9 @@ def get_advice(
target_profile,
model,
optimization_targets=optimization_targets,
+ optimization_profile=optimization_profile,
backends=backends,
)
-
advisor.run(context)
@@ -82,6 +84,10 @@ def get_advisor(
**extra_args: Any,
) -> InferenceAdvisor:
"""Find appropriate advisor for the target."""
+ if extra_args.get("optimization_profile"):
+ extra_args["optimization_profile"] = get_optimization_profile(
+ extra_args["optimization_profile"]
+ )
target = profile(target_profile).target
factory_function = target_registry.items[target].advisor_factory_func
return factory_function(