aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorAnnie Tallund <annie.tallund@arm.com>2023-03-15 11:27:08 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:43:14 +0100
commit867f37d643e66c0223457c28f5345f2f21db97f2 (patch)
tree4e3c55896760e24a8b5eadc5176ce7f5586552e1 /tests
parent62768232c5fe4ed6b87136c336b65e13d030e9d4 (diff)
downloadmlia-867f37d643e66c0223457c28f5345f2f21db97f2.tar.gz
Adapt rewrite module to MLIA coding standards
- Fix imports - Update variable names - Refactor helper functions - Add licence headers - Add docstrings - Use f-strings rather than % notation - Create type annotations in rewrite module - Migrate from tqdm to rich progress bar - Use logging module in rewrite module: All print statements are replaced with logging module Resolves: MLIA-831, MLIA-842, MLIA-844, MLIA-846 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: Idee37538d72b9f01128a894281a8d10155f7c17c
Diffstat (limited to 'tests')
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_cut.py4
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_record.py4
-rw-r--r--tests/test_nn_rewrite_core_train.py14
3 files changed, 13 insertions, 9 deletions
diff --git a/tests/test_nn_rewrite_core_graph_edit_cut.py b/tests/test_nn_rewrite_core_graph_edit_cut.py
index 914fdfd..7d267ed 100644
--- a/tests/test_nn_rewrite_core_graph_edit_cut.py
+++ b/tests/test_nn_rewrite_core_graph_edit_cut.py
@@ -13,11 +13,11 @@ def test_cut_model(test_tflite_model: Path, tmp_path: Path) -> None:
"""Test the function cut_model()."""
output_file = tmp_path / "out.tflite"
cut_model(
- model_file=test_tflite_model,
+ model_file=str(test_tflite_model),
input_names=["serving_default_input:0"],
output_names=["sequential/flatten/Reshape"],
subgraph_index=0,
- output_file=output_file,
+ output_file=str(output_file),
)
assert output_file.is_file()
diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py
index 39aeef5..cd728af 100644
--- a/tests/test_nn_rewrite_core_graph_edit_record.py
+++ b/tests/test_nn_rewrite_core_graph_edit_record.py
@@ -7,7 +7,7 @@ import pytest
import tensorflow as tf
from mlia.nn.rewrite.core.graph_edit.record import record_model
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFReader
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
@pytest.mark.parametrize("batch_size", (None, 1, 2))
@@ -46,7 +46,7 @@ def test_record_model(
# any of the model outputs
interpreter = tf.lite.Interpreter(str(test_tflite_model))
model_outputs = interpreter.get_output_details()
- dataset = NumpyTFReader(str(output_file))
+ dataset = numpytf_read(str(output_file))
for data in dataset:
for name, tensor in data.items():
assert data_matches_outputs(name, tensor, model_outputs)
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index d2bc1e0..3c2ef3e 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -6,17 +6,21 @@ from __future__ import annotations
from pathlib import Path
from tempfile import TemporaryDirectory
+from typing import Any
import numpy as np
import pytest
import tensorflow as tf
from mlia.nn.rewrite.core.train import augmentation_presets
+from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
-def replace_fully_connected_with_conv(input_shape, output_shape) -> tf.keras.Model:
+def replace_fully_connected_with_conv(
+ input_shape: Any, output_shape: Any
+) -> tf.keras.Model:
"""Get a replacement model for the fully connected layer."""
for name, shape in {
"Input": input_shape,
@@ -43,7 +47,7 @@ def check_train(
augmentation_preset: tuple[float | None, float | None] = augmentation_presets[
"none"
],
- lr_schedule: str = "cosine",
+ lr_schedule: LearningRateSchedule = "cosine",
use_unmodified_model: bool = False,
num_procs: int = 1,
) -> None:
@@ -60,7 +64,7 @@ def check_train(
output_tensors=["StatefulPartitionedCall:0"],
augment=augmentation_preset,
steps=32,
- lr=1e-3,
+ learning_rate=1e-3,
batch_size=batch_size,
verbose=verbose,
show_progress=show_progress,
@@ -104,7 +108,7 @@ def test_train(
verbose: bool,
show_progress: bool,
augmentation_preset: tuple[float | None, float | None],
- lr_schedule: str,
+ lr_schedule: LearningRateSchedule,
use_unmodified_model: bool,
num_procs: int,
) -> None:
@@ -131,7 +135,7 @@ def test_train_invalid_schedule(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- lr_schedule="unknown_schedule",
+ lr_schedule="unknown_schedule", # type: ignore
)