diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-10-24 15:08:08 +0100 |
---|---|---|
committer | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-10-26 17:08:13 +0100 |
commit | 58a65fee574c00329cf92b387a6d2513dcbf6100 (patch) | |
tree | 47e3185f78b4298ab029785ddee68456e44cac10 /src/mlia/devices/cortexa/reporters.py | |
parent | 9d34cb72d45a6d0a2ec1063ebf32536c1efdba75 (diff) | |
download | mlia-58a65fee574c00329cf92b387a6d2513dcbf6100.tar.gz |
MLIA-433 Add TensorFlow Lite compatibility check
- Add ability to intercept low level TensorFlow output
- Produce advice for the models that could not be
converted to the TensorFlow Lite format
- Refactor utility functions for TensorFlow Lite
conversion
- Add TensorFlow Lite compatibility checker
Change-Id: I47d120d2619ced7b143bc92c5184515b81c0220d
Diffstat (limited to 'src/mlia/devices/cortexa/reporters.py')
-rw-r--r-- | src/mlia/devices/cortexa/reporters.py | 108 |
1 files changed, 102 insertions, 6 deletions
diff --git a/src/mlia/devices/cortexa/reporters.py b/src/mlia/devices/cortexa/reporters.py index 076b9ca..a55caba 100644 --- a/src/mlia/devices/cortexa/reporters.py +++ b/src/mlia/devices/cortexa/reporters.py @@ -7,25 +7,118 @@ from typing import Any from typing import Callable from mlia.core.advice_generation import Advice +from mlia.core.reporters import report_advice +from mlia.core.reporting import Cell +from mlia.core.reporting import Column +from mlia.core.reporting import Format +from mlia.core.reporting import NestedReport from mlia.core.reporting import Report +from mlia.core.reporting import ReportItem +from mlia.core.reporting import Table from mlia.devices.cortexa.config import CortexAConfiguration from mlia.devices.cortexa.operators import Operator +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.utils.console import style_improvement from mlia.utils.types import is_list_of def report_device(device: CortexAConfiguration) -> Report: """Generate report for the device.""" - raise NotImplementedError() + return NestedReport( + "Device information", + "device", + [ + ReportItem("Target", alias="target", value=device.target), + ], + ) -def report_advice(advice: list[Advice]) -> Report: - """Generate report for the advice.""" - raise NotImplementedError() +def report_tflite_compatiblity(compat_info: TFLiteCompatibilityInfo) -> Report: + """Generate report for the TensorFlow Lite compatibility information.""" + if compat_info.conversion_errors: + return Table( + [ + Column("#", only_for=["plain_text"]), + Column("Operator", alias="operator"), + Column( + "Operator location", + alias="operator_location", + fmt=Format(wrap_width=25), + ), + Column("Error code", alias="error_code"), + Column( + "Error message", alias="error_message", fmt=Format(wrap_width=25) + ), + ], + [ + ( + index + 1, + err.operator, + ", ".join(err.location), + err.code.name, + err.message, + ) + for index, err in enumerate(compat_info.conversion_errors) + ], + name="TensorFlow Lite conversion errors", + alias="tensorflow_lite_conversion_errors", + ) + return Table( + columns=[ + Column("Reason", alias="reason"), + Column( + "Exception details", + alias="exception_details", + fmt=Format(wrap_width=40), + ), + ], + rows=[ + ( + "TensorFlow Lite compatibility check failed with exception", + str(compat_info.conversion_exception), + ), + ], + name="TensorFlow Lite compatibility errors", + alias="tflite_compatibility", + ) -def report_cortex_a_operators(operators: list[Operator]) -> Report: + +def report_cortex_a_operators(ops: list[Operator]) -> Report: """Generate report for the operators.""" - raise NotImplementedError() + return Table( + [ + Column("#", only_for=["plain_text"]), + Column( + "Operator location", + alias="operator_location", + fmt=Format(wrap_width=30), + ), + Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)), + Column( + "Cortex-A compatibility", + alias="cortex_a_compatible", + fmt=Format(wrap_width=25), + ), + ], + [ + ( + index + 1, + op.location, + op.name, + Cell( + op.is_cortex_a_compatible, + Format( + style=style_improvement(op.is_cortex_a_compatible), + str_fmt=lambda v: "Compatible" if v else "Not compatible", + ), + ), + ) + for index, op in enumerate(ops) + ], + name="Operators", + alias="operators", + ) def cortex_a_formatters(data: Any) -> Callable[[Any], Report]: @@ -36,6 +129,9 @@ def cortex_a_formatters(data: Any) -> Callable[[Any], Report]: if isinstance(data, CortexAConfiguration): return report_device + if isinstance(data, TFLiteCompatibilityInfo): + return report_tflite_compatiblity + if is_list_of(data, Operator): return report_cortex_a_operators |