diff options
Diffstat (limited to 'tests/test_nn_rewrite_library_helper_functions.py')
-rw-r--r-- | tests/test_nn_rewrite_library_helper_functions.py | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_library_helper_functions.py b/tests/test_nn_rewrite_library_helper_functions.py new file mode 100644 index 0000000..a0dd7b9 --- /dev/null +++ b/tests/test_nn_rewrite_library_helper_functions.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.library.helper_functions.""" +from __future__ import annotations + +from contextlib import ExitStack as does_not_raise +from typing import Any + +import numpy as np +import pytest +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters +from mlia.nn.rewrite.library.helper_functions import get_activation_function + + +def compute_conv_output( + input_data: np.ndarray, input_shape: np.ndarray, conv_parameters: dict[str, Any] +) -> np.ndarray: + """Compute the output of a conv layer for testing.""" + test_model = keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Conv2D(**conv_parameters), + ] + ) + output = test_model(input_data) + return np.array(output.shape[1:]) + + +@pytest.mark.parametrize( + "input_shape, output_shape, kernel_size", + [ + (np.array([32, 32, 3]), np.array([16, 16, 3]), [3, 3]), + (np.array([32, 32, 3]), np.array([8, 8, 3]), [3, 3]), + (np.array([32, 32, 3]), np.array([8, 16, 3]), [3, 3]), + (np.array([25, 10, 3]), np.array([13, 5, 3]), [3, 3]), + (np.array([25, 10, 3]), np.array([7, 5, 3]), [3, 3]), + (np.array([25, 10, 3]), np.array([6, 4, 3]), [3, 3]), + (np.array([25, 10, 3]), np.array([5, 5, 3]), [3, 3]), + (np.array([32, 32, 3]), np.array([16, 16, 3]), [1, 3]), + (np.array([32, 32, 3]), np.array([16, 16, 3]), [1, 1]), + (np.array([32, 32, 3]), np.array([16, 16, 3]), [5, 5]), + ], +) +def test_compute_conv2d_parameters( + input_shape: np.ndarray, output_shape: np.ndarray, kernel_size: list[int] +) -> None: + """Test to check compute_conv2d_parameters works as expected.""" + conv_parameters = compute_conv2d_parameters( + input_shape=input_shape, + output_shape=output_shape, + kernel_size_input=kernel_size, + ) + computed_output_shape = compute_conv_output( + np.random.rand(1, *input_shape), input_shape, conv_parameters + ) + assert np.equal(computed_output_shape, output_shape).all() + + +@pytest.mark.parametrize( + "activation, expected_function_type, expected_extra_args, expected_error", + [ + ("relu", keras.layers.ReLU, {}, does_not_raise()), + ("relu6", keras.layers.ReLU, {"max_value": 6}, does_not_raise()), + ("none", None, {}, does_not_raise()), + ( + "wrong_key", + keras.layers.Identity, + {}, + pytest.raises( + KeyError, + match=( + "Expected activation function to be " + rf"in \{ACTIVATION_FUNCTION_LIST}\, found wrong_key" + ), + ), + ), + ], +) +def test_get_activation_functions( + activation: str, + expected_function_type: type, + expected_extra_args: dict, + expected_error: Any, +) -> None: + """ + Check the get_activation_function returns + the expected layer and extra arguments. + """ + with expected_error: + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + if activation_function: + assert isinstance( + activation_function(**activation_function_extra_args), + expected_function_type, + ) + else: + assert activation_function is None + assert expected_extra_args == activation_function_extra_args |