diff options
Diffstat (limited to 'tests/aiet/test_check_model.py')
-rw-r--r-- | tests/aiet/test_check_model.py | 162 |
1 files changed, 0 insertions, 162 deletions
diff --git a/tests/aiet/test_check_model.py b/tests/aiet/test_check_model.py deleted file mode 100644 index 4eafe59..0000000 --- a/tests/aiet/test_check_model.py +++ /dev/null @@ -1,162 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=redefined-outer-name,no-self-use -"""Module for testing check_model.py script.""" -from pathlib import Path -from typing import Any - -import pytest -from ethosu.vela.tflite.Model import Model -from ethosu.vela.tflite.OperatorCode import OperatorCode - -from aiet.cli.common import InvalidTFLiteFileError -from aiet.cli.common import ModelOptimisedException -from aiet.resources.tools.vela.check_model import check_custom_codes_for_ethosu -from aiet.resources.tools.vela.check_model import check_model -from aiet.resources.tools.vela.check_model import get_custom_codes_from_operators -from aiet.resources.tools.vela.check_model import get_model_from_file -from aiet.resources.tools.vela.check_model import get_operators_from_model -from aiet.resources.tools.vela.check_model import is_vela_optimised - - -@pytest.fixture(scope="session") -def optimised_tflite_model( - optimised_input_model_file: Path, -) -> Model: - """Return Model instance read from a Vela-optimised TFLite file.""" - return get_model_from_file(optimised_input_model_file) - - -@pytest.fixture(scope="session") -def non_optimised_tflite_model( - non_optimised_input_model_file: Path, -) -> Model: - """Return Model instance read from a Vela-optimised TFLite file.""" - return get_model_from_file(non_optimised_input_model_file) - - -class TestIsVelaOptimised: - """Test class for is_vela_optimised() function.""" - - def test_return_true_when_input_is_optimised( - self, - optimised_tflite_model: Model, - ) -> None: - """Verify True returned when input is optimised model.""" - output = is_vela_optimised(optimised_tflite_model) - - assert output is True - - def test_return_false_when_input_is_not_optimised( - self, - non_optimised_tflite_model: Model, - ) -> None: - """Verify False returned when input is non-optimised model.""" - output = is_vela_optimised(non_optimised_tflite_model) - - assert output is False - - -def test_get_operator_list_returns_correct_instances( - optimised_tflite_model: Model, -) -> None: - """Verify list of OperatorCode instances returned by get_operator_list().""" - operator_list = get_operators_from_model(optimised_tflite_model) - - assert all(isinstance(operator, OperatorCode) for operator in operator_list) - - -class TestGetCustomCodesFromOperators: - """Test the get_custom_codes_from_operators() function.""" - - def test_returns_empty_list_when_input_operators_have_no_custom_codes( - self, monkeypatch: Any - ) -> None: - """Verify function returns empty list when operators have no custom codes.""" - # Mock OperatorCode.CustomCode() function to return None - monkeypatch.setattr( - "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", lambda _: None - ) - - operators = [OperatorCode()] * 3 - - custom_codes = get_custom_codes_from_operators(operators) - - assert custom_codes == [] - - def test_returns_custom_codes_when_input_operators_have_custom_codes( - self, monkeypatch: Any - ) -> None: - """Verify list of bytes objects returned representing the CustomCodes.""" - # Mock OperatorCode.CustomCode() function to return a byte string - monkeypatch.setattr( - "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", - lambda _: b"custom-code", - ) - - operators = [OperatorCode()] * 3 - - custom_codes = get_custom_codes_from_operators(operators) - - assert custom_codes == [b"custom-code", b"custom-code", b"custom-code"] - - -@pytest.mark.parametrize( - "custom_codes, expected_output", - [ - ([b"ethos-u", b"something else"], True), - ([b"custom-code-1", b"custom-code-2"], False), - ], -) -def test_check_list_for_ethosu(custom_codes: list, expected_output: bool) -> None: - """Verify function detects 'ethos-u' bytes in the input list.""" - output = check_custom_codes_for_ethosu(custom_codes) - assert output is expected_output - - -class TestGetModelFromFile: - """Test the get_model_from_file() function.""" - - def test_error_raised_when_input_is_invalid_model_file( - self, - invalid_input_model_file: Path, - ) -> None: - """Verify error thrown when an invalid model file is given.""" - with pytest.raises(InvalidTFLiteFileError): - get_model_from_file(invalid_input_model_file) - - def test_model_instance_returned_when_input_is_valid_model_file( - self, - optimised_input_model_file: Path, - ) -> None: - """Verify file is read successfully and returns model instance.""" - tflite_model = get_model_from_file(optimised_input_model_file) - - assert isinstance(tflite_model, Model) - - -class TestCheckModel: - """Test the check_model() function.""" - - def test_check_model_with_non_optimised_input( - self, - non_optimised_input_model_file: Path, - ) -> None: - """Verify no error occurs for a valid input file.""" - check_model(non_optimised_input_model_file) - - def test_check_model_with_optimised_input( - self, - optimised_input_model_file: Path, - ) -> None: - """Verify that the right exception is raised with already optimised input.""" - with pytest.raises(ModelOptimisedException): - check_model(optimised_input_model_file) - - def test_check_model_with_invalid_input( - self, - invalid_input_model_file: Path, - ) -> None: - """Verify that an exception is raised with invalid input.""" - with pytest.raises(Exception): - check_model(invalid_input_model_file) |