diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-20 08:13:39 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-28 07:17:32 +0000 |
commit | f3f3ab451968350b8f6df2de7c60b2c2b9320b59 (patch) | |
tree | 05d56c8e41de9b32f8054019a21b78628151310d /tests/test_nn_rewrite_core_train.py | |
parent | 5f063ae1cfbfa2568d2858af0a0ccaf192bb1e8d (diff) | |
download | mlia-f3f3ab451968350b8f6df2de7c60b2c2b9320b59.tar.gz |
feat: Update Vela version
Updates Vela Version to 3.11.0 and TensorFlow version to 2.15.1
Required keras import to change:
from keras.api._v2 import keras needed instead of calling tf.keras
Subsequently tf.keras.X needed to change to keras.X
Resolves: MLIA-1107
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I53bcaa9cdad58b0e6c311c8c6490393d33cb18bc
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 7fb6f85..6d24133 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module mlia.nn.rewrite.core.train.""" # pylint: disable=too-many-arguments @@ -12,6 +12,7 @@ from typing import Any import numpy as np import pytest import tensorflow as tf +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.rewrite.core.train import augment_fn_twins from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS @@ -24,7 +25,7 @@ from tests.utils.rewrite import MockTrainingParameters def replace_fully_connected_with_conv( input_shape: Any, output_shape: Any -) -> tf.keras.Model: +) -> keras.Model: """Get a replacement model for the fully connected layer.""" for name, shape in { "Input": input_shape, @@ -33,11 +34,11 @@ def replace_fully_connected_with_conv( if len(shape) != 1: raise RuntimeError(f"{name}: shape (N,) expected, but it is {input_shape}.") - model = tf.keras.Sequential(name="RewriteModel") - model.add(tf.keras.Input(input_shape)) - model.add(tf.keras.layers.Reshape((1, 1, input_shape[0]))) - model.add(tf.keras.layers.Conv2D(filters=output_shape[0], kernel_size=(1, 1))) - model.add(tf.keras.layers.Reshape(output_shape)) + model = keras.Sequential(name="RewriteModel") + model.add(keras.Input(input_shape)) + model.add(keras.layers.Reshape((1, 1, input_shape[0]))) + model.add(keras.layers.Conv2D(filters=output_shape[0], kernel_size=(1, 1))) + model.add(keras.layers.Reshape(output_shape)) return model |