From a8ee1aee3e674c78a77801d1bf2256881ab6b4b9 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Thu, 21 Jul 2022 14:06:03 +0100 Subject: MLIA-549 Refactor API module to support several target profiles - Move target specific details out of API module - Move common logic for workflow event handler into a separate class Change-Id: Ic4a22657b722af1c1fead1d478f606ac57325788 --- tests/test_api.py | 17 +++++++++++++---- tests/test_core_events.py | 2 +- tests/test_core_reporting.py | 22 ++++++++++++++++++++++ tests/test_devices_ethosu_reporters.py | 4 ++-- 4 files changed, 38 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/test_api.py b/tests/test_api.py index 09bc509..e8df7af 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -7,14 +7,16 @@ from unittest.mock import MagicMock import pytest from mlia.api import get_advice +from mlia.api import get_advisor from mlia.core.common import AdviceCategory from mlia.core.context import Context from mlia.core.context import ExecutionContext +from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor def test_get_advice_no_target_provided(test_keras_model: Path) -> None: """Test getting advice when no target provided.""" - with pytest.raises(Exception, match="Target is not provided"): + with pytest.raises(Exception, match="Target profile is not provided"): get_advice(None, test_keras_model, "all") # type: ignore @@ -78,7 +80,7 @@ def test_get_advice( ) -> None: """Test getting advice with valid parameters.""" advisor_mock = MagicMock() - monkeypatch.setattr("mlia.api._get_advisor", MagicMock(return_value=advisor_mock)) + monkeypatch.setattr("mlia.api.get_advisor", MagicMock(return_value=advisor_mock)) get_advice( "ethos-u55-256", @@ -92,5 +94,12 @@ def test_get_advice( assert isinstance(context, Context) assert context.advice_category == expected_category - assert context.event_handlers is not None - assert context.config_parameters is not None + +def test_get_advisor( + test_keras_model: Path, +) -> None: + """Test function for getting the advisor.""" + ethos_u55_advisor = get_advisor( + ExecutionContext(), "ethos-u55-256", str(test_keras_model) + ) + assert isinstance(ethos_u55_advisor, EthosUInferenceAdvisor) diff --git a/tests/test_core_events.py b/tests/test_core_events.py index faaab7c..a531bab 100644 --- a/tests/test_core_events.py +++ b/tests/test_core_events.py @@ -18,7 +18,7 @@ from mlia.core.events import EventHandler from mlia.core.events import ExecutionFinishedEvent from mlia.core.events import ExecutionStartedEvent from mlia.core.events import stage -from mlia.core.events import SystemEventsHandler +from mlia.core.handlers import SystemEventsHandler @dataclass diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py index 2f7ec22..d7a6ade 100644 --- a/tests/test_core_reporting.py +++ b/tests/test_core_reporting.py @@ -2,9 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for reporting module.""" from typing import List +from typing import Optional import pytest +from mlia.core._typing import OutputFormat +from mlia.core._typing import PathOrFileLike from mlia.core.reporting import BytesCell from mlia.core.reporting import Cell from mlia.core.reporting import ClockCell @@ -13,6 +16,7 @@ from mlia.core.reporting import CyclesCell from mlia.core.reporting import Format from mlia.core.reporting import NestedReport from mlia.core.reporting import ReportItem +from mlia.core.reporting import resolve_output_format from mlia.core.reporting import SingleRow from mlia.core.reporting import Table from mlia.utils.console import remove_ascii_codes @@ -411,3 +415,21 @@ Single row example: alias="simple_row_example", ) wrong_single_row.to_plain_text() + + +@pytest.mark.parametrize( + "output, expected_output_format", + [ + [None, "plain_text"], + ["", "plain_text"], + ["some_file", "plain_text"], + ["some_format.some_ext", "plain_text"], + ["output.csv", "csv"], + ["output.json", "json"], + ], +) +def test_resolve_output_format( + output: Optional[PathOrFileLike], expected_output_format: OutputFormat +) -> None: + """Test function resolve_output_format.""" + assert resolve_output_format(output) == expected_output_format diff --git a/tests/test_devices_ethosu_reporters.py b/tests/test_devices_ethosu_reporters.py index 0da50e0..a63db1c 100644 --- a/tests/test_devices_ethosu_reporters.py +++ b/tests/test_devices_ethosu_reporters.py @@ -22,7 +22,7 @@ from mlia.devices.ethosu.config import EthosUConfiguration from mlia.devices.ethosu.performance import MemoryUsage from mlia.devices.ethosu.performance import NPUCycles from mlia.devices.ethosu.performance import PerformanceMetrics -from mlia.devices.ethosu.reporters import find_appropriate_formatter +from mlia.devices.ethosu.reporters import ethos_u_formatters from mlia.devices.ethosu.reporters import report_device_details from mlia.devices.ethosu.reporters import report_operators from mlia.devices.ethosu.reporters import report_perf_metrics @@ -410,7 +410,7 @@ def test_get_reporter(tmp_path: Path) -> None: ) output = tmp_path / "output.json" - with get_reporter("json", output, find_appropriate_formatter) as reporter: + with get_reporter("json", output, ethos_u_formatters) as reporter: assert isinstance(reporter, Reporter) with pytest.raises( -- cgit v1.2.1