From 918a9fb4aa4be23ca4261c241e9e52acc42f9bb3 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Tue, 15 Feb 2022 11:40:13 +0000 Subject: Add Pool3d reference implementation This patch - adds the reference implementation for the 3D pooling layer - supports FP32/FP16 and INT8/UINT8 types - adds a function to calculate the output shape for 3D pooling - adds a new type for describing pool 3d info (Pool3DInfo) Resolves: COMPMID-4659 Change-Id: I22a18fa30625c98fa827ef1b50781db6893ba9c4 Signed-off-by: Gunes Bayir Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7219 Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- arm_compute/core/utils/misc/ShapeCalculator.h | 59 ++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) (limited to 'arm_compute/core/utils') diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 3e8b024f82..ee4fe0c02f 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -1460,6 +1460,63 @@ inline TensorShape compute_conv3d_shape(const TensorShape &src, const TensorShap return output_shape; } +/** Calculate the output pool3d shape of a tensor + * + * @param[in] src Input tensor info + * @param[in] pool3d_info Pooling layer info + * + * @return the calculated shape + */ +inline TensorShape compute_pool3d_shape(const TensorShape &src, Pool3DInfo pool3d_info) +{ + TensorShape output_shape{ src }; + + const int idx_width = 1; + const int idx_height = 2; + const int idx_depth = 3; + const int pool_size_width = pool3d_info.is_global_pooling ? src[idx_width] : pool3d_info.pool_size.width; + const int pool_size_height = pool3d_info.is_global_pooling ? src[idx_height] : pool3d_info.pool_size.height; + const int pool_size_depth = pool3d_info.is_global_pooling ? src[idx_depth] : pool3d_info.pool_size.depth; + const int pool_stride_width = pool3d_info.strides.width; + const int pool_stride_height = pool3d_info.strides.height; + const int pool_stride_depth = pool3d_info.strides.depth; + + int output_width_size = 0; + int output_height_size = 0; + int output_depth_size = 0; + + const size_t pad_left = pool3d_info.padding.left; + const size_t pad_right = pool3d_info.padding.right; + const size_t pad_top = pool3d_info.padding.top; + const size_t pad_bottom = pool3d_info.padding.bottom; + const size_t pad_front = pool3d_info.padding.front; + const size_t pad_back = pool3d_info.padding.back; + + switch(pool3d_info.round_type) + { + case DimensionRoundingType::FLOOR: + output_width_size = static_cast(std::floor((static_cast(src[idx_width] + pad_left + pad_right - pool_size_width)) / pool_stride_width) + 1); + output_height_size = static_cast(std::floor((static_cast(src[idx_height] + pad_top + pad_bottom - pool_size_height)) / pool_stride_height) + 1); + output_depth_size = static_cast(std::floor((static_cast(src[idx_depth] + pad_front + pad_back - pool_size_depth)) / pool_stride_depth) + 1); + break; + case DimensionRoundingType::CEIL: + output_width_size = static_cast(std::ceil((static_cast(src[idx_width] + pad_left + pad_right - pool_size_width)) / pool_stride_width) + 1); + output_height_size = static_cast(std::ceil((static_cast(src[idx_height] + pad_top + pad_bottom - pool_size_height)) / pool_stride_height) + 1); + output_depth_size = static_cast(std::ceil((static_cast(src[idx_depth] + pad_front + pad_back - pool_size_depth)) / pool_stride_depth) + 1); + break; + default: + ARM_COMPUTE_ERROR("Unsupported rounding type"); + } + + ARM_COMPUTE_ERROR_ON_MSG((output_width_size < 1 || output_height_size < 1 || output_depth_size < 1), "Calculated output dimension size is invalid"); + + output_shape.set(idx_width, static_cast(output_width_size)); + output_shape.set(idx_height, static_cast(output_height_size)); + output_shape.set(idx_depth, static_cast(output_depth_size)); + + return output_shape; +} + inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis) { ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 1); -- cgit v1.2.1