aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_api.py6
-rw-r--r--tests/test_backend_tosa_compat.py4
-rw-r--r--tests/test_cli_main.py22
-rw-r--r--tests/test_cli_options.py60
-rw-r--r--tests/test_core_context.py9
-rw-r--r--tests/test_core_logging.py (renamed from tests/test_cli_logging.py)17
-rw-r--r--tests/test_core_reporting.py105
-rw-r--r--tests/test_target_ethos_u_reporters.py124
8 files changed, 138 insertions, 209 deletions
diff --git a/tests/test_api.py b/tests/test_api.py
index 0bbc3ae..251d5ac 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -20,7 +20,11 @@ from mlia.target.tosa.advisor import TOSAInferenceAdvisor
def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
"""Test getting advice when no target provided."""
with pytest.raises(Exception, match="Target profile is not provided"):
- get_advice(None, test_keras_model, {"compatibility"}) # type: ignore
+ get_advice(
+ None, # type:ignore
+ test_keras_model,
+ {"compatibility"},
+ )
def test_get_advice_wrong_category(test_keras_model: Path) -> None:
diff --git a/tests/test_backend_tosa_compat.py b/tests/test_backend_tosa_compat.py
index 5a80b4b..0b6eaf5 100644
--- a/tests/test_backend_tosa_compat.py
+++ b/tests/test_backend_tosa_compat.py
@@ -27,7 +27,7 @@ def replace_get_tosa_checker_with_mock(
def test_compatibility_check_should_fail_if_checker_not_available(
- monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path
+ monkeypatch: pytest.MonkeyPatch, test_tflite_model: str | Path
) -> None:
"""Test that compatibility check should fail if TOSA checker is not available."""
replace_get_tosa_checker_with_mock(monkeypatch, None)
@@ -71,7 +71,7 @@ def test_compatibility_check_should_fail_if_checker_not_available(
)
def test_get_tosa_compatibility_info(
monkeypatch: pytest.MonkeyPatch,
- test_tflite_model: Path,
+ test_tflite_model: str | Path,
is_tosa_compatible: bool,
operators: Any,
exception: Exception | None,
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index 5a9c0c9..9db5341 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -5,7 +5,6 @@ from __future__ import annotations
import argparse
from functools import wraps
-from pathlib import Path
from typing import Any
from typing import Callable
from unittest.mock import ANY
@@ -122,8 +121,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.tflite",
compatibility=False,
performance=False,
- output=None,
- json=False,
backend=None,
),
],
@@ -135,8 +132,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.tflite",
compatibility=False,
performance=False,
- output=None,
- json=False,
backend=None,
),
],
@@ -153,8 +148,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- output=None,
- json=False,
compatibility=True,
performance=True,
backend=None,
@@ -167,9 +160,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
"--performance",
"--target-profile",
"ethos-u55-256",
- "--output",
- "result.json",
- "--json",
],
call(
ctx=ANY,
@@ -177,8 +167,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.h5",
performance=True,
compatibility=False,
- output=Path("result.json"),
- json=True,
backend=None,
),
],
@@ -196,8 +184,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.h5",
compatibility=False,
performance=True,
- output=None,
- json=False,
backend=None,
),
],
@@ -218,8 +204,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=True,
pruning_target=None,
clustering_target=None,
- output=None,
- json=False,
backend=None,
),
],
@@ -244,8 +228,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=True,
pruning_target=0.5,
clustering_target=32,
- output=None,
- json=False,
backend=None,
),
],
@@ -267,8 +249,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=False,
pruning_target=None,
clustering_target=None,
- output=None,
- json=False,
backend=["some_backend"],
),
],
@@ -286,8 +266,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.h5",
compatibility=True,
performance=False,
- output=None,
- json=False,
backend=None,
),
],
diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py
index a889a93..94c3111 100644
--- a/tests/test_cli_options.py
+++ b/tests/test_cli_options.py
@@ -5,16 +5,14 @@ from __future__ import annotations
import argparse
from contextlib import ExitStack as does_not_raise
-from pathlib import Path
from typing import Any
import pytest
-from mlia.cli.options import add_output_options
+from mlia.cli.options import get_output_format
from mlia.cli.options import get_target_profile_opts
from mlia.cli.options import parse_optimization_parameters
-from mlia.cli.options import parse_output_parameters
-from mlia.core.common import FormattedFilePath
+from mlia.core.typing import OutputFormat
@pytest.mark.parametrize(
@@ -164,54 +162,24 @@ def test_get_target_opts(args: dict | None, expected_opts: list[str]) -> None:
@pytest.mark.parametrize(
- "output_parameters, expected_path",
- [
- [["--output", "report.json"], "report.json"],
- [["--output", "REPORT.JSON"], "REPORT.JSON"],
- [["--output", "some_folder/report.json"], "some_folder/report.json"],
- ],
-)
-def test_output_options(output_parameters: list[str], expected_path: str) -> None:
- """Test output options resolving."""
- parser = argparse.ArgumentParser()
- add_output_options(parser)
-
- args = parser.parse_args(output_parameters)
- assert str(args.output) == expected_path
-
-
-@pytest.mark.parametrize(
- "path, json, expected_error, output",
+ "args, expected_output_format",
[
[
- None,
- True,
- pytest.raises(
- argparse.ArgumentError,
- match=r"To enable JSON output you need to specify the output path. "
- r"\(e.g. --output out.json --json\)",
- ),
- None,
+ {},
+ "plain_text",
],
- [None, False, does_not_raise(), None],
[
- Path("test_path"),
- False,
- does_not_raise(),
- FormattedFilePath(Path("test_path"), "plain_text"),
+ {"json": True},
+ "json",
],
[
- Path("test_path"),
- True,
- does_not_raise(),
- FormattedFilePath(Path("test_path"), "json"),
+ {"json": False},
+ "plain_text",
],
],
)
-def test_parse_output_parameters(
- path: Path | None, json: bool, expected_error: Any, output: FormattedFilePath | None
-) -> None:
- """Test parsing for output parameters."""
- with expected_error:
- formatted_output = parse_output_parameters(path, json)
- assert formatted_output == output
+def test_get_output_format(args: dict, expected_output_format: OutputFormat) -> None:
+ """Test get_output_format function."""
+ arguments = argparse.Namespace(**args)
+ output_format = get_output_format(arguments)
+ assert output_format == expected_output_format
diff --git a/tests/test_core_context.py b/tests/test_core_context.py
index dcdbef3..0e7145f 100644
--- a/tests/test_core_context.py
+++ b/tests/test_core_context.py
@@ -58,6 +58,7 @@ def test_execution_context(tmpdir: str) -> None:
verbose=True,
logs_dir="logs_directory",
models_dir="models_directory",
+ output_format="json",
)
assert context.advice_category == category
@@ -68,12 +69,14 @@ def test_execution_context(tmpdir: str) -> None:
expected_model_path = Path(tmpdir) / "models_directory/sample.model"
assert context.get_model_path("sample.model") == expected_model_path
assert context.verbose is True
+ assert context.output_format == "json"
assert str(context) == (
f"ExecutionContext: "
f"working_dir={tmpdir}, "
"advice_category={'COMPATIBILITY'}, "
"config_parameters={'param': 'value'}, "
- "verbose=True"
+ "verbose=True, "
+ "output_format=json"
)
context_with_default_params = ExecutionContext(working_dir=tmpdir)
@@ -88,11 +91,13 @@ def test_execution_context(tmpdir: str) -> None:
default_model_path = context_with_default_params.get_model_path("sample.model")
expected_default_model_path = Path(tmpdir) / "models/sample.model"
assert default_model_path == expected_default_model_path
+ assert context_with_default_params.output_format == "plain_text"
expected_str = (
f"ExecutionContext: working_dir={tmpdir}, "
"advice_category={'COMPATIBILITY'}, "
"config_parameters=None, "
- "verbose=False"
+ "verbose=False, "
+ "output_format=plain_text"
)
assert str(context_with_default_params) == expected_str
diff --git a/tests/test_cli_logging.py b/tests/test_core_logging.py
index 1e2cc85..e021e26 100644
--- a/tests/test_cli_logging.py
+++ b/tests/test_core_logging.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module cli.logging."""
from __future__ import annotations
@@ -8,7 +8,7 @@ from pathlib import Path
import pytest
-from mlia.cli.logging import setup_logging
+from mlia.core.logging import setup_logging
from tests.utils.logging import clear_loggers
@@ -33,20 +33,21 @@ def teardown_function() -> None:
(
None,
True,
- """mlia.backend.manager - backends debug
-cli info
-mlia.cli - cli debug
+ """mlia.backend.manager - DEBUG - backends debug
+mlia.cli - INFO - cli info
+mlia.cli - DEBUG - cli debug
""",
None,
),
(
"logs",
True,
- """mlia.backend.manager - backends debug
-cli info
-mlia.cli - cli debug
+ """mlia.backend.manager - DEBUG - backends debug
+mlia.cli - INFO - cli info
+mlia.cli - DEBUG - cli debug
""",
"""mlia.backend.manager - DEBUG - backends debug
+mlia.cli - INFO - cli info
mlia.cli - DEBUG - cli debug
""",
),
diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py
index 71eaf85..7a68d4b 100644
--- a/tests/test_core_reporting.py
+++ b/tests/test_core_reporting.py
@@ -3,9 +3,13 @@
"""Tests for reporting module."""
from __future__ import annotations
-import io
import json
from enum import Enum
+from unittest.mock import ANY
+from unittest.mock import call
+from unittest.mock import MagicMock
+from unittest.mock import Mock
+from unittest.mock import patch
import numpy as np
import pytest
@@ -14,13 +18,15 @@ from mlia.core.reporting import BytesCell
from mlia.core.reporting import Cell
from mlia.core.reporting import ClockCell
from mlia.core.reporting import Column
+from mlia.core.reporting import CustomJSONEncoder
from mlia.core.reporting import CyclesCell
from mlia.core.reporting import Format
-from mlia.core.reporting import json_reporter
+from mlia.core.reporting import JSONReporter
from mlia.core.reporting import NestedReport
from mlia.core.reporting import ReportItem
from mlia.core.reporting import SingleRow
from mlia.core.reporting import Table
+from mlia.core.reporting import TextReporter
from mlia.utils.console import remove_ascii_codes
@@ -364,10 +370,9 @@ def test_custom_json_serialization() -> None:
alias="sample_table",
)
- output = io.StringIO()
- json_reporter(table, output)
+ output = json.dumps(table.to_json(), indent=4, cls=CustomJSONEncoder)
- assert json.loads(output.getvalue()) == {
+ assert json.loads(output) == {
"sample_table": [
{"column1": "value1"},
{"column1": 10.0},
@@ -375,3 +380,93 @@ def test_custom_json_serialization() -> None:
{"column1": 10},
]
}
+
+
+class TestTextReporter:
+ """Test TextReporter methods."""
+
+ def test_text_reporter(self) -> None:
+ """Test TextReporter."""
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+ assert reporter.output_format == "plain_text"
+
+ def test_submit(self) -> None:
+ """Test TextReporter submit."""
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+ reporter.submit("test")
+ assert reporter.data == [("test", ANY)]
+
+ reporter.submit("test2", delay_print=True)
+ assert reporter.data == [("test", ANY), ("test2", ANY)]
+ assert reporter.delayed == [("test2", ANY)]
+
+ def test_print_delayed(self) -> None:
+ """Test TextReporter print_delayed."""
+ with patch(
+ "mlia.core.reporting.TextReporter.produce_report"
+ ) as mock_produce_report:
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+ reporter.submit("test", delay_print=True)
+ reporter.print_delayed()
+ assert reporter.data == [("test", ANY)]
+ assert not reporter.delayed
+ mock_produce_report.assert_called()
+
+ def test_produce_report(self) -> None:
+ """Test TextReporter produce_report."""
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+
+ with patch("mlia.core.reporting.logger") as mock_logger:
+ mock_formatter = MagicMock()
+ reporter.produce_report("test", mock_formatter)
+ mock_formatter.assert_has_calls([call("test"), call().to_plain_text()])
+ mock_logger.info.assert_called()
+
+
+class TestJSONReporter:
+ """Test JSONReporter methods."""
+
+ def test_text_reporter(self) -> None:
+ """Test JSONReporter."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+ assert reporter.output_format == "json"
+
+ def test_submit(self) -> None:
+ """Test JSONReporter submit."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+ reporter.submit("test")
+ assert reporter.data == [("test", ANY)]
+
+ reporter.submit("test2")
+ assert reporter.data == [("test", ANY), ("test2", ANY)]
+
+ def test_generate_report(self) -> None:
+ """Test JSONReporter generate_report."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+ reporter.submit("test")
+
+ with patch(
+ "mlia.core.reporting.JSONReporter.produce_report"
+ ) as mock_produce_report:
+ reporter.generate_report()
+ mock_produce_report.assert_called()
+
+ @patch("builtins.print")
+ def test_produce_report(self, mock_print: Mock) -> None:
+ """Test JSONReporter produce_report."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+
+ with patch("json.dumps") as mock_dumps:
+ mock_formatter = MagicMock()
+ reporter.produce_report("test", mock_formatter)
+ mock_formatter.assert_has_calls([call("test"), call().to_json()])
+ mock_dumps.assert_called()
+ mock_print.assert_called()
diff --git a/tests/test_target_ethos_u_reporters.py b/tests/test_target_ethos_u_reporters.py
index 7f372bf..ee7ea52 100644
--- a/tests/test_target_ethos_u_reporters.py
+++ b/tests/test_target_ethos_u_reporters.py
@@ -1,106 +1,21 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for reports module."""
from __future__ import annotations
-import json
-import sys
-from contextlib import ExitStack as doesnt_raise
-from pathlib import Path
-from typing import Any
-from typing import Callable
-from typing import Literal
-
import pytest
from mlia.backend.vela.compat import NpuSupported
from mlia.backend.vela.compat import Operator
-from mlia.backend.vela.compat import Operators
-from mlia.core.reporting import get_reporter
-from mlia.core.reporting import produce_report
from mlia.core.reporting import Report
-from mlia.core.reporting import Reporter
from mlia.core.reporting import Table
from mlia.target.ethos_u.config import EthosUConfiguration
-from mlia.target.ethos_u.performance import MemoryUsage
-from mlia.target.ethos_u.performance import NPUCycles
-from mlia.target.ethos_u.performance import PerformanceMetrics
-from mlia.target.ethos_u.reporters import ethos_u_formatters
from mlia.target.ethos_u.reporters import report_device_details
from mlia.target.ethos_u.reporters import report_operators
-from mlia.target.ethos_u.reporters import report_perf_metrics
from mlia.utils.console import remove_ascii_codes
@pytest.mark.parametrize(
- "data, formatters",
- [
- (
- [Operator("test_operator", "test_type", NpuSupported(False, []))],
- [report_operators],
- ),
- (
- PerformanceMetrics(
- EthosUConfiguration("ethos-u55-256"),
- NPUCycles(0, 0, 0, 0, 0, 0),
- MemoryUsage(0, 0, 0, 0, 0),
- ),
- [report_perf_metrics],
- ),
- ],
-)
-@pytest.mark.parametrize(
- "fmt, output, expected_error",
- [
- [
- "unknown_format",
- sys.stdout,
- pytest.raises(Exception, match="Unknown format unknown_format"),
- ],
- [
- "plain_text",
- sys.stdout,
- doesnt_raise(),
- ],
- [
- "json",
- sys.stdout,
- doesnt_raise(),
- ],
- [
- "plain_text",
- "report.txt",
- doesnt_raise(),
- ],
- [
- "json",
- "report.json",
- doesnt_raise(),
- ],
- ],
-)
-def test_report(
- data: Any,
- formatters: list[Callable],
- fmt: Literal["plain_text", "json"],
- output: Any,
- expected_error: Any,
- tmp_path: Path,
-) -> None:
- """Test report function."""
- if is_file := isinstance(output, str):
- output = tmp_path / output
-
- for formatter in formatters:
- with expected_error:
- produce_report(data, formatter, fmt, output)
-
- if is_file:
- assert output.is_file()
- assert output.stat().st_size > 0
-
-
-@pytest.mark.parametrize(
"ops, expected_plain_text, expected_json_dict",
[
(
@@ -314,40 +229,3 @@ def test_report_device_details(
json_dict = report.to_json()
assert json_dict == expected_json_dict
-
-
-def test_get_reporter(tmp_path: Path) -> None:
- """Test reporter functionality."""
- ops = Operators(
- [
- Operator(
- "npu_supported",
- "op_type",
- NpuSupported(True, []),
- ),
- ]
- )
-
- output = tmp_path / "output.json"
- with get_reporter("json", output, ethos_u_formatters) as reporter:
- assert isinstance(reporter, Reporter)
-
- with pytest.raises(
- Exception, match="Unable to find appropriate formatter for some_data"
- ):
- reporter.submit("some_data")
-
- reporter.submit(ops)
-
- with open(output, encoding="utf-8") as file:
- json_data = json.load(file)
-
- assert json_data == {
- "operators_stats": [
- {
- "npu_unsupported_ratio": 0.0,
- "num_of_npu_supported_operators": 1,
- "num_of_operators": 1,
- }
- ]
- }