aboutsummaryrefslogtreecommitdiff
path: root/tests/aiet/test_check_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/aiet/test_check_model.py')
-rw-r--r--tests/aiet/test_check_model.py162
1 files changed, 162 insertions, 0 deletions
diff --git a/tests/aiet/test_check_model.py b/tests/aiet/test_check_model.py
new file mode 100644
index 0000000..4eafe59
--- /dev/null
+++ b/tests/aiet/test_check_model.py
@@ -0,0 +1,162 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=redefined-outer-name,no-self-use
+"""Module for testing check_model.py script."""
+from pathlib import Path
+from typing import Any
+
+import pytest
+from ethosu.vela.tflite.Model import Model
+from ethosu.vela.tflite.OperatorCode import OperatorCode
+
+from aiet.cli.common import InvalidTFLiteFileError
+from aiet.cli.common import ModelOptimisedException
+from aiet.resources.tools.vela.check_model import check_custom_codes_for_ethosu
+from aiet.resources.tools.vela.check_model import check_model
+from aiet.resources.tools.vela.check_model import get_custom_codes_from_operators
+from aiet.resources.tools.vela.check_model import get_model_from_file
+from aiet.resources.tools.vela.check_model import get_operators_from_model
+from aiet.resources.tools.vela.check_model import is_vela_optimised
+
+
+@pytest.fixture(scope="session")
+def optimised_tflite_model(
+ optimised_input_model_file: Path,
+) -> Model:
+ """Return Model instance read from a Vela-optimised TFLite file."""
+ return get_model_from_file(optimised_input_model_file)
+
+
+@pytest.fixture(scope="session")
+def non_optimised_tflite_model(
+ non_optimised_input_model_file: Path,
+) -> Model:
+ """Return Model instance read from a Vela-optimised TFLite file."""
+ return get_model_from_file(non_optimised_input_model_file)
+
+
+class TestIsVelaOptimised:
+ """Test class for is_vela_optimised() function."""
+
+ def test_return_true_when_input_is_optimised(
+ self,
+ optimised_tflite_model: Model,
+ ) -> None:
+ """Verify True returned when input is optimised model."""
+ output = is_vela_optimised(optimised_tflite_model)
+
+ assert output is True
+
+ def test_return_false_when_input_is_not_optimised(
+ self,
+ non_optimised_tflite_model: Model,
+ ) -> None:
+ """Verify False returned when input is non-optimised model."""
+ output = is_vela_optimised(non_optimised_tflite_model)
+
+ assert output is False
+
+
+def test_get_operator_list_returns_correct_instances(
+ optimised_tflite_model: Model,
+) -> None:
+ """Verify list of OperatorCode instances returned by get_operator_list()."""
+ operator_list = get_operators_from_model(optimised_tflite_model)
+
+ assert all(isinstance(operator, OperatorCode) for operator in operator_list)
+
+
+class TestGetCustomCodesFromOperators:
+ """Test the get_custom_codes_from_operators() function."""
+
+ def test_returns_empty_list_when_input_operators_have_no_custom_codes(
+ self, monkeypatch: Any
+ ) -> None:
+ """Verify function returns empty list when operators have no custom codes."""
+ # Mock OperatorCode.CustomCode() function to return None
+ monkeypatch.setattr(
+ "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", lambda _: None
+ )
+
+ operators = [OperatorCode()] * 3
+
+ custom_codes = get_custom_codes_from_operators(operators)
+
+ assert custom_codes == []
+
+ def test_returns_custom_codes_when_input_operators_have_custom_codes(
+ self, monkeypatch: Any
+ ) -> None:
+ """Verify list of bytes objects returned representing the CustomCodes."""
+ # Mock OperatorCode.CustomCode() function to return a byte string
+ monkeypatch.setattr(
+ "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode",
+ lambda _: b"custom-code",
+ )
+
+ operators = [OperatorCode()] * 3
+
+ custom_codes = get_custom_codes_from_operators(operators)
+
+ assert custom_codes == [b"custom-code", b"custom-code", b"custom-code"]
+
+
+@pytest.mark.parametrize(
+ "custom_codes, expected_output",
+ [
+ ([b"ethos-u", b"something else"], True),
+ ([b"custom-code-1", b"custom-code-2"], False),
+ ],
+)
+def test_check_list_for_ethosu(custom_codes: list, expected_output: bool) -> None:
+ """Verify function detects 'ethos-u' bytes in the input list."""
+ output = check_custom_codes_for_ethosu(custom_codes)
+ assert output is expected_output
+
+
+class TestGetModelFromFile:
+ """Test the get_model_from_file() function."""
+
+ def test_error_raised_when_input_is_invalid_model_file(
+ self,
+ invalid_input_model_file: Path,
+ ) -> None:
+ """Verify error thrown when an invalid model file is given."""
+ with pytest.raises(InvalidTFLiteFileError):
+ get_model_from_file(invalid_input_model_file)
+
+ def test_model_instance_returned_when_input_is_valid_model_file(
+ self,
+ optimised_input_model_file: Path,
+ ) -> None:
+ """Verify file is read successfully and returns model instance."""
+ tflite_model = get_model_from_file(optimised_input_model_file)
+
+ assert isinstance(tflite_model, Model)
+
+
+class TestCheckModel:
+ """Test the check_model() function."""
+
+ def test_check_model_with_non_optimised_input(
+ self,
+ non_optimised_input_model_file: Path,
+ ) -> None:
+ """Verify no error occurs for a valid input file."""
+ check_model(non_optimised_input_model_file)
+
+ def test_check_model_with_optimised_input(
+ self,
+ optimised_input_model_file: Path,
+ ) -> None:
+ """Verify that the right exception is raised with already optimised input."""
+ with pytest.raises(ModelOptimisedException):
+ check_model(optimised_input_model_file)
+
+ def test_check_model_with_invalid_input(
+ self,
+ invalid_input_model_file: Path,
+ ) -> None:
+ """Verify that an exception is raised with invalid input."""
+ with pytest.raises(Exception):
+ check_model(invalid_input_model_file)