diff options
Diffstat (limited to 'tests/test_target_ethos_u_advisor.py')
-rw-r--r-- | tests/test_target_ethos_u_advisor.py | 51 |
1 files changed, 46 insertions, 5 deletions
diff --git a/tests/test_target_ethos_u_advisor.py b/tests/test_target_ethos_u_advisor.py index 11aefc7..20131d2 100644 --- a/tests/test_target_ethos_u_advisor.py +++ b/tests/test_target_ethos_u_advisor.py @@ -1,7 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Ethos-U MLIA module.""" +from __future__ import annotations + +from contextlib import ExitStack as does_not_raise from pathlib import Path +from typing import Any import pytest @@ -16,16 +20,53 @@ def test_advisor_metadata() -> None: assert EthosUInferenceAdvisor.name() == "ethos_u_inference_advisor" -def test_unsupported_advice_categories(tmp_path: Path, test_tflite_model: Path) -> None: +@pytest.mark.parametrize( + "optimization_targets, expected_error", + [ + [ + [ + { + "optimization_type": "pruning", + "optimization_target": 0.5, + "layers_to_optimize": None, + } + ], + pytest.raises( + Exception, + match="Only 'rewrite' is supported for TensorFlow Lite files.", + ), + ], + [ + [ + { + "optimization_type": "rewrite", + "optimization_target": "fully_connected", + "layers_to_optimize": [ + "MobileNet/avg_pool/AvgPool", + "MobileNet/fc1/BiasAdd", + ], + } + ], + does_not_raise(), + ], + ], +) +def test_unsupported_advice_categories( + tmp_path: Path, + test_tflite_model: Path, + optimization_targets: list[dict[str, Any]], + expected_error: Any, +) -> None: """Test that advisor should throw an exception for unsupported categories.""" - with pytest.raises( - Exception, match="Optimizations are not supported for TensorFlow Lite files." - ): + with expected_error: ctx = ExecutionContext( output_dir=tmp_path, advice_category={AdviceCategory.OPTIMIZATION} ) advisor = configure_and_get_ethosu_advisor( - ctx, "ethos-u55-256", str(test_tflite_model) + ctx, + "ethos-u55-256", + str(test_tflite_model), + optimization_targets=optimization_targets, ) advisor.configure(ctx) |