aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-03-20 08:13:39 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-03-28 07:17:32 +0000
commitf3f3ab451968350b8f6df2de7c60b2c2b9320b59 (patch)
tree05d56c8e41de9b32f8054019a21b78628151310d /src/mlia/nn/tensorflow
parent5f063ae1cfbfa2568d2858af0a0ccaf192bb1e8d (diff)
downloadmlia-f3f3ab451968350b8f6df2de7c60b2c2b9320b59.tar.gz
feat: Update Vela version
Updates Vela Version to 3.11.0 and TensorFlow version to 2.15.1 Required keras import to change: from keras.api._v2 import keras needed instead of calling tf.keras Subsequently tf.keras.X needed to change to keras.X Resolves: MLIA-1107 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I53bcaa9cdad58b0e6c311c8c6490393d33cb18bc
Diffstat (limited to 'src/mlia/nn/tensorflow')
-rw-r--r--src/mlia/nn/tensorflow/config.py6
-rw-r--r--src/mlia/nn/tensorflow/optimizations/clustering.py16
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py31
-rw-r--r--src/mlia/nn/tensorflow/tflite_convert.py18
-rw-r--r--src/mlia/nn/tensorflow/utils.py5
5 files changed, 39 insertions, 37 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index 44fbaef..c6fae1c 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -15,6 +15,7 @@ from typing import List
import numpy as np
import tensorflow as tf
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.core.context import Context
from mlia.nn.tensorflow.optimizations.quantization import dequantize
@@ -30,6 +31,7 @@ from mlia.nn.tensorflow.utils import is_saved_model
from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.utils.logging import log_action
+
logger = logging.getLogger(__name__)
@@ -57,10 +59,10 @@ class KerasModel(ModelConfiguration):
Supports all models supported by Keras API: saved model, H5, HDF5
"""
- def get_keras_model(self) -> tf.keras.Model:
+ def get_keras_model(self) -> keras.Model:
"""Return associated Keras model."""
try:
- keras_model = tf.keras.models.load_model(self.model_path)
+ keras_model = keras.models.load_model(self.model_path)
except OSError as err:
raise RuntimeError(
f"Unable to load model content in {self.model_path}. "
diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py
index f9018b3..8e7c4a2 100644
--- a/src/mlia/nn/tensorflow/optimizations/clustering.py
+++ b/src/mlia/nn/tensorflow/optimizations/clustering.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
Contains class Clusterer that clusters unique weights per layer to a specified number.
@@ -12,8 +12,8 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Any
-import tensorflow as tf
import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from tensorflow_model_optimization.python.core.clustering.keras.experimental import ( # pylint: disable=no-name-in-module
cluster as experimental_cluster,
)
@@ -50,7 +50,7 @@ class Clusterer(Optimizer):
"""
def __init__(
- self, model: tf.keras.Model, optimizer_configuration: ClusteringConfiguration
+ self, model: keras.Model, optimizer_configuration: ClusteringConfiguration
):
"""Init Clusterer instance."""
self.model = model
@@ -69,8 +69,8 @@ class Clusterer(Optimizer):
}
def _apply_clustering_to_layer(
- self, layer: tf.keras.layers.Layer
- ) -> tf.keras.layers.Layer:
+ self, layer: keras.layers.Layer
+ ) -> keras.layers.Layer:
layers_to_optimize = self.optimizer_configuration.layers_to_optimize
assert layers_to_optimize, "List of the layers to optimize is empty"
@@ -81,7 +81,7 @@ class Clusterer(Optimizer):
return experimental_cluster.cluster_weights(layer, **clustering_params)
def _init_for_clustering(self) -> None:
- # Use `tf.keras.models.clone_model` to apply `apply_clustering_to_layer`
+ # Use `keras.models.clone_model` to apply `apply_clustering_to_layer`
# to the layers of the model
if not self.optimizer_configuration.layers_to_optimize:
clustering_params = self._setup_clustering_params()
@@ -89,7 +89,7 @@ class Clusterer(Optimizer):
self.model, **clustering_params
)
else:
- clustered_model = tf.keras.models.clone_model(
+ clustered_model = keras.models.clone_model(
self.model, clone_function=self._apply_clustering_to_layer
)
@@ -103,6 +103,6 @@ class Clusterer(Optimizer):
self._init_for_clustering()
self._strip_clustering()
- def get_model(self) -> tf.keras.Model:
+ def get_model(self) -> keras.Model:
"""Get model."""
return self.model
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
index a30b301..866e209 100644
--- a/src/mlia/nn/tensorflow/optimizations/pruning.py
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
Contains class Pruner to prune a model to a specified sparsity.
@@ -15,8 +15,8 @@ from dataclasses import dataclass
from typing import Any
import numpy as np
-import tensorflow as tf
import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module
prune_registry,
)
@@ -27,6 +27,7 @@ from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
+
logger = logging.getLogger(__name__)
@@ -58,7 +59,7 @@ class PrunableLayerPolicy(tfmot.sparsity.keras.PruningPolicy):
are compatible with the pruning API, and that the model supports pruning.
"""
- def allow_pruning(self, layer: tf.keras.layers.Layer) -> Any:
+ def allow_pruning(self, layer: keras.layers.Layer) -> Any:
"""Allow pruning only for layers that are prunable.
Checks the PruneRegistry in TensorFlow Model Optimization Toolkit.
@@ -71,13 +72,13 @@ class PrunableLayerPolicy(tfmot.sparsity.keras.PruningPolicy):
return layer_is_supported
- def ensure_model_supports_pruning(self, model: tf.keras.Model) -> None:
+ def ensure_model_supports_pruning(self, model: keras.Model) -> None:
"""Ensure that the model contains only supported layers."""
# Check whether the model is a Keras model.
- if not isinstance(model, tf.keras.Model):
+ if not isinstance(model, keras.Model):
raise ValueError(
"Models that are not part of the \
- tf.keras.Model base class \
+ keras.Model base class \
are not supported currently."
)
@@ -99,7 +100,7 @@ class Pruner(Optimizer):
"""
def __init__(
- self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration
+ self, model: keras.Model, optimizer_configuration: PruningConfiguration
):
"""Init Pruner instance."""
self.model = model
@@ -132,9 +133,7 @@ class Pruner(Optimizer):
),
}
- def _apply_pruning_to_layer(
- self, layer: tf.keras.layers.Layer
- ) -> tf.keras.layers.Layer:
+ def _apply_pruning_to_layer(self, layer: keras.layers.Layer) -> keras.layers.Layer:
layers_to_optimize = self.optimizer_configuration.layers_to_optimize
assert layers_to_optimize, "List of the layers to optimize is empty"
@@ -145,7 +144,7 @@ class Pruner(Optimizer):
return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
def _init_for_pruning(self) -> None:
- # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer`
+ # Use `keras.models.clone_model` to apply `apply_pruning_to_layer`
# to the layers of the model
if not self.optimizer_configuration.layers_to_optimize:
pruning_params = self._setup_pruning_params()
@@ -153,14 +152,14 @@ class Pruner(Optimizer):
self.model, pruning_policy=PrunableLayerPolicy(), **pruning_params
)
else:
- prunable_model = tf.keras.models.clone_model(
+ prunable_model = keras.models.clone_model(
self.model, clone_function=self._apply_pruning_to_layer
)
self.model = prunable_model
def _train_pruning(self) -> None:
- loss_fn = tf.keras.losses.MeanAbsolutePercentageError()
+ loss_fn = keras.losses.MeanAbsolutePercentageError()
self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
# Model callbacks
@@ -183,8 +182,8 @@ class Pruner(Optimizer):
continue
for weight in layer.layer.get_prunable_weights():
- nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight))
- all_weights = tf.keras.backend.get_value(weight).size
+ nonzero_weights = np.count_nonzero(keras.backend.get_value(weight))
+ all_weights = keras.backend.get_value(weight).size
# Types need to be ignored for this function call because
# np.testing.assert_approx_equal does not have type annotation while the
@@ -205,6 +204,6 @@ class Pruner(Optimizer):
self._assert_sparsity_reached()
self._strip_pruning()
- def get_model(self) -> tf.keras.Model:
+ def get_model(self) -> keras.Model:
"""Get model."""
return self.model
diff --git a/src/mlia/nn/tensorflow/tflite_convert.py b/src/mlia/nn/tensorflow/tflite_convert.py
index d3a833a..29839d6 100644
--- a/src/mlia/nn/tensorflow/tflite_convert.py
+++ b/src/mlia/nn/tensorflow/tflite_convert.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Support module to call TFLiteConverter."""
from __future__ import annotations
@@ -14,6 +14,7 @@ from typing import Iterable
import numpy as np
import tensorflow as tf
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.nn.tensorflow.utils import get_tf_tensor_shape
from mlia.nn.tensorflow.utils import is_keras_model
@@ -23,6 +24,7 @@ from mlia.utils.logging import redirect_output
from mlia.utils.proc import Command
from mlia.utils.proc import command_output
+
logger = logging.getLogger(__name__)
@@ -40,21 +42,21 @@ def representative_dataset(
def get_tflite_converter(
- model: tf.keras.Model | str | Path, quantized: bool = False
+ model: keras.Model | str | Path, quantized: bool = False
) -> tf.lite.TFLiteConverter:
"""Configure TensorFlow Lite converter for the provided model."""
if isinstance(model, (str, Path)):
# converter's methods accept string as input parameter
model = str(model)
- if isinstance(model, tf.keras.Model):
+ if isinstance(model, keras.Model):
converter = tf.lite.TFLiteConverter.from_keras_model(model)
input_shape = model.input_shape
elif isinstance(model, str) and is_saved_model(model):
converter = tf.lite.TFLiteConverter.from_saved_model(model)
input_shape = get_tf_tensor_shape(model)
elif isinstance(model, str) and is_keras_model(model):
- keras_model = tf.keras.models.load_model(model)
+ keras_model = keras.models.load_model(model)
input_shape = keras_model.input_shape
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
else:
@@ -70,9 +72,7 @@ def get_tflite_converter(
return converter
-def convert_to_tflite_bytes(
- model: tf.keras.Model | str, quantized: bool = False
-) -> bytes:
+def convert_to_tflite_bytes(model: keras.Model | str, quantized: bool = False) -> bytes:
"""Convert Keras model to TensorFlow Lite."""
converter = get_tflite_converter(model, quantized)
@@ -83,7 +83,7 @@ def convert_to_tflite_bytes(
def _convert_to_tflite(
- model: tf.keras.Model | str,
+ model: keras.Model | str,
quantized: bool = False,
output_path: Path | None = None,
) -> bytes:
@@ -97,7 +97,7 @@ def _convert_to_tflite(
def convert_to_tflite(
- model: tf.keras.Model | str,
+ model: keras.Model | str,
quantized: bool = False,
output_path: Path | None = None,
input_path: Path | None = None,
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index 1612447..3ac5064 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Collection of useful functions for optimizations."""
@@ -8,6 +8,7 @@ from pathlib import Path
from typing import Any
import tensorflow as tf
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
def get_tf_tensor_shape(model: str) -> list:
@@ -30,7 +31,7 @@ def get_tf_tensor_shape(model: str) -> list:
def save_keras_model(
- model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True
+ model: keras.Model, save_path: str | Path, include_optimizer: bool = True
) -> None:
"""Save Keras model at provided path."""
model.save(save_path, include_optimizer=include_optimizer)