aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-03-20 08:13:39 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-03-28 07:17:32 +0000
commitf3f3ab451968350b8f6df2de7c60b2c2b9320b59 (patch)
tree05d56c8e41de9b32f8054019a21b78628151310d /src/mlia/nn/rewrite
parent5f063ae1cfbfa2568d2858af0a0ccaf192bb1e8d (diff)
downloadmlia-f3f3ab451968350b8f6df2de7c60b2c2b9320b59.tar.gz
feat: Update Vela version
Updates Vela Version to 3.11.0 and TensorFlow version to 2.15.1 Required keras import to change: from keras.api._v2 import keras needed instead of calling tf.keras Subsequently tf.keras.X needed to change to keras.X Resolves: MLIA-1107 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I53bcaa9cdad58b0e6c311c8c6490393d33cb18bc
Diffstat (limited to 'src/mlia/nn/rewrite')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py9
-rw-r--r--src/mlia/nn/rewrite/core/train.py19
-rw-r--r--src/mlia/nn/rewrite/library/fc_layer.py14
3 files changed, 22 insertions, 20 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 8658991..c7d13ba 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -12,7 +12,7 @@ from typing import Any
from typing import Callable
from typing import cast
-import tensorflow as tf
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.core.errors import ConfigurationError
from mlia.core.reporting import Column
@@ -25,8 +25,9 @@ from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
+
logger = logging.getLogger(__name__)
-RewriteCallable = Callable[[Any, Any], tf.keras.Model]
+RewriteCallable = Callable[[Any, Any], keras.Model]
class Rewrite:
@@ -37,7 +38,7 @@ class Rewrite:
self.name = name
self.function = rewrite_fn
- def __call__(self, input_shape: Any, output_shape: Any) -> tf.keras.Model:
+ def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model:
"""Perform the rewrite operation using the configured function."""
try:
return self.function(input_shape, output_shape)
@@ -52,7 +53,7 @@ class DynamicallyLoadedRewrite(Rewrite):
def __init__(self, name: str, function_name: str):
"""Initialize."""
- def load_and_run(input_shape: Any, output_shape: Any) -> tf.keras.Model:
+ def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model:
"""Load the function from a file dynamically."""
self.load_function(function_name)
return self.function(input_shape, output_shape)
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 72b8f48..60c39ae 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Sequential trainer."""
# pylint: disable=too-many-locals
@@ -23,6 +23,7 @@ from typing import Literal
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from numpy.random import Generator
from mlia.nn.rewrite.core.extract import extract
@@ -383,8 +384,8 @@ def train_in_dir(
model = replace_fn(input_shape, output_shape)
- optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
- loss_fn = tf.keras.losses.MeanSquaredError()
+ optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
+ loss_fn = keras.losses.MeanSquaredError()
if model_is_quantized:
model = tfmot.quantization.keras.quantize_model(model)
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
@@ -403,7 +404,7 @@ def train_in_dir(
* (math.cos(math.pi * current_step / train_params.steps) + 1)
/ 2.0
)
- tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate)
+ keras.backend.set_value(optimizer.learning_rate, cd_learning_rate)
def late_decay(
epoch_step: int, logs: Any # pylint: disable=unused-argument
@@ -414,16 +415,16 @@ def train_in_dir(
decay_length = train_params.steps // 5
decay_fraction = min(steps_remaining, decay_length) / decay_length
ld_learning_rate = train_params.learning_rate * decay_fraction
- tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate)
+ keras.backend.set_value(optimizer.learning_rate, ld_learning_rate)
assert train_params.learning_rate_schedule in LEARNING_RATE_SCHEDULES, (
f'Learning rate schedule "{train_params.learning_rate_schedule}" '
f"not implemented - expected one of {LEARNING_RATE_SCHEDULES}."
)
if train_params.learning_rate_schedule == "cosine":
- callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)]
+ callbacks = [keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)]
elif train_params.learning_rate_schedule == "late":
- callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)]
+ callbacks = [keras.callbacks.LambdaCallback(on_batch_begin=late_decay)]
elif train_params.learning_rate_schedule == "constant":
callbacks = []
@@ -474,7 +475,7 @@ def train_in_dir(
def save_as_tflite(
- keras_model: tf.keras.Model,
+ keras_model: keras.Model,
filename: str,
input_name: str,
input_shape: list,
@@ -485,7 +486,7 @@ def save_as_tflite(
"""Save Keras model as TFLite file."""
@contextmanager
- def fixed_input(keras_model: tf.keras.Model, tmp_shape: list) -> GeneratorType:
+ def fixed_input(keras_model: keras.Model, tmp_shape: list) -> GeneratorType:
"""Fix the input shape of the Keras model temporarily.
This avoids artifacts during conversion to TensorFlow Lite.
diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py
index 2480500..041ce85 100644
--- a/src/mlia/nn/rewrite/library/fc_layer.py
+++ b/src/mlia/nn/rewrite/library/fc_layer.py
@@ -1,18 +1,18 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Example rewrite with one fully connected layer."""
from typing import Any
-import tensorflow as tf
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-def get_keras_model(input_shape: Any, output_shape: Any) -> tf.keras.Model:
+def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model:
"""Generate TensorFlow Lite model for rewrite."""
- model = tf.keras.Sequential(
+ model = keras.Sequential(
(
- tf.keras.layers.InputLayer(input_shape=input_shape),
- tf.keras.layers.Reshape([-1]),
- tf.keras.layers.Dense(output_shape),
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
)
)
return model