aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_extract.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_extract.py')
-rw-r--r--tests/test_nn_rewrite_core_extract.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_extract.py b/tests/test_nn_rewrite_core_extract.py
new file mode 100644
index 0000000..09eca77
--- /dev/null
+++ b/tests/test_nn_rewrite_core_extract.py
@@ -0,0 +1,38 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.core.extract."""
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any
+from typing import Iterable
+
+import pytest
+
+from mlia.nn.rewrite.core.extract import ExtractPaths
+from mlia.nn.rewrite.core.graph_edit.record import DEQUANT_SUFFIX
+
+
+@pytest.mark.parametrize("dir_path", ("/dev/null", Path("/dev/null")))
+@pytest.mark.parametrize("model_is_quantized", (False, True))
+@pytest.mark.parametrize(
+ ("obj", "func_names", "suffix"),
+ (
+ (ExtractPaths.tflite, ("start", "replace", "end"), ".tflite"),
+ (ExtractPaths.tfrec, ("input", "output", "end"), ".tfrec"),
+ ),
+)
+def test_extract_paths(
+ dir_path: str | Path,
+ model_is_quantized: bool,
+ obj: Any,
+ func_names: Iterable[str],
+ suffix: str,
+) -> None:
+ """Test class ExtractPaths."""
+ for func_name in func_names:
+ func = getattr(obj, func_name)
+ path = func(dir_path, model_is_quantized)
+ assert path == Path(dir_path, path.relative_to(dir_path))
+ assert path.suffix == suffix
+ assert not model_is_quantized or path.stem.endswith(DEQUANT_SUFFIX)