aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/select.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r--src/mlia/nn/select.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 20950cc..81a614f 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -10,7 +10,7 @@ from typing import cast
from typing import List
from typing import NamedTuple
-import tensorflow as tf
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.core.errors import ConfigurationError
from mlia.nn.common import Optimizer
@@ -91,7 +91,7 @@ class MultiStageOptimizer(Optimizer):
def __init__(
self,
- model: tf.keras.Model,
+ model: keras.Model,
optimizations: list[OptimizerConfiguration],
) -> None:
"""Init MultiStageOptimizer instance."""
@@ -115,7 +115,7 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
- model: tf.keras.Model | KerasModel | TFLiteModel,
+ model: keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
training_parameters: list[dict | None] | None = None,
) -> Optimizer:
@@ -149,7 +149,7 @@ def get_optimizer(
def _get_optimizer(
- model: tf.keras.Model | Path,
+ model: keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
training_parameters: list[dict | None] | None = None,
) -> Optimizer: