diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-15 08:12:30 +0100 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-21 16:51:15 +0100 |
commit | 856111bcaef76c60303bdf2ae7cbf718d93d1df4 (patch) | |
tree | d955901817194e48e478f751140bd3c1741d1834 /src/mlia/nn/rewrite/library/helper_functions.py | |
parent | 0d3cc76284f9311c99169b568570d767f5b0aeb6 (diff) | |
download | mlia-856111bcaef76c60303bdf2ae7cbf718d93d1df4.tar.gz |
feat: Implement the conv2D rewrites for int8 and fp32 models
Enable clustering and fully connected rewrites for conv2D layers.
Resolves: MLIA-1159 and MLIA-1160
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I640b8a7e79e455b12fb68d02ac1c33213b8de9c6
Diffstat (limited to 'src/mlia/nn/rewrite/library/helper_functions.py')
-rw-r--r-- | src/mlia/nn/rewrite/library/helper_functions.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py new file mode 100644 index 0000000..4f08170 --- /dev/null +++ b/src/mlia/nn/rewrite/library/helper_functions.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Helper functions for the rewrite library.""" +import math +from typing import Any + +import numpy as np + + +def compute_conv2d_parameters( + input_shape: np.ndarray, output_shape: np.ndarray +) -> dict[str, Any]: + """Compute needed kernel size and strides for a given input and output_shape.""" + input_shape = input_shape.tolist() + output_shape = output_shape.tolist() + assert len(input_shape) == 3 + assert len(output_shape) == 3 + num_filters = (output_shape[-1] - input_shape[-1]) + input_shape[-1] + padding = "valid" + kernel_size = (3, 3) + stride_h = round(input_shape[0] / output_shape[0]) + check_output_size_h = math.floor((input_shape[0] - kernel_size[0]) / stride_h) + 1 + stride_w = round(input_shape[1] / output_shape[1]) + check_output_size_w = math.floor((input_shape[1] - kernel_size[1]) / stride_w) + 1 + if check_output_size_h != output_shape[0] or check_output_size_w != output_shape[1]: + padding = "same" + return { + "filters": num_filters, + "kernel_size": kernel_size, + "padding": padding, + "strides": (stride_h, stride_w), + } |