aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 8658991..c7d13ba 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -12,7 +12,7 @@ from typing import Any
from typing import Callable
from typing import cast
-import tensorflow as tf
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.core.errors import ConfigurationError
from mlia.core.reporting import Column
@@ -25,8 +25,9 @@ from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
+
logger = logging.getLogger(__name__)
-RewriteCallable = Callable[[Any, Any], tf.keras.Model]
+RewriteCallable = Callable[[Any, Any], keras.Model]
class Rewrite:
@@ -37,7 +38,7 @@ class Rewrite:
self.name = name
self.function = rewrite_fn
- def __call__(self, input_shape: Any, output_shape: Any) -> tf.keras.Model:
+ def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model:
"""Perform the rewrite operation using the configured function."""
try:
return self.function(input_shape, output_shape)
@@ -52,7 +53,7 @@ class DynamicallyLoadedRewrite(Rewrite):
def __init__(self, name: str, function_name: str):
"""Initialize."""
- def load_and_run(input_shape: Any, output_shape: Any) -> tf.keras.Model:
+ def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model:
"""Load the function from a file dynamically."""
self.load_function(function_name)
return self.function(input_shape, output_shape)