diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_utils_numpy_tfrecord.py')
-rw-r--r-- | tests/test_nn_rewrite_core_utils_numpy_tfrecord.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py new file mode 100644 index 0000000..7fc8048 --- /dev/null +++ b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.core.utils.numpy_tfrecord.""" +from __future__ import annotations + +from pathlib import Path + +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count +from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec + + +def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None: + """Test function sample_tfrec().""" + output_file = tmp_path / "output.tfrecord" + # Sample 1 sample from test_tfrecord + sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file)) + assert output_file.is_file() + assert numpytf_count(str(output_file)) == 1 |