diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-02-15 14:50:58 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-14 15:45:40 +0000 |
commit | 0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 (patch) | |
tree | 09b40b939acbe0bcf02dcc77a7ed7ce4aba94322 /tests/test_target_registry.py | |
parent | 09b272be6e88d84a30cb89fb71f3fc3c64d20d2e (diff) | |
download | mlia-0b552d2ae47da4fb9c16d2a59d6ebe12c8307771.tar.gz |
feat: Enable rewrite parameterisation
Enables user to provide a toml or default profile to change training settings for rewrite optimization
Resolves: MLIA-1004
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061
Diffstat (limited to 'tests/test_target_registry.py')
-rw-r--r-- | tests/test_target_registry.py | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py index ca1ad82..120d0f5 100644 --- a/tests/test_target_registry.py +++ b/tests/test_target_registry.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the target registry module.""" from __future__ import annotations @@ -6,9 +6,11 @@ from __future__ import annotations import pytest from mlia.core.common import AdviceCategory -from mlia.target.config import get_builtin_profile_path +from mlia.target.config import get_builtin_optimization_profile_path +from mlia.target.config import get_builtin_target_profile_path from mlia.target.registry import all_supported_backends from mlia.target.registry import default_backends +from mlia.target.registry import get_optimization_profile from mlia.target.registry import is_supported from mlia.target.registry import profile from mlia.target.registry import registry @@ -146,6 +148,27 @@ def test_profile(target_profile: str) -> None: assert target_profile.startswith(cfg.target) # Test loading the file directly - profile_file = get_builtin_profile_path(target_profile) + profile_file = get_builtin_target_profile_path(target_profile) cfg = profile(profile_file) assert target_profile.startswith(cfg.target) + + +@pytest.mark.parametrize("optimization_profile", ["optimization"]) +def test_optimization_profile(optimization_profile: str) -> None: + """Test function optimization_profile().""" + + get_optimization_profile(optimization_profile) + + profile_file = get_builtin_optimization_profile_path(optimization_profile) + get_optimization_profile(profile_file) + + +@pytest.mark.parametrize("optimization_profile", ["non_valid_file"]) +def test_optimization_profile_non_valid_file(optimization_profile: str) -> None: + """Test function optimization_profile().""" + with pytest.raises( + ValueError, + match=f"optimization Profile '{optimization_profile}' is neither " + "a valid built-in optimization profile name or a valid file path.", + ): + get_optimization_profile(optimization_profile) |