From f3f3ab451968350b8f6df2de7c60b2c2b9320b59 Mon Sep 17 00:00:00 2001 From: Nathan Bailey Date: Wed, 20 Mar 2024 08:13:39 +0000 Subject: feat: Update Vela version Updates Vela Version to 3.11.0 and TensorFlow version to 2.15.1 Required keras import to change: from keras.api._v2 import keras needed instead of calling tf.keras Subsequently tf.keras.X needed to change to keras.X Resolves: MLIA-1107 Signed-off-by: Nathan Bailey Change-Id: I53bcaa9cdad58b0e6c311c8c6490393d33cb18bc --- src/mlia/nn/rewrite/core/rewrite.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'src/mlia/nn/rewrite/core/rewrite.py') diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 8658991..c7d13ba 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -12,7 +12,7 @@ from typing import Any from typing import Callable from typing import cast -import tensorflow as tf +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.core.errors import ConfigurationError from mlia.core.reporting import Column @@ -25,8 +25,9 @@ from mlia.nn.rewrite.core.train import TrainingParameters from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry + logger = logging.getLogger(__name__) -RewriteCallable = Callable[[Any, Any], tf.keras.Model] +RewriteCallable = Callable[[Any, Any], keras.Model] class Rewrite: @@ -37,7 +38,7 @@ class Rewrite: self.name = name self.function = rewrite_fn - def __call__(self, input_shape: Any, output_shape: Any) -> tf.keras.Model: + def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model: """Perform the rewrite operation using the configured function.""" try: return self.function(input_shape, output_shape) @@ -52,7 +53,7 @@ class DynamicallyLoadedRewrite(Rewrite): def __init__(self, name: str, function_name: str): """Initialize.""" - def load_and_run(input_shape: Any, output_shape: Any) -> tf.keras.Model: + def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model: """Load the function from a file dynamically.""" self.load_function(function_name) return self.function(input_shape, output_shape) -- cgit v1.2.1