diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 7fb6f85..6d24133 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module mlia.nn.rewrite.core.train.""" # pylint: disable=too-many-arguments @@ -12,6 +12,7 @@ from typing import Any import numpy as np import pytest import tensorflow as tf +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.rewrite.core.train import augment_fn_twins from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS @@ -24,7 +25,7 @@ from tests.utils.rewrite import MockTrainingParameters def replace_fully_connected_with_conv( input_shape: Any, output_shape: Any -) -> tf.keras.Model: +) -> keras.Model: """Get a replacement model for the fully connected layer.""" for name, shape in { "Input": input_shape, @@ -33,11 +34,11 @@ def replace_fully_connected_with_conv( if len(shape) != 1: raise RuntimeError(f"{name}: shape (N,) expected, but it is {input_shape}.") - model = tf.keras.Sequential(name="RewriteModel") - model.add(tf.keras.Input(input_shape)) - model.add(tf.keras.layers.Reshape((1, 1, input_shape[0]))) - model.add(tf.keras.layers.Conv2D(filters=output_shape[0], kernel_size=(1, 1))) - model.add(tf.keras.layers.Reshape(output_shape)) + model = keras.Sequential(name="RewriteModel") + model.add(keras.Input(input_shape)) + model.add(keras.layers.Reshape((1, 1, input_shape[0]))) + model.add(keras.layers.Conv2D(filters=output_shape[0], kernel_size=(1, 1))) + model.add(keras.layers.Reshape(output_shape)) return model |