aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
blob: d030350cd01da494176cf52285e34cb45ea7280f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.core.utils.numpy_tfrecord."""
from __future__ import annotations

from pathlib import Path

import pytest
import tensorflow as tf

from mlia.nn.rewrite.core.utils.numpy_tfrecord import make_decode_fn
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec


def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None:
    """Test function sample_tfrec()."""
    output_file = tmp_path / "output.tfrecord"
    # Sample 1 sample from test_tfrecord
    sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file))
    assert output_file.is_file()
    assert numpytf_count(str(output_file)) == 1


def test_make_decode_fn(test_tfrecord: Path) -> None:
    """Test function make_decode_fn()."""
    decode = make_decode_fn(str(test_tfrecord))
    dataset = tf.data.TFRecordDataset(str(test_tfrecord))
    features = decode(next(iter(dataset)))
    assert isinstance(features, dict)
    assert len(features) == 1
    key, val = next(iter(features.items()))
    assert isinstance(key, str)
    assert isinstance(val, tf.Tensor)
    assert val.dtype == tf.int8

    with pytest.raises(FileNotFoundError):
        make_decode_fn(str(test_tfrecord) + "_")


def test_numpytf_count(test_tfrecord: Path) -> None:
    """Test function numpytf_count()."""
    assert numpytf_count(test_tfrecord) == 3