diff options
Diffstat (limited to 'ethosu/vela/utils.py')
-rw-r--r-- | ethosu/vela/utils.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/ethosu/vela/utils.py b/ethosu/vela/utils.py index 6a368979..11c253c0 100644 --- a/ethosu/vela/utils.py +++ b/ethosu/vela/utils.py @@ -84,3 +84,32 @@ def progress_print( return print(f"{context_str}{message}") + + +def calc_resize_factor(ifm_width: int, stride_x: int) -> tuple[int, int]: + """Compute resize factor for strided Conv2D optimization.""" + # Define strides that are supported by HW + hw_supported_strides = (2, 3) + resize_factor = stride_x + + if ifm_width % resize_factor != 0: + # In case it is not divisible, check if the resize factor is + # divisible by any of the hw_supported_strides. If it is, re-compute + # the resize factor to be the value that leads us to + # reach a hw supported stride. The IFM width needs to be divisible by the new stride. + # E.g.: IFM width = 133, stride = 14, filter width = 7 can be + # optimised to IFM width = 19, stride = 2, filter width = 7 using + # a resize factor of 7. The final stride is 2 which is + # supported by the hardware. + + # Filter strides that can be obtained from current stride + divisible_strides = (x for x in hw_supported_strides if resize_factor % x == 0) + # Remove strides that are not IFM width divisors + divisor_strides = (x for x in divisible_strides if ifm_width % (stride_x // x) == 0) + # Compute new resize factor based on chosen stride + new_resize_factor = resize_factor // next(divisor_strides, 1) + resize_factor = new_resize_factor if resize_factor != new_resize_factor else 1 + + optimised_stride = stride_x // resize_factor + + return resize_factor, optimised_stride |