aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_utils_numpy_tfrecord.py')
-rw-r--r--tests/test_nn_rewrite_core_utils_numpy_tfrecord.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
index 7fc8048..d030350 100644
--- a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
+++ b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
@@ -5,6 +5,10 @@ from __future__ import annotations
from pathlib import Path
+import pytest
+import tensorflow as tf
+
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import make_decode_fn
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec
@@ -16,3 +20,24 @@ def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None:
sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file))
assert output_file.is_file()
assert numpytf_count(str(output_file)) == 1
+
+
+def test_make_decode_fn(test_tfrecord: Path) -> None:
+ """Test function make_decode_fn()."""
+ decode = make_decode_fn(str(test_tfrecord))
+ dataset = tf.data.TFRecordDataset(str(test_tfrecord))
+ features = decode(next(iter(dataset)))
+ assert isinstance(features, dict)
+ assert len(features) == 1
+ key, val = next(iter(features.items()))
+ assert isinstance(key, str)
+ assert isinstance(val, tf.Tensor)
+ assert val.dtype == tf.int8
+
+ with pytest.raises(FileNotFoundError):
+ make_decode_fn(str(test_tfrecord) + "_")
+
+
+def test_numpytf_count(test_tfrecord: Path) -> None:
+ """Test function numpytf_count()."""
+ assert numpytf_count(test_tfrecord) == 3