diff options
Diffstat (limited to 'src/mlia/target/tosa/reporters.py')
-rw-r--r-- | src/mlia/target/tosa/reporters.py | 93 |
1 files changed, 92 insertions, 1 deletions
diff --git a/src/mlia/target/tosa/reporters.py b/src/mlia/target/tosa/reporters.py index 01fbb97..283f61f 100644 --- a/src/mlia/target/tosa/reporters.py +++ b/src/mlia/target/tosa/reporters.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Reports module.""" from __future__ import annotations @@ -7,20 +7,41 @@ from typing import Any from typing import Callable from mlia.backend.tosa_checker.compat import Operator +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.advice_generation import Advice +from mlia.core.metadata import MLIAMetadata +from mlia.core.metadata import ModelMetadata from mlia.core.reporters import report_advice from mlia.core.reporting import Cell from mlia.core.reporting import Column +from mlia.core.reporting import CompoundReport from mlia.core.reporting import Format from mlia.core.reporting import NestedReport from mlia.core.reporting import Report from mlia.core.reporting import ReportItem from mlia.core.reporting import Table from mlia.target.tosa.config import TOSAConfiguration +from mlia.target.tosa.metadata import TOSAMetadata from mlia.utils.console import style_improvement from mlia.utils.types import is_list_of +class MetadataDisplay: # pylint: disable=too-few-public-methods + """TOSA metadata.""" + + def __init__( + self, + tosa_meta: TOSAMetadata, + mlia_meta: MLIAMetadata, + model_meta: ModelMetadata, + ) -> None: + """Init TOSAMetadata.""" + self.tosa_version = tosa_meta.version + self.mlia_version = mlia_meta.version + self.model_check_sum = model_meta.checksum + self.model_name = model_meta.model_name + + def report_device(device: TOSAConfiguration) -> Report: """Generate report for the device.""" return NestedReport( @@ -32,6 +53,34 @@ def report_device(device: TOSAConfiguration) -> Report: ) +def report_metadata(data: MetadataDisplay) -> Report: + """Generate report for the package version.""" + return NestedReport( + "Metadata", + "metadata", + [ + ReportItem( + "TOSA checker", + alias="tosa-checker", + nested_items=[ReportItem("version", "version", data.tosa_version)], + ), + ReportItem( + "MLIA", + "MLIA", + nested_items=[ReportItem("version", "version", data.mlia_version)], + ), + ReportItem( + "Model", + "Model", + nested_items=[ + ReportItem("name", "name", data.model_name), + ReportItem("checksum", "checksum", data.model_check_sum), + ], + ), + ], + ) + + def report_tosa_operators(ops: list[Operator]) -> Report: """Generate report for the operators.""" return Table( @@ -69,6 +118,42 @@ def report_tosa_operators(ops: list[Operator]) -> Report: ) +def report_tosa_exception(exc: Exception | None) -> Report: + """Generate report for exception thrown by tosa.""" + return NestedReport( + "TOSA exception", + "exception", + [ + ReportItem("TOSA exception", alias="exception", value=repr(exc)), + ], + ) + + +def report_tosa_errors(err: list[str] | None) -> Report: + """Generate report for errors thrown by tosa.""" + message = "".join(err) if err else None + return NestedReport( + "TOSA stderr", + "stderr", + [ + ReportItem( + "TOSA stderr", + alias="stderr", + value=message, + ), + ], + ) + + +def report_tosa_compatibility(compat_info: TOSACompatibilityInfo) -> Report: + """Generate combined report for all compatibility info.""" + report_ops = report_tosa_operators(compat_info.operators) + report_exception = report_tosa_exception(compat_info.exception) + + report_errors = report_tosa_errors(compat_info.errors) + return CompoundReport([report_ops, report_exception, report_errors]) + + def tosa_formatters(data: Any) -> Callable[[Any], Report]: """Find appropriate formatter for the provided data.""" if is_list_of(data, Advice): @@ -77,7 +162,13 @@ def tosa_formatters(data: Any) -> Callable[[Any], Report]: if isinstance(data, TOSAConfiguration): return report_device + if isinstance(data, MetadataDisplay): + return report_metadata + if is_list_of(data, Operator): return report_tosa_operators + if isinstance(data, TOSACompatibilityInfo): + return report_tosa_compatibility + raise Exception(f"Unable to find appropriate formatter for {data}") |