# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for TOSA data analysis module.""" from __future__ import annotations import pytest from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityStatus from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode from mlia.target.common.reporters import ModelHasCustomOperators from mlia.target.common.reporters import ModelIsNotTFLiteCompatible from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed from mlia.target.tosa.data_analysis import ModelIsNotTOSACompatible from mlia.target.tosa.data_analysis import ModelIsTOSACompatible from mlia.target.tosa.data_analysis import TOSADataAnalyzer @pytest.mark.parametrize( "input_data, expected_facts", [ [ TOSACompatibilityInfo(True, []), [ModelIsTOSACompatible()], ], [ TOSACompatibilityInfo(False, []), [ModelIsNotTOSACompatible()], ], [ TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.COMPATIBLE), [], ], [ TFLiteCompatibilityInfo( status=TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR ), [ModelHasCustomOperators()], ], [ TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.UNKNOWN_ERROR), [TFLiteCompatibilityCheckFailed()], ], [ TFLiteCompatibilityInfo( status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR ), [ModelIsNotTFLiteCompatible(custom_ops=[], flex_ops=[])], ], [ TFLiteCompatibilityInfo( status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR, conversion_errors=[ TFLiteConversionError( "error", TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS, "custom_op1", [], ), TFLiteConversionError( "error", TFLiteConversionErrorCode.NEEDS_FLEX_OPS, "flex_op1", [], ), ], ), [ ModelIsNotTFLiteCompatible( custom_ops=["custom_op1"], flex_ops=["flex_op1"], ) ], ], ], ) def test_tosa_data_analyzer(input_data: DataItem, expected_facts: list[Fact]) -> None: """Test TOSA data analyzer.""" analyzer = TOSADataAnalyzer() analyzer.analyze_data(input_data) assert analyzer.get_analyzed_data() == expected_facts