diff options
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r-- | src/mlia/nn/select.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 20950cc..81a614f 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -10,7 +10,7 @@ from typing import cast from typing import List from typing import NamedTuple -import tensorflow as tf +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.core.errors import ConfigurationError from mlia.nn.common import Optimizer @@ -91,7 +91,7 @@ class MultiStageOptimizer(Optimizer): def __init__( self, - model: tf.keras.Model, + model: keras.Model, optimizations: list[OptimizerConfiguration], ) -> None: """Init MultiStageOptimizer instance.""" @@ -115,7 +115,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( - model: tf.keras.Model | KerasModel | TFLiteModel, + model: keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], training_parameters: list[dict | None] | None = None, ) -> Optimizer: @@ -149,7 +149,7 @@ def get_optimizer( def _get_optimizer( - model: tf.keras.Model | Path, + model: keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], training_parameters: list[dict | None] | None = None, ) -> Optimizer: |