aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/rewrite.py
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-03-20 08:13:39 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-03-28 07:17:32 +0000
commitf3f3ab451968350b8f6df2de7c60b2c2b9320b59 (patch)
tree05d56c8e41de9b32f8054019a21b78628151310d /src/mlia/nn/rewrite/core/rewrite.py
parent5f063ae1cfbfa2568d2858af0a0ccaf192bb1e8d (diff)
downloadmlia-f3f3ab451968350b8f6df2de7c60b2c2b9320b59.tar.gz
feat: Update Vela version
Updates Vela Version to 3.11.0 and TensorFlow version to 2.15.1 Required keras import to change: from keras.api._v2 import keras needed instead of calling tf.keras Subsequently tf.keras.X needed to change to keras.X Resolves: MLIA-1107 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I53bcaa9cdad58b0e6c311c8c6490393d33cb18bc
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 8658991..c7d13ba 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -12,7 +12,7 @@ from typing import Any
from typing import Callable
from typing import cast
-import tensorflow as tf
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.core.errors import ConfigurationError
from mlia.core.reporting import Column
@@ -25,8 +25,9 @@ from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
+
logger = logging.getLogger(__name__)
-RewriteCallable = Callable[[Any, Any], tf.keras.Model]
+RewriteCallable = Callable[[Any, Any], keras.Model]
class Rewrite:
@@ -37,7 +38,7 @@ class Rewrite:
self.name = name
self.function = rewrite_fn
- def __call__(self, input_shape: Any, output_shape: Any) -> tf.keras.Model:
+ def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model:
"""Perform the rewrite operation using the configured function."""
try:
return self.function(input_shape, output_shape)
@@ -52,7 +53,7 @@ class DynamicallyLoadedRewrite(Rewrite):
def __init__(self, name: str, function_name: str):
"""Initialize."""
- def load_and_run(input_shape: Any, output_shape: Any) -> tf.keras.Model:
+ def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model:
"""Load the function from a file dynamically."""
self.load_function(function_name)
return self.function(input_shape, output_shape)