aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 72b8f48..60c39ae 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Sequential trainer."""
# pylint: disable=too-many-locals
@@ -23,6 +23,7 @@ from typing import Literal
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from numpy.random import Generator
from mlia.nn.rewrite.core.extract import extract
@@ -383,8 +384,8 @@ def train_in_dir(
model = replace_fn(input_shape, output_shape)
- optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
- loss_fn = tf.keras.losses.MeanSquaredError()
+ optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
+ loss_fn = keras.losses.MeanSquaredError()
if model_is_quantized:
model = tfmot.quantization.keras.quantize_model(model)
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
@@ -403,7 +404,7 @@ def train_in_dir(
* (math.cos(math.pi * current_step / train_params.steps) + 1)
/ 2.0
)
- tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate)
+ keras.backend.set_value(optimizer.learning_rate, cd_learning_rate)
def late_decay(
epoch_step: int, logs: Any # pylint: disable=unused-argument
@@ -414,16 +415,16 @@ def train_in_dir(
decay_length = train_params.steps // 5
decay_fraction = min(steps_remaining, decay_length) / decay_length
ld_learning_rate = train_params.learning_rate * decay_fraction
- tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate)
+ keras.backend.set_value(optimizer.learning_rate, ld_learning_rate)
assert train_params.learning_rate_schedule in LEARNING_RATE_SCHEDULES, (
f'Learning rate schedule "{train_params.learning_rate_schedule}" '
f"not implemented - expected one of {LEARNING_RATE_SCHEDULES}."
)
if train_params.learning_rate_schedule == "cosine":
- callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)]
+ callbacks = [keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)]
elif train_params.learning_rate_schedule == "late":
- callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)]
+ callbacks = [keras.callbacks.LambdaCallback(on_batch_begin=late_decay)]
elif train_params.learning_rate_schedule == "constant":
callbacks = []
@@ -474,7 +475,7 @@ def train_in_dir(
def save_as_tflite(
- keras_model: tf.keras.Model,
+ keras_model: keras.Model,
filename: str,
input_name: str,
input_shape: list,
@@ -485,7 +486,7 @@ def save_as_tflite(
"""Save Keras model as TFLite file."""
@contextmanager
- def fixed_input(keras_model: tf.keras.Model, tmp_shape: list) -> GeneratorType:
+ def fixed_input(keras_model: keras.Model, tmp_shape: list) -> GeneratorType:
"""Fix the input shape of the Keras model temporarily.
This avoids artifacts during conversion to TensorFlow Lite.