blob: d030350cd01da494176cf52285e34cb45ea7280f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
|
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.core.utils.numpy_tfrecord."""
from __future__ import annotations
from pathlib import Path
import pytest
import tensorflow as tf
from mlia.nn.rewrite.core.utils.numpy_tfrecord import make_decode_fn
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec
def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None:
"""Test function sample_tfrec()."""
output_file = tmp_path / "output.tfrecord"
# Sample 1 sample from test_tfrecord
sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file))
assert output_file.is_file()
assert numpytf_count(str(output_file)) == 1
def test_make_decode_fn(test_tfrecord: Path) -> None:
"""Test function make_decode_fn()."""
decode = make_decode_fn(str(test_tfrecord))
dataset = tf.data.TFRecordDataset(str(test_tfrecord))
features = decode(next(iter(dataset)))
assert isinstance(features, dict)
assert len(features) == 1
key, val = next(iter(features.items()))
assert isinstance(key, str)
assert isinstance(val, tf.Tensor)
assert val.dtype == tf.int8
with pytest.raises(FileNotFoundError):
make_decode_fn(str(test_tfrecord) + "_")
def test_numpytf_count(test_tfrecord: Path) -> None:
"""Test function numpytf_count()."""
assert numpytf_count(test_tfrecord) == 3
|