blob: 7fc80480f9e27a1777919dc64e38f58ec1588422 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.core.utils.numpy_tfrecord."""
from __future__ import annotations
from pathlib import Path
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec
def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None:
"""Test function sample_tfrec()."""
output_file = tmp_path / "output.tfrecord"
# Sample 1 sample from test_tfrecord
sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file))
assert output_file.is_file()
assert numpytf_count(str(output_file)) == 1
|