# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module backend/manager."""
from __future__ import annotations
import base64
from pathlib import Path
from typing import Generator
from unittest.mock import MagicMock
import pytest
from mlia.backend.corstone.performance import build_corstone_command
from mlia.backend.corstone.performance import estimate_performance
from mlia.backend.corstone.performance import GenericInferenceOutputParser
from mlia.backend.corstone.performance import get_metrics
from mlia.backend.corstone.performance import PerformanceMetrics
from mlia.backend.errors import BackendExecutionFailed
from mlia.utils.proc import Command
def encode_b64(data: str) -> str:
"""Encode data in base64 format."""
return base64.b64encode(data.encode()).decode()
def valid_fvp_output() -> list[str]:
"""Return valid FVP output that could be succesfully parsed."""
json_data = """[
{
"profiling_group": "Inference",
"count": 1,
"samples": [
{"name": "NPU IDLE", "value": [2]},
{"name": "NPU AXI0_RD_DATA_BEAT_RECEIVED", "value": [4]},
{"name": "NPU AXI0_WR_DATA_BEAT_WRITTEN", "value": [5]},
{"name": "NPU AXI1_RD_DATA_BEAT_RECEIVED", "value": [6]},
{"name": "NPU ACTIVE", "value": [1]},
{"name": "NPU TOTAL", "value": [3]}
]
}
]"""
return [
"some output",
f"{encode_b64(json_data)}",
"some_output",
]
def test_generic_inference_output_parser_success() -> None:
"""Test successful generic inference output parsing."""
output_parser = GenericInferenceOutputParser()
for line in valid_fvp_output():
output_parser(line)
assert output_parser.get_metrics() == PerformanceMetrics(1, 2, 3, 4, 5, 6)
@pytest.mark.parametrize(
"wrong_fvp_output",
[
[],
["NPU IDLE: 123"],
["123"],
],
)
def test_generic_inference_output_parser_failure(wrong_fvp_output: list[str]) -> None:
"""Test unsuccessful generic inference output parsing."""
output_parser = GenericInferenceOutputParser()
for line in wrong_fvp_output:
output_parser(line)
with pytest.raises(ValueError, match="Unable to parse output and get metrics"):
output_parser.get_metrics()
@pytest.mark.parametrize(
"backend_path, fvp, target, mac, model, profile, expected_command",
[
[
Path("backend_path"),
"corstone-300",
"ethos-u55",
256,
Path("model.tflite"),
"default",
Command(
[
"backend_path/FVP_Corstone_SSE-300_Ethos-U55",
"-a",
"apps/backends/applications/"
"inference_runner-sse-300-22.08.02-ethos-U55-Default-noTA/"
"ethos-u-inference_runner.axf",
"--data",
"model.tflite@0x90000000",
"-C",
"ethosu.num_macs=256",
"-C",
"mps3_board.telnetterminal0.start_telnet=0",
"-C",
"mps3_board.uart0.out_file='-'",
"-C",
"mps3_board.uart0.shutdown_on_eot=1",
"-C",
"mps3_board.visualisation.disable-visualisation=1",
"--stat",
]
),
],
],
)
def test_build_corsone_command(
monkeypatch: pytest.MonkeyPatch,
backend_path: Path,
fvp: str,
target: str,
mac: int,
model: Path,
profile: str,
expected_command: Command,
) -> None:
"""Test function build_corstone_command."""
monkeypatch.setattr(
"mlia.backend.corstone.performance.get_mlia_resources", lambda: Path("apps")
)
command = build_corstone_command(backend_path, fvp, target, mac, model, profile)
assert command == expected_command
def test_get_metrics_wrong_fvp() -> None:
"""Test that command construction should fail for wrong FVP."""
with pytest.raises(
BackendExecutionFailed, match=r"Unable to construct a command line for some_fvp"
):
get_metrics(
Path("backend_path"),
"some_fvp",
"ethos-u55",
256,
Path("model.tflite"),
)
def test_estimate_performance(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test function estimate_performance."""
mock_repository = MagicMock()
mock_repository.get_backend_settings.return_value = Path("backend_path"), {
"profile": "default"
}
monkeypatch.setattr(
"mlia.backend.corstone.performance.get_backend_repository",
lambda: mock_repository,
)
def command_output_mock(_command: Command) -> Generator[str, None, None]:
"""Mock FVP output."""
yield from valid_fvp_output()
monkeypatch.setattr("mlia.utils.proc.command_output", command_output_mock)
result = estimate_performance(
"ethos-u55", 256, Path("model.tflite"), "corstone-300"
)
assert result == PerformanceMetrics(1, 2, 3, 4, 5, 6)
mock_repository.get_backend_settings.assert_called_once()