aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_library_helper_functions.py
blob: a0dd7b9a5c8639e1335f2f51b1e457e221118b15 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.library.helper_functions."""
from __future__ import annotations

from contextlib import ExitStack as does_not_raise
from typing import Any

import numpy as np
import pytest
from keras.api._v2 import keras  # Temporary workaround for now: MLIA-1107

from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST
from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
from mlia.nn.rewrite.library.helper_functions import get_activation_function


def compute_conv_output(
    input_data: np.ndarray, input_shape: np.ndarray, conv_parameters: dict[str, Any]
) -> np.ndarray:
    """Compute the output of a conv layer for testing."""
    test_model = keras.Sequential(
        [
            keras.layers.InputLayer(input_shape=input_shape),
            keras.layers.Conv2D(**conv_parameters),
        ]
    )
    output = test_model(input_data)
    return np.array(output.shape[1:])


@pytest.mark.parametrize(
    "input_shape, output_shape, kernel_size",
    [
        (np.array([32, 32, 3]), np.array([16, 16, 3]), [3, 3]),
        (np.array([32, 32, 3]), np.array([8, 8, 3]), [3, 3]),
        (np.array([32, 32, 3]), np.array([8, 16, 3]), [3, 3]),
        (np.array([25, 10, 3]), np.array([13, 5, 3]), [3, 3]),
        (np.array([25, 10, 3]), np.array([7, 5, 3]), [3, 3]),
        (np.array([25, 10, 3]), np.array([6, 4, 3]), [3, 3]),
        (np.array([25, 10, 3]), np.array([5, 5, 3]), [3, 3]),
        (np.array([32, 32, 3]), np.array([16, 16, 3]), [1, 3]),
        (np.array([32, 32, 3]), np.array([16, 16, 3]), [1, 1]),
        (np.array([32, 32, 3]), np.array([16, 16, 3]), [5, 5]),
    ],
)
def test_compute_conv2d_parameters(
    input_shape: np.ndarray, output_shape: np.ndarray, kernel_size: list[int]
) -> None:
    """Test to check compute_conv2d_parameters works as expected."""
    conv_parameters = compute_conv2d_parameters(
        input_shape=input_shape,
        output_shape=output_shape,
        kernel_size_input=kernel_size,
    )
    computed_output_shape = compute_conv_output(
        np.random.rand(1, *input_shape), input_shape, conv_parameters
    )
    assert np.equal(computed_output_shape, output_shape).all()


@pytest.mark.parametrize(
    "activation, expected_function_type, expected_extra_args, expected_error",
    [
        ("relu", keras.layers.ReLU, {}, does_not_raise()),
        ("relu6", keras.layers.ReLU, {"max_value": 6}, does_not_raise()),
        ("none", None, {}, does_not_raise()),
        (
            "wrong_key",
            keras.layers.Identity,
            {},
            pytest.raises(
                KeyError,
                match=(
                    "Expected activation function to be "
                    rf"in \{ACTIVATION_FUNCTION_LIST}\, found wrong_key"
                ),
            ),
        ),
    ],
)
def test_get_activation_functions(
    activation: str,
    expected_function_type: type,
    expected_extra_args: dict,
    expected_error: Any,
) -> None:
    """
    Check the get_activation_function returns
    the expected layer and extra arguments.
    """
    with expected_error:
        activation_function, activation_function_extra_args = get_activation_function(
            activation
        )
        if activation_function:
            assert isinstance(
                activation_function(**activation_function_extra_args),
                expected_function_type,
            )
        else:
            assert activation_function is None
        assert expected_extra_args == activation_function_extra_args