diff options
Diffstat (limited to 'tests/aiet/test_check_model.py')
-rw-r--r-- | tests/aiet/test_check_model.py | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/tests/aiet/test_check_model.py b/tests/aiet/test_check_model.py new file mode 100644 index 0000000..4eafe59 --- /dev/null +++ b/tests/aiet/test_check_model.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=redefined-outer-name,no-self-use +"""Module for testing check_model.py script.""" +from pathlib import Path +from typing import Any + +import pytest +from ethosu.vela.tflite.Model import Model +from ethosu.vela.tflite.OperatorCode import OperatorCode + +from aiet.cli.common import InvalidTFLiteFileError +from aiet.cli.common import ModelOptimisedException +from aiet.resources.tools.vela.check_model import check_custom_codes_for_ethosu +from aiet.resources.tools.vela.check_model import check_model +from aiet.resources.tools.vela.check_model import get_custom_codes_from_operators +from aiet.resources.tools.vela.check_model import get_model_from_file +from aiet.resources.tools.vela.check_model import get_operators_from_model +from aiet.resources.tools.vela.check_model import is_vela_optimised + + +@pytest.fixture(scope="session") +def optimised_tflite_model( + optimised_input_model_file: Path, +) -> Model: + """Return Model instance read from a Vela-optimised TFLite file.""" + return get_model_from_file(optimised_input_model_file) + + +@pytest.fixture(scope="session") +def non_optimised_tflite_model( + non_optimised_input_model_file: Path, +) -> Model: + """Return Model instance read from a Vela-optimised TFLite file.""" + return get_model_from_file(non_optimised_input_model_file) + + +class TestIsVelaOptimised: + """Test class for is_vela_optimised() function.""" + + def test_return_true_when_input_is_optimised( + self, + optimised_tflite_model: Model, + ) -> None: + """Verify True returned when input is optimised model.""" + output = is_vela_optimised(optimised_tflite_model) + + assert output is True + + def test_return_false_when_input_is_not_optimised( + self, + non_optimised_tflite_model: Model, + ) -> None: + """Verify False returned when input is non-optimised model.""" + output = is_vela_optimised(non_optimised_tflite_model) + + assert output is False + + +def test_get_operator_list_returns_correct_instances( + optimised_tflite_model: Model, +) -> None: + """Verify list of OperatorCode instances returned by get_operator_list().""" + operator_list = get_operators_from_model(optimised_tflite_model) + + assert all(isinstance(operator, OperatorCode) for operator in operator_list) + + +class TestGetCustomCodesFromOperators: + """Test the get_custom_codes_from_operators() function.""" + + def test_returns_empty_list_when_input_operators_have_no_custom_codes( + self, monkeypatch: Any + ) -> None: + """Verify function returns empty list when operators have no custom codes.""" + # Mock OperatorCode.CustomCode() function to return None + monkeypatch.setattr( + "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", lambda _: None + ) + + operators = [OperatorCode()] * 3 + + custom_codes = get_custom_codes_from_operators(operators) + + assert custom_codes == [] + + def test_returns_custom_codes_when_input_operators_have_custom_codes( + self, monkeypatch: Any + ) -> None: + """Verify list of bytes objects returned representing the CustomCodes.""" + # Mock OperatorCode.CustomCode() function to return a byte string + monkeypatch.setattr( + "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", + lambda _: b"custom-code", + ) + + operators = [OperatorCode()] * 3 + + custom_codes = get_custom_codes_from_operators(operators) + + assert custom_codes == [b"custom-code", b"custom-code", b"custom-code"] + + +@pytest.mark.parametrize( + "custom_codes, expected_output", + [ + ([b"ethos-u", b"something else"], True), + ([b"custom-code-1", b"custom-code-2"], False), + ], +) +def test_check_list_for_ethosu(custom_codes: list, expected_output: bool) -> None: + """Verify function detects 'ethos-u' bytes in the input list.""" + output = check_custom_codes_for_ethosu(custom_codes) + assert output is expected_output + + +class TestGetModelFromFile: + """Test the get_model_from_file() function.""" + + def test_error_raised_when_input_is_invalid_model_file( + self, + invalid_input_model_file: Path, + ) -> None: + """Verify error thrown when an invalid model file is given.""" + with pytest.raises(InvalidTFLiteFileError): + get_model_from_file(invalid_input_model_file) + + def test_model_instance_returned_when_input_is_valid_model_file( + self, + optimised_input_model_file: Path, + ) -> None: + """Verify file is read successfully and returns model instance.""" + tflite_model = get_model_from_file(optimised_input_model_file) + + assert isinstance(tflite_model, Model) + + +class TestCheckModel: + """Test the check_model() function.""" + + def test_check_model_with_non_optimised_input( + self, + non_optimised_input_model_file: Path, + ) -> None: + """Verify no error occurs for a valid input file.""" + check_model(non_optimised_input_model_file) + + def test_check_model_with_optimised_input( + self, + optimised_input_model_file: Path, + ) -> None: + """Verify that the right exception is raised with already optimised input.""" + with pytest.raises(ModelOptimisedException): + check_model(optimised_input_model_file) + + def test_check_model_with_invalid_input( + self, + invalid_input_model_file: Path, + ) -> None: + """Verify that an exception is raised with invalid input.""" + with pytest.raises(Exception): + check_model(invalid_input_model_file) |