aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
blob: 7fc80480f9e27a1777919dc64e38f58ec1588422 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.core.utils.numpy_tfrecord."""
from __future__ import annotations

from pathlib import Path

from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec


def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None:
    """Test function sample_tfrec()."""
    output_file = tmp_path / "output.tfrecord"
    # Sample 1 sample from test_tfrecord
    sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file))
    assert output_file.is_file()
    assert numpytf_count(str(output_file)) == 1