From f3f3ab451968350b8f6df2de7c60b2c2b9320b59 Mon Sep 17 00:00:00 2001 From: Nathan Bailey Date: Wed, 20 Mar 2024 08:13:39 +0000 Subject: feat: Update Vela version Updates Vela Version to 3.11.0 and TensorFlow version to 2.15.1 Required keras import to change: from keras.api._v2 import keras needed instead of calling tf.keras Subsequently tf.keras.X needed to change to keras.X Resolves: MLIA-1107 Signed-off-by: Nathan Bailey Change-Id: I53bcaa9cdad58b0e6c311c8c6490393d33cb18bc --- tests/test_nn_rewrite_core_train.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'tests/test_nn_rewrite_core_train.py') diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 7fb6f85..6d24133 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module mlia.nn.rewrite.core.train.""" # pylint: disable=too-many-arguments @@ -12,6 +12,7 @@ from typing import Any import numpy as np import pytest import tensorflow as tf +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.rewrite.core.train import augment_fn_twins from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS @@ -24,7 +25,7 @@ from tests.utils.rewrite import MockTrainingParameters def replace_fully_connected_with_conv( input_shape: Any, output_shape: Any -) -> tf.keras.Model: +) -> keras.Model: """Get a replacement model for the fully connected layer.""" for name, shape in { "Input": input_shape, @@ -33,11 +34,11 @@ def replace_fully_connected_with_conv( if len(shape) != 1: raise RuntimeError(f"{name}: shape (N,) expected, but it is {input_shape}.") - model = tf.keras.Sequential(name="RewriteModel") - model.add(tf.keras.Input(input_shape)) - model.add(tf.keras.layers.Reshape((1, 1, input_shape[0]))) - model.add(tf.keras.layers.Conv2D(filters=output_shape[0], kernel_size=(1, 1))) - model.add(tf.keras.layers.Reshape(output_shape)) + model = keras.Sequential(name="RewriteModel") + model.add(keras.Input(input_shape)) + model.add(keras.layers.Reshape((1, 1, input_shape[0]))) + model.add(keras.layers.Conv2D(filters=output_shape[0], kernel_size=(1, 1))) + model.add(keras.layers.Reshape(output_shape)) return model -- cgit v1.2.1