aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/utils.py
blob: ee501f3f00b5c9b3b309717c910acf31cc3c0d14 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Description:
# Contains various utility functions used across the codebase.
from __future__ import annotations

import collections
import inspect


def progress_print(
    enabled: bool,
    message: str,
    progress_counter: int = -1,
    progress_total: int | collections.abc.Sized = 0,
    progress_granularity: float = 0.20,
):
    """Print progress information.

    :param enabled: boolean indicating whether message should be printed.
    :param message: message to be printed
    :param progress_counter: the value of the incremental counter that indicates the progress
    :param progress_total: integer value or sized data structure to use to extract the total number of elements that
                           progress is measured against
    :param progress_granularity: floating point percentage indicating how often progress information should be printed
    :param enable_context: boolean used to indicate whether context information should be printed with the message

    Example
    -------
    def example_function(verbose_progress: bool = True):
        a_list = [x for x in range(101)]
        for index, value in a:
            progress_print(verbose_progress,
                            message="Processing",
                            progress_counter=index,
                            progress_total=a_list,
                            progress_granulrity=0.25,
                            enable_context=True)

    **Output**
    Processing 0/100
    Processing 25/100
    Processing 50/100
    Processing 75/100
    Processing 100/100
    """
    if not enabled:
        return

    context_str = ""
    # Get calling function name
    context_str = inspect.stack()[1].function
    context_str += ": " if message else ""
    display_total = progress_total
    # If a sized collection is provided, extract its size to use as progress total
    if isinstance(progress_total, collections.abc.Sized):
        progress_total = len(progress_total)
        display_total = progress_total - 1

    # Print progress information with "counter/total" information
    if progress_counter > -1 and progress_total > 0 and 0 < progress_granularity < 1:
        # Extract progress frequency and ensure it is not equal to 0 (avoid zero division)
        progress_frequency = int(progress_total * progress_granularity)
        progress_frequency = progress_frequency if progress_frequency else 1
        # Check whether information should be printed based on computed progress frequency
        if (
            progress_counter % progress_frequency == 0 and progress_counter <= progress_total - progress_frequency
        ) or progress_counter == display_total:
            print(f"{context_str}{message} {progress_counter}/{display_total}")
        return

    print(f"{context_str}{message}")


def calc_resize_factor(ifm_width: int, stride_x: int) -> tuple[int, int]:
    """Compute resize factor for strided Conv2D optimization."""
    # Define strides that are supported by HW
    hw_supported_strides = (2, 3)
    resize_factor = stride_x

    if ifm_width % resize_factor != 0:
        # In case it is not divisible, check if the resize factor is
        # divisible by any of the hw_supported_strides. If it is, re-compute
        # the resize factor to be the value that leads us to
        # reach a hw supported stride. The IFM width needs to be divisible by the new resize factor.
        # E.g.: IFM width = 133, stride = 14, filter width = 7 can be
        #       optimised to IFM width = 19, stride = 2, filter width = 7 using
        #       a resize factor of 7. The final stride is 2 which is
        #       supported by the hardware.

        # Filter strides that can be obtained from current stride
        divisible_strides = (x for x in hw_supported_strides if resize_factor % x == 0)
        # Remove strides that are not IFM width divisors
        divisor_strides = (x for x in divisible_strides if ifm_width % (stride_x // x) == 0)
        # Compute new resize factor based on chosen stride
        new_resize_factor = resize_factor // next(divisor_strides, 1)
        resize_factor = new_resize_factor if resize_factor != new_resize_factor else 1

    optimised_stride = stride_x // resize_factor

    return resize_factor, optimised_stride