diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_extract.py')
-rw-r--r-- | tests/test_nn_rewrite_core_extract.py | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_extract.py b/tests/test_nn_rewrite_core_extract.py new file mode 100644 index 0000000..09eca77 --- /dev/null +++ b/tests/test_nn_rewrite_core_extract.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.core.extract.""" +from __future__ import annotations + +from pathlib import Path +from typing import Any +from typing import Iterable + +import pytest + +from mlia.nn.rewrite.core.extract import ExtractPaths +from mlia.nn.rewrite.core.graph_edit.record import DEQUANT_SUFFIX + + +@pytest.mark.parametrize("dir_path", ("/dev/null", Path("/dev/null"))) +@pytest.mark.parametrize("model_is_quantized", (False, True)) +@pytest.mark.parametrize( + ("obj", "func_names", "suffix"), + ( + (ExtractPaths.tflite, ("start", "replace", "end"), ".tflite"), + (ExtractPaths.tfrec, ("input", "output", "end"), ".tfrec"), + ), +) +def test_extract_paths( + dir_path: str | Path, + model_is_quantized: bool, + obj: Any, + func_names: Iterable[str], + suffix: str, +) -> None: + """Test class ExtractPaths.""" + for func_name in func_names: + func = getattr(obj, func_name) + path = func(dir_path, model_is_quantized) + assert path == Path(dir_path, path.relative_to(dir_path)) + assert path.suffix == suffix + assert not model_is_quantized or path.stem.endswith(DEQUANT_SUFFIX) |