aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/optimizations/clustering.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/optimizations/clustering.py')
-rw-r--r--src/mlia/nn/tensorflow/optimizations/clustering.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py
index f9018b3..8e7c4a2 100644
--- a/src/mlia/nn/tensorflow/optimizations/clustering.py
+++ b/src/mlia/nn/tensorflow/optimizations/clustering.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
Contains class Clusterer that clusters unique weights per layer to a specified number.
@@ -12,8 +12,8 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Any
-import tensorflow as tf
import tensorflow_model_optimization as tfmot
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from tensorflow_model_optimization.python.core.clustering.keras.experimental import ( # pylint: disable=no-name-in-module
cluster as experimental_cluster,
)
@@ -50,7 +50,7 @@ class Clusterer(Optimizer):
"""
def __init__(
- self, model: tf.keras.Model, optimizer_configuration: ClusteringConfiguration
+ self, model: keras.Model, optimizer_configuration: ClusteringConfiguration
):
"""Init Clusterer instance."""
self.model = model
@@ -69,8 +69,8 @@ class Clusterer(Optimizer):
}
def _apply_clustering_to_layer(
- self, layer: tf.keras.layers.Layer
- ) -> tf.keras.layers.Layer:
+ self, layer: keras.layers.Layer
+ ) -> keras.layers.Layer:
layers_to_optimize = self.optimizer_configuration.layers_to_optimize
assert layers_to_optimize, "List of the layers to optimize is empty"
@@ -81,7 +81,7 @@ class Clusterer(Optimizer):
return experimental_cluster.cluster_weights(layer, **clustering_params)
def _init_for_clustering(self) -> None:
- # Use `tf.keras.models.clone_model` to apply `apply_clustering_to_layer`
+ # Use `keras.models.clone_model` to apply `apply_clustering_to_layer`
# to the layers of the model
if not self.optimizer_configuration.layers_to_optimize:
clustering_params = self._setup_clustering_params()
@@ -89,7 +89,7 @@ class Clusterer(Optimizer):
self.model, **clustering_params
)
else:
- clustered_model = tf.keras.models.clone_model(
+ clustered_model = keras.models.clone_model(
self.model, clone_function=self._apply_clustering_to_layer
)
@@ -103,6 +103,6 @@ class Clusterer(Optimizer):
self._init_for_clustering()
self._strip_clustering()
- def get_model(self) -> tf.keras.Model:
+ def get_model(self) -> keras.Model:
"""Get model."""
return self.model