aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_train.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r--tests/test_nn_rewrite_core_train.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index d2bc1e0..3c2ef3e 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -6,17 +6,21 @@ from __future__ import annotations
from pathlib import Path
from tempfile import TemporaryDirectory
+from typing import Any
import numpy as np
import pytest
import tensorflow as tf
from mlia.nn.rewrite.core.train import augmentation_presets
+from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
-def replace_fully_connected_with_conv(input_shape, output_shape) -> tf.keras.Model:
+def replace_fully_connected_with_conv(
+ input_shape: Any, output_shape: Any
+) -> tf.keras.Model:
"""Get a replacement model for the fully connected layer."""
for name, shape in {
"Input": input_shape,
@@ -43,7 +47,7 @@ def check_train(
augmentation_preset: tuple[float | None, float | None] = augmentation_presets[
"none"
],
- lr_schedule: str = "cosine",
+ lr_schedule: LearningRateSchedule = "cosine",
use_unmodified_model: bool = False,
num_procs: int = 1,
) -> None:
@@ -60,7 +64,7 @@ def check_train(
output_tensors=["StatefulPartitionedCall:0"],
augment=augmentation_preset,
steps=32,
- lr=1e-3,
+ learning_rate=1e-3,
batch_size=batch_size,
verbose=verbose,
show_progress=show_progress,
@@ -104,7 +108,7 @@ def test_train(
verbose: bool,
show_progress: bool,
augmentation_preset: tuple[float | None, float | None],
- lr_schedule: str,
+ lr_schedule: LearningRateSchedule,
use_unmodified_model: bool,
num_procs: int,
) -> None:
@@ -131,7 +135,7 @@ def test_train_invalid_schedule(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- lr_schedule="unknown_schedule",
+ lr_schedule="unknown_schedule", # type: ignore
)