From f3f3ab451968350b8f6df2de7c60b2c2b9320b59 Mon Sep 17 00:00:00 2001 From: Nathan Bailey Date: Wed, 20 Mar 2024 08:13:39 +0000 Subject: feat: Update Vela version Updates Vela Version to 3.11.0 and TensorFlow version to 2.15.1 Required keras import to change: from keras.api._v2 import keras needed instead of calling tf.keras Subsequently tf.keras.X needed to change to keras.X Resolves: MLIA-1107 Signed-off-by: Nathan Bailey Change-Id: I53bcaa9cdad58b0e6c311c8c6490393d33cb18bc --- src/mlia/nn/rewrite/core/train.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) (limited to 'src/mlia/nn/rewrite/core/train.py') diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 72b8f48..60c39ae 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Sequential trainer.""" # pylint: disable=too-many-locals @@ -23,6 +23,7 @@ from typing import Literal import numpy as np import tensorflow as tf import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from numpy.random import Generator from mlia.nn.rewrite.core.extract import extract @@ -383,8 +384,8 @@ def train_in_dir( model = replace_fn(input_shape, output_shape) - optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate) - loss_fn = tf.keras.losses.MeanSquaredError() + optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) + loss_fn = keras.losses.MeanSquaredError() if model_is_quantized: model = tfmot.quantization.keras.quantize_model(model) model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) @@ -403,7 +404,7 @@ def train_in_dir( * (math.cos(math.pi * current_step / train_params.steps) + 1) / 2.0 ) - tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate) + keras.backend.set_value(optimizer.learning_rate, cd_learning_rate) def late_decay( epoch_step: int, logs: Any # pylint: disable=unused-argument @@ -414,16 +415,16 @@ def train_in_dir( decay_length = train_params.steps // 5 decay_fraction = min(steps_remaining, decay_length) / decay_length ld_learning_rate = train_params.learning_rate * decay_fraction - tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate) + keras.backend.set_value(optimizer.learning_rate, ld_learning_rate) assert train_params.learning_rate_schedule in LEARNING_RATE_SCHEDULES, ( f'Learning rate schedule "{train_params.learning_rate_schedule}" ' f"not implemented - expected one of {LEARNING_RATE_SCHEDULES}." ) if train_params.learning_rate_schedule == "cosine": - callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)] + callbacks = [keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)] elif train_params.learning_rate_schedule == "late": - callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)] + callbacks = [keras.callbacks.LambdaCallback(on_batch_begin=late_decay)] elif train_params.learning_rate_schedule == "constant": callbacks = [] @@ -474,7 +475,7 @@ def train_in_dir( def save_as_tflite( - keras_model: tf.keras.Model, + keras_model: keras.Model, filename: str, input_name: str, input_shape: list, @@ -485,7 +486,7 @@ def save_as_tflite( """Save Keras model as TFLite file.""" @contextmanager - def fixed_input(keras_model: tf.keras.Model, tmp_shape: list) -> GeneratorType: + def fixed_input(keras_model: keras.Model, tmp_shape: list) -> GeneratorType: """Fix the input shape of the Keras model temporarily. This avoids artifacts during conversion to TensorFlow Lite. -- cgit v1.2.1