diff options
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_metrics.py')
-rw-r--r-- | tests/test_nn_tensorflow_tflite_metrics.py | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/tests/test_nn_tensorflow_tflite_metrics.py b/tests/test_nn_tensorflow_tflite_metrics.py index e8d7c09..cbb1b63 100644 --- a/tests/test_nn_tensorflow_tflite_metrics.py +++ b/tests/test_nn_tensorflow_tflite_metrics.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Test for module utils/tflite_metrics.""" from __future__ import annotations @@ -12,26 +12,27 @@ from typing import Generator import numpy as np import pytest import tensorflow as tf +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics -def _sample_keras_model() -> tf.keras.Model: +def _sample_keras_model() -> keras.Model: # Create a sample model - keras_model = tf.keras.Sequential( + keras_model = keras.Sequential( [ - tf.keras.Input(shape=(8, 8, 3)), - tf.keras.layers.Conv2D(4, 3), - tf.keras.layers.DepthwiseConv2D(3), - tf.keras.layers.Flatten(), - tf.keras.layers.Dense(8), + keras.Input(shape=(8, 8, 3)), + keras.layers.Conv2D(4, 3), + keras.layers.DepthwiseConv2D(3), + keras.layers.Flatten(), + keras.layers.Dense(8), ] ) return keras_model -def _sparse_binary_keras_model() -> tf.keras.Model: +def _sparse_binary_keras_model() -> keras.Model: def get_sparse_weights(shape: list[int]) -> np.ndarray: weights = np.zeros(shape) with np.nditer(weights, op_flags=[["writeonly"]]) as weight_it: @@ -43,7 +44,7 @@ def _sparse_binary_keras_model() -> tf.keras.Model: keras_model = _sample_keras_model() # Assign weights to have 0.5 sparsity for layer in keras_model.layers: - if not isinstance(layer, tf.keras.layers.Flatten): + if not isinstance(layer, keras.layers.Flatten): weight = layer.weights[0] weight.assign(get_sparse_weights(weight.shape)) print(layer) |