aboutsummaryrefslogtreecommitdiff
path: root/tests/aiet/test_check_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/aiet/test_check_model.py')
-rw-r--r--tests/aiet/test_check_model.py162
1 files changed, 0 insertions, 162 deletions
diff --git a/tests/aiet/test_check_model.py b/tests/aiet/test_check_model.py
deleted file mode 100644
index 4eafe59..0000000
--- a/tests/aiet/test_check_model.py
+++ /dev/null
@@ -1,162 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-# pylint: disable=redefined-outer-name,no-self-use
-"""Module for testing check_model.py script."""
-from pathlib import Path
-from typing import Any
-
-import pytest
-from ethosu.vela.tflite.Model import Model
-from ethosu.vela.tflite.OperatorCode import OperatorCode
-
-from aiet.cli.common import InvalidTFLiteFileError
-from aiet.cli.common import ModelOptimisedException
-from aiet.resources.tools.vela.check_model import check_custom_codes_for_ethosu
-from aiet.resources.tools.vela.check_model import check_model
-from aiet.resources.tools.vela.check_model import get_custom_codes_from_operators
-from aiet.resources.tools.vela.check_model import get_model_from_file
-from aiet.resources.tools.vela.check_model import get_operators_from_model
-from aiet.resources.tools.vela.check_model import is_vela_optimised
-
-
-@pytest.fixture(scope="session")
-def optimised_tflite_model(
- optimised_input_model_file: Path,
-) -> Model:
- """Return Model instance read from a Vela-optimised TFLite file."""
- return get_model_from_file(optimised_input_model_file)
-
-
-@pytest.fixture(scope="session")
-def non_optimised_tflite_model(
- non_optimised_input_model_file: Path,
-) -> Model:
- """Return Model instance read from a Vela-optimised TFLite file."""
- return get_model_from_file(non_optimised_input_model_file)
-
-
-class TestIsVelaOptimised:
- """Test class for is_vela_optimised() function."""
-
- def test_return_true_when_input_is_optimised(
- self,
- optimised_tflite_model: Model,
- ) -> None:
- """Verify True returned when input is optimised model."""
- output = is_vela_optimised(optimised_tflite_model)
-
- assert output is True
-
- def test_return_false_when_input_is_not_optimised(
- self,
- non_optimised_tflite_model: Model,
- ) -> None:
- """Verify False returned when input is non-optimised model."""
- output = is_vela_optimised(non_optimised_tflite_model)
-
- assert output is False
-
-
-def test_get_operator_list_returns_correct_instances(
- optimised_tflite_model: Model,
-) -> None:
- """Verify list of OperatorCode instances returned by get_operator_list()."""
- operator_list = get_operators_from_model(optimised_tflite_model)
-
- assert all(isinstance(operator, OperatorCode) for operator in operator_list)
-
-
-class TestGetCustomCodesFromOperators:
- """Test the get_custom_codes_from_operators() function."""
-
- def test_returns_empty_list_when_input_operators_have_no_custom_codes(
- self, monkeypatch: Any
- ) -> None:
- """Verify function returns empty list when operators have no custom codes."""
- # Mock OperatorCode.CustomCode() function to return None
- monkeypatch.setattr(
- "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", lambda _: None
- )
-
- operators = [OperatorCode()] * 3
-
- custom_codes = get_custom_codes_from_operators(operators)
-
- assert custom_codes == []
-
- def test_returns_custom_codes_when_input_operators_have_custom_codes(
- self, monkeypatch: Any
- ) -> None:
- """Verify list of bytes objects returned representing the CustomCodes."""
- # Mock OperatorCode.CustomCode() function to return a byte string
- monkeypatch.setattr(
- "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode",
- lambda _: b"custom-code",
- )
-
- operators = [OperatorCode()] * 3
-
- custom_codes = get_custom_codes_from_operators(operators)
-
- assert custom_codes == [b"custom-code", b"custom-code", b"custom-code"]
-
-
-@pytest.mark.parametrize(
- "custom_codes, expected_output",
- [
- ([b"ethos-u", b"something else"], True),
- ([b"custom-code-1", b"custom-code-2"], False),
- ],
-)
-def test_check_list_for_ethosu(custom_codes: list, expected_output: bool) -> None:
- """Verify function detects 'ethos-u' bytes in the input list."""
- output = check_custom_codes_for_ethosu(custom_codes)
- assert output is expected_output
-
-
-class TestGetModelFromFile:
- """Test the get_model_from_file() function."""
-
- def test_error_raised_when_input_is_invalid_model_file(
- self,
- invalid_input_model_file: Path,
- ) -> None:
- """Verify error thrown when an invalid model file is given."""
- with pytest.raises(InvalidTFLiteFileError):
- get_model_from_file(invalid_input_model_file)
-
- def test_model_instance_returned_when_input_is_valid_model_file(
- self,
- optimised_input_model_file: Path,
- ) -> None:
- """Verify file is read successfully and returns model instance."""
- tflite_model = get_model_from_file(optimised_input_model_file)
-
- assert isinstance(tflite_model, Model)
-
-
-class TestCheckModel:
- """Test the check_model() function."""
-
- def test_check_model_with_non_optimised_input(
- self,
- non_optimised_input_model_file: Path,
- ) -> None:
- """Verify no error occurs for a valid input file."""
- check_model(non_optimised_input_model_file)
-
- def test_check_model_with_optimised_input(
- self,
- optimised_input_model_file: Path,
- ) -> None:
- """Verify that the right exception is raised with already optimised input."""
- with pytest.raises(ModelOptimisedException):
- check_model(optimised_input_model_file)
-
- def test_check_model_with_invalid_input(
- self,
- invalid_input_model_file: Path,
- ) -> None:
- """Verify that an exception is raised with invalid input."""
- with pytest.raises(Exception):
- check_model(invalid_input_model_file)