1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.core.extract."""
from __future__ import annotations
from pathlib import Path
from typing import Any
from typing import Iterable
import pytest
from mlia.nn.rewrite.core.extract import ExtractPaths
from mlia.nn.rewrite.core.graph_edit.record import DEQUANT_SUFFIX
@pytest.mark.parametrize("dir_path", ("/dev/null", Path("/dev/null")))
@pytest.mark.parametrize("model_is_quantized", (False, True))
@pytest.mark.parametrize(
("obj", "func_names", "suffix"),
(
(ExtractPaths.tflite, ("start", "replace", "end"), ".tflite"),
(ExtractPaths.tfrec, ("input", "output", "end"), ".tfrec"),
),
)
def test_extract_paths(
dir_path: str | Path,
model_is_quantized: bool,
obj: Any,
func_names: Iterable[str],
suffix: str,
) -> None:
"""Test class ExtractPaths."""
for func_name in func_names:
func = getattr(obj, func_name)
path = func(dir_path, model_is_quantized)
assert path == Path(dir_path, path.relative_to(dir_path))
assert path.suffix == suffix
assert not model_is_quantized or path.stem.endswith(DEQUANT_SUFFIX)
|