From ec0113dd7749991959ae351934eea0c0d8077dcb Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Wed, 9 Nov 2022 09:26:27 +0000 Subject: Optimize Transposed Convolution for CL backend (FP32/16) This patch optimizes transposed convolution for CL backend by rewriting it in a single kernel instead of three (flip_kernel + upsample + conv). The new kernel skips the upsampling step which reduces the input space of convolution by stride_x * stride_y, resulting in significant performance improvement. It also skips the kernel flipping by traversing the weights accordingly, thus reduces the memory footprint. Resolves: COMPMID-5676 Signed-off-by: Gunes Bayir Change-Id: I8a333212dc7c5f7f0597aa58b0d56d44814baa14 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8588 Tested-by: Arm Jenkins Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- Android.bp | 3 + SConscript | 1 + arm_compute/core/Types.h | 5 +- .../runtime/CL/functions/CLDeconvolutionLayer.h | 7 +- filelist.json | 2 + .../CL/cl_kernels/nhwc/transposed_convolution.cl | 253 +++++++++++++++++++++ src/gpu/cl/ClKernelLibrary.cpp | 5 + .../cl/kernels/ClTransposedConvolutionKernel.cpp | 231 +++++++++++++++++++ src/gpu/cl/kernels/ClTransposedConvolutionKernel.h | 66 ++++++ src/gpu/cl/operators/ClTransposedConvolution.cpp | 58 +++++ src/gpu/cl/operators/ClTransposedConvolution.h | 90 ++++++++ src/runtime/CL/functions/CLDeconvolutionLayer.cpp | 68 +++++- tests/datasets/ShapeDatasets.h | 16 +- tests/validation/CL/DeconvolutionLayer.cpp | 52 ++++- 14 files changed, 844 insertions(+), 13 deletions(-) create mode 100644 src/core/CL/cl_kernels/nhwc/transposed_convolution.cl create mode 100644 src/gpu/cl/kernels/ClTransposedConvolutionKernel.cpp create mode 100644 src/gpu/cl/kernels/ClTransposedConvolutionKernel.h create mode 100644 src/gpu/cl/operators/ClTransposedConvolution.cpp create mode 100644 src/gpu/cl/operators/ClTransposedConvolution.h diff --git a/Android.bp b/Android.bp index d02d436fa0..89a7a43060 100644 --- a/Android.bp +++ b/Android.bp @@ -119,6 +119,7 @@ opencl_srcs = [ "src/core/CL/cl_kernels/nhwc/scale.cl", "src/core/CL/cl_kernels/nhwc/space_to_batch.cl", "src/core/CL/cl_kernels/nhwc/space_to_depth.cl", + "src/core/CL/cl_kernels/nhwc/transposed_convolution.cl", "src/core/CL/cl_kernels/nhwc/upsample_layer.cl", "src/core/CL/cl_kernels/nhwc/winograd_filter_transform.cl", "src/core/CL/cl_kernels/nhwc/winograd_input_transform.cl", @@ -656,6 +657,7 @@ cc_library_static { "src/gpu/cl/kernels/ClScaleKernel.cpp", "src/gpu/cl/kernels/ClSoftmaxKernel.cpp", "src/gpu/cl/kernels/ClTransposeKernel.cpp", + "src/gpu/cl/kernels/ClTransposedConvolutionKernel.cpp", "src/gpu/cl/kernels/ClWeightsReshapeKernel.cpp", "src/gpu/cl/kernels/ClWidthConcatenate2TensorsKernel.cpp", "src/gpu/cl/kernels/ClWidthConcatenate4TensorsKernel.cpp", @@ -707,6 +709,7 @@ cc_library_static { "src/gpu/cl/operators/ClSoftmax.cpp", "src/gpu/cl/operators/ClSub.cpp", "src/gpu/cl/operators/ClTranspose.cpp", + "src/gpu/cl/operators/ClTransposedConvolution.cpp", "src/gpu/cl/operators/ClWinogradConv2d.cpp", "src/gpu/cl/operators/experimental/dynamic_fusion/ClCompositeOperator.cpp", "src/runtime/Allocator.cpp", diff --git a/SConscript b/SConscript index c05c1fe484..42a03f0a04 100644 --- a/SConscript +++ b/SConscript @@ -460,6 +460,7 @@ if env['opencl'] and env['embed_kernels']: 'src/core/CL/cl_kernels/nhwc/scale.cl', 'src/core/CL/cl_kernels/nhwc/space_to_batch.cl', 'src/core/CL/cl_kernels/nhwc/space_to_depth.cl', + 'src/core/CL/cl_kernels/nhwc/transposed_convolution.cl', 'src/core/CL/cl_kernels/nhwc/upsample_layer.cl', 'src/core/CL/cl_kernels/nhwc/winograd_filter_transform.cl', 'src/core/CL/cl_kernels/nhwc/winograd_input_transform.cl', diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index b0a6475527..d5a4125c88 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -150,8 +150,9 @@ enum class DepthwiseConvolutionFunction /** Available DeconvolutionMethod*/ enum class DeconvolutionMethod { - GEMM, /**< Deconvolution using GEMM */ - DIRECT, /**< Direct deconvolution */ + GEMM, /**< Deconvolution using GEMM */ + DIRECT, /**< Direct deconvolution */ + UPSCALE_CONV2D /**< Deconvolution with Upscaling */ }; /** Available FuseBatchNormalizationType*/ diff --git a/arm_compute/runtime/CL/functions/CLDeconvolutionLayer.h b/arm_compute/runtime/CL/functions/CLDeconvolutionLayer.h index 8ad805492d..0c59e2c86d 100644 --- a/arm_compute/runtime/CL/functions/CLDeconvolutionLayer.h +++ b/arm_compute/runtime/CL/functions/CLDeconvolutionLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -44,6 +44,8 @@ public: /** Default constructor */ CLDeconvolutionLayer(std::shared_ptr memory_manager = nullptr); + ~CLDeconvolutionLayer(); + /** Set the input, weights, biases and output tensors. * * Valid data layouts: @@ -105,6 +107,9 @@ public: private: std::shared_ptr _memory_manager; std::unique_ptr _function; + + struct Impl; + std::unique_ptr _impl; }; } // namespace arm_compute #endif /* ARM_COMPUTE_CLDECONVOLUTIONLAYER_H */ diff --git a/filelist.json b/filelist.json index a99dced647..21bc35c644 100644 --- a/filelist.json +++ b/filelist.json @@ -337,6 +337,8 @@ "common": [ "src/core/CL/kernels/CLDeconvolutionLayerUpsampleKernel.cpp", "src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp", + "src/gpu/cl/kernels/ClTransposedConvolutionKernel.cpp", + "src/gpu/cl/operators/ClTransposedConvolution.cpp", "src/runtime/CL/functions/CLDeconvolutionLayer.cpp", "src/runtime/CL/functions/CLDeconvolutionLayerUpsample.cpp", "src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp", diff --git a/src/core/CL/cl_kernels/nhwc/transposed_convolution.cl b/src/core/CL/cl_kernels/nhwc/transposed_convolution.cl new file mode 100644 index 0000000000..8872c31229 --- /dev/null +++ b/src/core/CL/cl_kernels/nhwc/transposed_convolution.cl @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "helpers.h" +#include "tile_helpers.h" + +//! @cond Doxygen_Suppress +/** OpenCL kernel to compute the transposed convolution. + * + * @note Data layout supported: NHWC + * @note Data type supported: F32/F16 + * @note The transposed convolution padding (left and top) must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (e.g. -DPAD_LEFT=2, -DPAD_TOP=2) + * @note The transposed convolution strides must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y (e.g. -DSTRIDE_X=2, -DSTRIDE_Y=2) + * @note The spatial dimensions of the weights must be passed at compile time using -DWEI_WIDTH and -DWEI_HEIGHT (e.g. -DWEI_WIDTH=9, -DWEI_HEIGHT=9) + * @note The spatial dimensions of the source tensor must be passed at compile time using -DSRC_WIDTH and -DSRC_HEIGHT (e.g. -DSRC_WIDTH=96, -DSRC_HEIGHT=64) + * @note The spatial dimensions of the destination tensor must be passed at compile time using -DDST_WIDTH and -DDST_HEIGHT (e.g. -DDST_WIDTH=96, -DDST_HEIGHT=64) + * @note The channels of the source tensor must be passed at compile time using -DSRC_CHANNELS (e.g. -DSRC_CHANNELS=64) + * @note The channels of the destination tensor must be passed at compile time using -DDST_CHANNELS (e.g. -DDST_CHANNELS=64) + * @note The tensor type (currently only "BUFFER" is supported) of the source tensor must be passed at compile time using -DSRC_TENSOR_TYPE (e.g. -DSRC_TENSOR_TYPE=BUFFER) + * @note The tensor type (currently only "BUFFER" is supported) of the weights tensor must be passed at compile time using -DWEI_TENSOR_TYPE (e.g. -DWEI_TENSOR_TYPE=BUFFER) + * @note The tensor type (currently only "BUFFER" is supported) of the destination tensor must be passed at compile time using -DDST_TENSOR_TYPE (e.g. -DDST_TENSOR_TYPE=BUFFER) + * @note The data type of the source tensor must be passed at compile time using -DSRC_DATA_TYPE (e.g. -DSRC_DATA_TYPE=float) + * @note The data type of the weights tensor must be passed at compile time using -DWEI_DATA_TYPE (e.g. -DWEI_DATA_TYPE=float) + * @note The data type of the destination tensor must be passed at compile time using -DDST_DATA_TYPE (e.g. -DDST_DATA_TYPE=float) + * @note The data type of the accumulators must be passed at compile time using -DACC_DATA_TYPE (e.g. -DACC_DATA_TYPE=float) + * @note The number of M0 rows (width*height) to process must be passed at compile time using -DM0 (e.g. -DM0=2) + * @note The number of N0 output channels to process must be passed at compile time using -DN0 (e.g. -DN0=2) + * @note The number of K0 inner accumulations must be passed at compile time using -DK0 (e.g. -DK0=2) + * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_N0 (e.g. -DPARTIAL_N0=1) + * @note Only the following configurations of M0, N0 and K0 are currently supported: + * - M0 = 1 + * - N0 = 1 + * - K0 = 2, 3, 4, 8 + * + * + * @param[in] src_ptr Pointer to the source tensor. Supported data type: F16/F32 + * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) + * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes) + * @param[in] src_c The size of the channels (IFM) dimension of the source tensor + * @param[in] src_w The size of the width dimension of the source tensor + * @param[in] src_h The size of the height dimension of the source tensor + * @param[in] src_n The size of the batches dimension of the source tensor + * @param[out] dst_ptr Pointer to the destination tensor. Supported data type: same as @p src_ptr + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes) + * @param[in] dst_c The size of the channels (OFM) dimension of the destination tensor + * @param[in] dst_w The size of the width dimension of the destination tensor + * @param[in] dst_h The size of the height dimension of the destination tensor + * @param[in] dst_n The size of the batches dimension of the destination tensor + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] wei_ptr Pointer to the weights tensor. Supported data type: same as @p src_ptr + * @param[in] wei_stride_y Stride of the weights tensor in Y dimension (in bytes) + * @param[in] wei_stride_z Stride of the weights tensor in Z dimension (in bytes) + * @param[in] wei_stride_w Stride of the weights tensor in W dimension (in bytes) + * @param[in] wei_c The size of the channels (IFM) dimension of the weights tensor + * @param[in] wei_w The size of the width dimension of the weights tensor + * @param[in] wei_h The size of the height dimension of the weights tensor + * @param[in] wei_n The size of the batches (OFM) dimension of the weights tensor + * @param[in] wei_offset_first_element_in_bytes The offset of the first element in the bias matrix + * @param[in] bia_ptr (Optional) Pointer to the bias tensor Supported data type: same as @p src_ptr (if F32/F16) + * @param[in] bia_stride_x (Optional) Stride of the bias tensor in X dimension (in bytes) + * @param[in] bia_step_x (Optional) bia_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] bia_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix + */ +//! @endcond +__kernel void transposed_convolution_nhwc( + TENSOR4D_T(src, SRC_TENSOR_TYPE), + TENSOR4D_T(dst, DST_TENSOR_TYPE), + TENSOR4D_T(wei, WEI_TENSOR_TYPE) +#if defined(HAS_BIAS) + , + VECTOR_DECLARATION(bia) +#endif // defined(HAS_BIAS) +) +{ + // All the tensor dimensions are passed at compile time. + // In case of dynamic tensor support, the following dimensions should be passed as function argument. +#define _IWEI_WIDTH WEI_WIDTH +#define _IWEI_HEIGHT WEI_HEIGHT +#define _ISRC_WIDTH SRC_WIDTH +#define _ISRC_HEIGHT SRC_HEIGHT +#define _ISRC_CHANNELS SRC_CHANNELS +#define _IDST_WIDTH DST_WIDTH +#define _IDST_HEIGHT DST_HEIGHT +#define _IDST_CHANNELS DST_CHANNELS +#define _IY_MULTIPLIER (_IWEI_WIDTH * _IWEI_HEIGHT) + + const int cout = GET_SPATIAL_IDX(0, N0, PARTIAL_N0); // OFM + const int mout = GET_SPATIAL_IDX(1, M0, 0); // WIDTH x HEIGHT + const int bout = GET_SPATIAL_IDX(2, 1, 0); // BATCH SIZE IDX + + // .v = access the whole vector (OpenCL vector) + // .s[x] = access the vector element at position x (scalar access) + TILE(int, M0, 1, xi); + TILE(int, M0, 1, yi); + TILE(int, M0, 1, xu); + TILE(int, M0, 1, yu); + + // Convert the linear index to coordinate + LOOP_UNROLLING(int, i, 0, 1, M0, + { + xu[i].v = ((mout + i) % _IDST_WIDTH) - PAD_LEFT; + yu[i].v = ((mout + i) / _IDST_WIDTH) - PAD_TOP; + xi[i].v = ceil(xu[i].v / (float)STRIDE_X); + yi[i].v = ceil(yu[i].v / (float)STRIDE_Y); + }) + + // Initialize the accumulators + TILE(ACC_DATA_TYPE, M0, N0, c); + + LOOP_UNROLLING(int, i, 0, 1, M0, + { + c[i].v = 0; + }) + + // Flipped indices + const int x_start = _IWEI_WIDTH - (xi[0].v * STRIDE_X - xu[0].v) - 1; + const int y_start = _IWEI_HEIGHT - (yi[0].v * STRIDE_Y - yu[0].v) - 1; + + for(int yk = y_start, yi_step = 0; yk >= 0; yk -= STRIDE_Y, ++yi_step) + { + for(int xk = x_start, xi_step = 0; xk >= 0; xk -= STRIDE_X, ++xi_step) + { + int weights_y = cout * _IY_MULTIPLIER + yk * _IWEI_WIDTH + xk; + + TILE(int, M0, 1, my); + + LOOP_UNROLLING(int, i, 0, 1, M0, + { + int x_s = xi[i].v + xi_step; + int y_s = yi[i].v + yi_step; + my[i].v = x_s + y_s *_ISRC_WIDTH; + my[i].v = my[i].v + bout * (int)(_ISRC_WIDTH * _ISRC_HEIGHT); + my[i].v = select(-1, my[i].v, x_s >= 0); + my[i].v = select(-1, my[i].v, x_s < _ISRC_WIDTH); + my[i].v = select(-1, my[i].v, y_s >= 0); + my[i].v = select(-1, my[i].v, y_s < _ISRC_HEIGHT); + }) + + int ck = 0; + for(; ck <= (_ISRC_CHANNELS - K0); ck += K0) + { + TILE(SRC_DATA_TYPE, M0, K0, a); + TILE(WEI_DATA_TYPE, N0, K0, b); + + // Initialize tiles + LOOP_UNROLLING(int, i, 0, 1, M0, + { + a[i].v = 0.f; + }) + + LOOP_UNROLLING(int, i, 0, 1, N0, + { + b[i].v = 0.f; + }) + + // Load tile from the src tensor + T_LOAD2D_INDIRECT(SRC_DATA_TYPE, M0, K0, SRC_TENSOR_TYPE, src, bout, yk, xk, ck, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, my, a); + + // Load tile from the weights tensor + T_LOAD(WEI_DATA_TYPE, N0, K0, WEI_TENSOR_TYPE, wei, ck, weights_y, _IY_MULTIPLIER, wei_stride_y, b); + + // Compute the matrix multiplication between two tiles + T_MMUL(SRC_DATA_TYPE, WEI_DATA_TYPE, ACC_DATA_TYPE, M0, N0, K0, NT, T, a, b, c); + } + + // We voluntarily use SRC_CHANNELS rather than _DSRC_CHANNELS + // This #if directive should be removed in case of dynamic tensor support +#if defined(LEFTOVER_LOOP) + // Left-over accumulations + for(; ck < _ISRC_CHANNELS; ++ck) + { + TILE(SRC_DATA_TYPE, M0, 1, a); + TILE(WEI_DATA_TYPE, N0, 1, b); + + // Initialize tiles + LOOP_UNROLLING(int, i, 0, 1, M0, + { + a[i].v = 0.f; + }) + + // Load tile from the src tensor + // The T_LOAD for the left-over elements can only use BUFFER because we load one element per iteration + T_LOAD2D_INDIRECT(SRC_DATA_TYPE, M0, 1, BUFFER, src, bout, yk, xk, ck, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, my, a); + + // Load tile from the weights tensor + // The T_LOAD for the left-over elements can only use BUFFER because we load one element per iteration + T_LOAD(WEI_DATA_TYPE, N0, 1, BUFFER, wei, ck, weights_y, _IY_MULTIPLIER, wei_stride_y, b); + + // Compute the matrix multiplication between two tiles + T_MMUL(SRC_DATA_TYPE, WEI_DATA_TYPE, ACC_DATA_TYPE, M0, N0, 1, NT, T, a, b, c); + } +#endif // defined(LEFTOVER_LOOP) + } + } + +#if defined(HAS_BIAS) + TILE(BIA_DATA_TYPE, 1, N0, bias0); + + T_LOAD(BIA_DATA_TYPE, 1, N0, BUFFER, bia, cout, 0, 1, 0, bias0); + + // c = c + bias[broadcasted] + T_ELTWISE_BROADCAST_ADD_X(ACC_DATA_TYPE, M0, N0, c, bias0, c); + +#endif // HAS_BIAS + + TILE(uint, M0, 1, dst_indirect_y); + + // Calculate the destination indirect Y + LOOP_UNROLLING(int, i, 0, 1, M0, + { + dst_indirect_y[i].v = (uint)min(mout + i, (int)(_IDST_WIDTH * _IDST_HEIGHT) - 1); + dst_indirect_y[i].v += bout * (int)(_IDST_WIDTH * _IDST_HEIGHT); + }) + + bool x_cond = PARTIAL_N0 != 0 && get_global_id(0) == 0; + + // Store the tile in reverse order so the invalid values are overwritten with the valid ones + T_STORE_INDIRECT_WIDTH_SELECT(DST_DATA_TYPE, M0, N0, PARTIAL_N0, DST_TENSOR_TYPE, dst, cout, dst_stride_y, x_cond, c, dst_indirect_y); + +#undef _IWEI_WIDTH +#undef _IWEI_HEIGHT +#undef _ISRC_WIDTH +#undef _ISRC_HEIGHT +#undef _ISRC_CHANNELS +#undef _IDST_WIDTH +#undef _IDST_HEIGHT +#undef _IDST_CHANNELS +#undef _IY_MULTIPLIER +} \ No newline at end of file diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp index 0f08f5d044..4e036399db 100644 --- a/src/gpu/cl/ClKernelLibrary.cpp +++ b/src/gpu/cl/ClKernelLibrary.cpp @@ -448,6 +448,7 @@ const std::map ClKernelLibrary::_kernel_program_map = { "space_to_batch_nhwc", "nhwc/space_to_batch.cl" }, { "space_to_batch_static_nhwc", "nhwc/space_to_batch.cl" }, { "space_to_depth_nhwc", "nhwc/space_to_depth.cl" }, + { "transposed_convolution_nhwc", "nhwc/transposed_convolution.cl" }, { "upsample_layer_nhwc", "nhwc/upsample_layer.cl" }, { "winograd_filter_transform_4x1_3x1_nhwc", "nhwc/winograd_filter_transform.cl" }, { "winograd_filter_transform_1x4_1x3_nhwc", "nhwc/winograd_filter_transform.cl" }, @@ -941,6 +942,10 @@ const std::map ClKernelLibrary::_program_source_map = { "nhwc/space_to_depth.cl", #include "./cl_kernels/nhwc/space_to_depth.clembed" + }, + { + "nhwc/transposed_convolution.cl", +#include "./cl_kernels/nhwc/transposed_convolution.clembed" }, { "nhwc/winograd_filter_transform.cl", diff --git a/src/gpu/cl/kernels/ClTransposedConvolutionKernel.cpp b/src/gpu/cl/kernels/ClTransposedConvolutionKernel.cpp new file mode 100644 index 0000000000..16c6ad9a9b --- /dev/null +++ b/src/gpu/cl/kernels/ClTransposedConvolutionKernel.cpp @@ -0,0 +1,231 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/gpu/cl/kernels/ClTransposedConvolutionKernel.h" + +#include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "src/core/CL/CLValidate.h" +#include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/WindowHelpers.h" +#include "support/Cast.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace +{ +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, + const PadStrideInfo &deconv_info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(input, DataLayout::NHWC); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(weights, DataLayout::NHWC); + + constexpr unsigned int channel_idx = 0; + constexpr unsigned int width_idx = 1; + constexpr unsigned int height_idx = 2; + constexpr unsigned int batch_idx = 3; + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(channel_idx) != input->dimension(channel_idx), "Weights feature map dimension should match the respective src's one"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->num_dimensions() > 4, "Weights can be at most 4 dimensional"); + + if(biases != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->dimension(channel_idx) != weights->dimension(batch_idx), + "Biases size and number of dst feature maps should match"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->num_dimensions() > 1, "Biases should be one dimensional"); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(input, DataLayout::NHWC); + } + + // Checks performed when output is configured + if(output->total_size() != 0) + { + const size_t input_width = input->dimension(width_idx); + const size_t input_height = input->dimension(height_idx); + const size_t weights_width = weights->dimension(width_idx); + const size_t weights_height = weights->dimension(height_idx); + + auto out_dims = deconvolution_output_dimensions(input_width, input_height, weights_width, weights_height, deconv_info); + TensorShape output_shape = misc::shape_calculator::compute_deconvolution_output_shape(out_dims, *input, *weights); + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(output, DataLayout::NHWC); + } + + return Status{}; +} +} // namespace + +void ClTransposedConvolutionKernel::configure(const CLCompileContext &compile_context, const ITensorInfo *input, const ITensorInfo *weights, + const ITensorInfo *biases, ITensorInfo *output, const PadStrideInfo &deconv_info) +{ + ARM_COMPUTE_UNUSED(biases, deconv_info); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); + + // Perform validation + ARM_COMPUTE_ERROR_THROW_ON(validate(input, weights, biases, output, deconv_info)); + + constexpr unsigned int channel_idx = 0; + constexpr unsigned int width_idx = 1; + constexpr unsigned int height_idx = 2; + + const size_t input_channels = input->dimension(channel_idx); // same as weight channels + const size_t input_width = input->dimension(width_idx); + const size_t input_height = input->dimension(height_idx); + const size_t weights_width = weights->dimension(width_idx); + const size_t weights_height = weights->dimension(height_idx); + const size_t output_width = output->dimension(width_idx); + const size_t output_height = output->dimension(height_idx); + const size_t output_channels = output->dimension(channel_idx); + + // Calculate output shape + auto out_dims = deconvolution_output_dimensions(input_width, input_height, weights_width, weights_height, deconv_info); + TensorShape output_shape = misc::shape_calculator::compute_deconvolution_output_shape(out_dims, *input, *weights); + auto_init_if_empty(*output, output_shape, 1, input->data_type(), input->quantization_info()); + + // Calculate updated paddings + // p' = k - p - 1 (k: kernel dimensions) + const uint32_t pad_left = weights_width - deconv_info.pad_left() - 1; + const uint32_t pad_top = weights_height - deconv_info.pad_top() - 1; + + // Configure kernel window + Window win; + output_shape.collapse(2U, 1U); // Collapse width and height into single dimension + + // Create window and update padding + win = calculate_max_window(output_shape, Steps(1, 1)); + ICLKernel::configure_internal(win); + + const std::string kernel_name = "transposed_convolution_nhwc"; + CLBuildOptions build_options; + + const DataType input_data_type = input->data_type(); // Fp32 or Fp16 only + const auto strides = deconv_info.stride(); + + const unsigned int n0 = 1; + const unsigned int m0 = 1; + const unsigned int k0 = adjust_vec_size(input_data_type == DataType::F32 ? 4 : 8, input_channels); + const unsigned int partial_store_n0 = output_channels % n0; + + if(biases != nullptr) + { + build_options.add_option(std::string("-DHAS_BIAS")); + build_options.add_option(std::string("-DBIA_DATA_TYPE=" + get_cl_type_from_data_type(biases->data_type()))); + } + + const auto output_data_type = output->data_type(); + + build_options.add_option("-cl-fast-relaxed-math"); + build_options.add_option("-DSRC_TENSOR_TYPE=BUFFER"); + build_options.add_option("-DSRC_DATA_TYPE=" + get_cl_type_from_data_type(input_data_type)); + build_options.add_option("-DSRC_CHANNELS=" + support::cpp11::to_string(input_channels)); + build_options.add_option("-DSRC_WIDTH=" + support::cpp11::to_string(input_width)); + build_options.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(input_height)); + build_options.add_option("-DDST_CHANNELS=" + support::cpp11::to_string(output_channels)); + build_options.add_option("-DDST_WIDTH=" + support::cpp11::to_string(output_width)); + build_options.add_option("-DDST_HEIGHT=" + support::cpp11::to_string(output_height)); + build_options.add_option("-DDST_TENSOR_TYPE=BUFFER"); + build_options.add_option("-DDST_DATA_TYPE=" + get_cl_type_from_data_type(output_data_type)); + build_options.add_option("-DWEI_TENSOR_TYPE=BUFFER"); + build_options.add_option("-DWEI_WIDTH=" + support::cpp11::to_string(weights_width)); + build_options.add_option("-DWEI_HEIGHT=" + support::cpp11::to_string(weights_height)); + build_options.add_option("-DWEI_DATA_TYPE=" + get_cl_type_from_data_type(weights->data_type())); + build_options.add_option("-DSTRIDE_X=" + support::cpp11::to_string(strides.first)); + build_options.add_option("-DSTRIDE_Y=" + support::cpp11::to_string(strides.second)); + build_options.add_option("-DPAD_LEFT=" + support::cpp11::to_string(pad_left)); + build_options.add_option("-DPAD_TOP=" + support::cpp11::to_string(pad_top)); + build_options.add_option("-DN0=" + support::cpp11::to_string(n0)); + build_options.add_option("-DM0=" + support::cpp11::to_string(m0)); + build_options.add_option("-DK0=" + support::cpp11::to_string(k0)); + build_options.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0)); + build_options.add_option_if((input_channels % k0) != 0, "-DLEFTOVER_LOOP"); + build_options.add_option("-DACC_DATA_TYPE=" + get_cl_type_from_data_type(input_data_type)); + + if(compile_context.get_ddk_version() >= 30) + { + build_options.add_option("-fregister-allocation=64"); + } + + _kernel = create_kernel(compile_context, kernel_name, build_options.options()); + + // Set config_id for enabling LWS tuning + _config_id = kernel_name; + _config_id += "_"; + _config_id += lower_string(string_from_data_type(input_data_type)); + _config_id += "_"; + _config_id += support::cpp11::to_string(weights_width); + _config_id += "_"; + _config_id += support::cpp11::to_string(strides.first); + _config_id += "_"; + _config_id += support::cpp11::to_string(strides.second); + _config_id += "_"; + _config_id += support::cpp11::to_string(output_width); + _config_id += "_"; + _config_id += support::cpp11::to_string(m0); + _config_id += "_"; + _config_id += support::cpp11::to_string(n0); +} + +Status ClTransposedConvolutionKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, + const ITensorInfo *dst, const PadStrideInfo &deconv_info) +{ + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, deconv_info)); + return Status{}; +} + +void ClTransposedConvolutionKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) +{ + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window); + + // Get initial windows + Window slice = window.first_slice_window_3D(); + + const auto src = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_0)); + const auto weights = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_1)); + const auto biases = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_2)); + auto dst = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); + + unsigned int idx = 0; + add_4d_tensor_nhwc_argument(idx, src); + add_4d_tensor_nhwc_argument(idx, dst); + + add_4d_tensor_nhwc_argument(idx, weights); + if(biases != nullptr) + { + add_1D_tensor_argument(idx, biases, slice); + } + + enqueue(queue, *this, slice, lws_hint()); +} +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/gpu/cl/kernels/ClTransposedConvolutionKernel.h b/src/gpu/cl/kernels/ClTransposedConvolutionKernel.h new file mode 100644 index 0000000000..d4350dda50 --- /dev/null +++ b/src/gpu/cl/kernels/ClTransposedConvolutionKernel.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_TRANSPOSED_CONVOLUTION_KERNEL_H +#define ARM_COMPUTE_CL_TRANSPOSED_CONVOLUTION_KERNEL_H + +#include "src/core/common/Macros.h" +#include "src/gpu/cl/ClCompileContext.h" +#include "src/gpu/cl/IClKernel.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +/** OpenCL kernel for transposed convolution. */ +class ClTransposedConvolutionKernel : public IClKernel +{ +public: + ClTransposedConvolutionKernel() = default; + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClTransposedConvolutionKernel); + /** Set the input, weights, biases and output tensors. + * + * Similar to @ref ClTransposedConvolution::configure() + * + */ + void configure(const CLCompileContext &compile_context, const ITensorInfo *input, const ITensorInfo *weights, + const ITensorInfo *biases, ITensorInfo *output, const PadStrideInfo &deconv_info); + /** Static function to check if given info will lead to a valid configuration + * + * Similar to @ref ClTransposedConvolution::configure() + * + * @return a status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, + const ITensorInfo *output, const PadStrideInfo &deconv_info); + + // Inherited methods overridden: + void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override; +}; +} // namespace kernels +} // namespace opencl +} // namespace arm_compute + +#endif /* ARM_COMPUTE_CL_TRANSPOSED_CONVOLUTION_KERNEL_H */ \ No newline at end of file diff --git a/src/gpu/cl/operators/ClTransposedConvolution.cpp b/src/gpu/cl/operators/ClTransposedConvolution.cpp new file mode 100644 index 0000000000..90dbe7f291 --- /dev/null +++ b/src/gpu/cl/operators/ClTransposedConvolution.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/gpu/cl/operators/ClTransposedConvolution.h" + +#include "arm_compute/core/Error.h" +#include "arm_compute/runtime/CL/CLScheduler.h" +#include "src/common/utils/Log.h" +#include "src/gpu/cl/kernels/ClTransposedConvolutionKernel.h" + +namespace arm_compute +{ +namespace opencl +{ +void ClTransposedConvolution::configure(const CLCompileContext &compile_context, const ITensorInfo *input, const ITensorInfo *weights, + const ITensorInfo *biases, ITensorInfo *output, const PadStrideInfo &deconv_info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input); + ARM_COMPUTE_LOG_PARAMS(input, weights, biases, output, deconv_info); + auto kernel_object = std::make_unique(); + kernel_object->set_target(CLScheduler::get().target()); + kernel_object->configure(compile_context, input, weights, biases, output, deconv_info); + _transposed_conv_kernel = std::move(kernel_object); +} + +Status ClTransposedConvolution::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, + const ITensorInfo *output, const PadStrideInfo &deconv_info) +{ + ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClTransposedConvolutionKernel::validate(input, weights, biases, output, deconv_info)); + return Status{}; +} + +void ClTransposedConvolution::run(ITensorPack &tensors) +{ + CLScheduler::get().enqueue_op(*_transposed_conv_kernel.get(), tensors, false); +} +} // namespace opencl +} // namespace arm_compute diff --git a/src/gpu/cl/operators/ClTransposedConvolution.h b/src/gpu/cl/operators/ClTransposedConvolution.h new file mode 100644 index 0000000000..bc04387df5 --- /dev/null +++ b/src/gpu/cl/operators/ClTransposedConvolution.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_TRANSPOSED_CONVOLUTION_H +#define ARM_COMPUTE_CL_TRANSPOSED_CONVOLUTION_H + +#include "src/gpu/cl/ClCompileContext.h" +#include "src/gpu/cl/IClKernel.h" +#include "src/gpu/cl/IClOperator.h" +namespace arm_compute +{ +namespace opencl +{ +/** Basic function to simulate a directly convolution layer. This function calls the following OpenCL kernels: + * + * -# @ref opencl::ClTransposedConvolution + */ +class ClTransposedConvolution : public IClOperator +{ +public: + /** Default constructor */ + ClTransposedConvolution() = default; + /** Default Destructor */ + ~ClTransposedConvolution() = default; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + ClTransposedConvolution(const ClTransposedConvolution &) = delete; + /** Default move constructor */ + ClTransposedConvolution(ClTransposedConvolution &&) = default; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + ClTransposedConvolution &operator=(const ClTransposedConvolution &) = delete; + /** Default move assignment operator */ + ClTransposedConvolution &operator=(ClTransposedConvolution &&) = default; + + /** Set the input, weights, biases and output tensors. + * + * @note: Only NHWC data layout is supported + * + * @param[in] compile_context The compile context to be used. + * @param[in] input Input tensor info with dimensions [IFM, width, height, batch] + * Data types supported: F16/F32. + * @param[in] weights Weight tensor info with dimensions [IFM, width, height, OFM]. + * Data type supported: Same as @p input + * @param[in] biases (Optional) Biases tensor info. Biases are 1D tensor with dimension [OFM]. + * Data type supported: Should match @p input data type + * @param[out] output Output tensor info with dimensions [OFM, width, height, batch] + * The 1st dimension must be equal to the 4th dimension of the @p weights tensor. + * Data types supported: Same as @p input. + * @param[in] deconv_info Contains padding and stride information described in @ref PadStrideInfo. + * + */ + void configure(const CLCompileContext &compile_context, const ITensorInfo *input, const ITensorInfo *weights, + const ITensorInfo *biases, ITensorInfo *output, const PadStrideInfo &deconv_info); + /** Static function to check if given info will lead to a valid configuration + * + * Similar to ClTransposedConvolution::configure() + * + * @return a status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, + const ITensorInfo *output, const PadStrideInfo &deconv_info); + + // Inherited method overridden + void run(ITensorPack &tensors) override; + +private: + std::unique_ptr _transposed_conv_kernel{ nullptr }; +}; +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_TRANSPOSED_CONVOLUTION_H */ \ No newline at end of file diff --git a/src/runtime/CL/functions/CLDeconvolutionLayer.cpp b/src/runtime/CL/functions/CLDeconvolutionLayer.cpp index c348bfcd0c..a4db6d7770 100644 --- a/src/runtime/CL/functions/CLDeconvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLDeconvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,6 +29,8 @@ #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/CL/CLScheduler.h" #include "src/core/CL/ICLKernel.h" +#include "src/gpu/cl/IClOperator.h" +#include "src/gpu/cl/operators/ClTransposedConvolution.h" #include "src/common/utils/Log.h" @@ -39,8 +41,19 @@ using namespace arm_compute; using namespace arm_compute::misc::shape_calculator; +struct CLDeconvolutionLayer::Impl +{ + const ICLTensor *src{ nullptr }; + const ICLTensor *weights{ nullptr }; + const ICLTensor *biases{ nullptr }; + ICLTensor *dst{ nullptr }; + std::unique_ptr op{ nullptr }; +}; + +CLDeconvolutionLayer::~CLDeconvolutionLayer() = default; + CLDeconvolutionLayer::CLDeconvolutionLayer(std::shared_ptr memory_manager) - : _memory_manager(std::move(memory_manager)), _function() + : _memory_manager(std::move(memory_manager)), _function(), _impl(std::make_unique()) { } @@ -59,6 +72,19 @@ void CLDeconvolutionLayer::configure(const CLCompileContext &compile_context, IC switch(CLDeconvolutionLayer::get_deconvolution_method(input->info(), weights->info(), nullptr, output->info(), deconv_info, weights_info)) { case DeconvolutionMethod::DIRECT: + { + auto op = std::make_unique(); + op->configure(compile_context, input->info(), weights->info(), bias != nullptr ? bias->info() : nullptr, output->info(), deconv_info); + + _impl->src = input; + _impl->weights = weights; + _impl->biases = bias; + _impl->dst = output; + + _impl->op = std::move(op); + break; + } + case DeconvolutionMethod::UPSCALE_CONV2D: { auto f = std::make_unique(); f->configure(compile_context, input, weights, bias, output, deconv_info, weights_info); @@ -85,6 +111,12 @@ Status CLDeconvolutionLayer::validate(const ITensorInfo *input, const ITensorInf switch(CLDeconvolutionLayer::get_deconvolution_method(input, weights, bias, output, deconv_info, weights_info)) { case DeconvolutionMethod::DIRECT: + { + // Validate transposed convolution operator + ARM_COMPUTE_RETURN_ON_ERROR(opencl::ClTransposedConvolution::validate(input, weights, bias, output, deconv_info)); + break; + } + case DeconvolutionMethod::UPSCALE_CONV2D: { // Validate direct convolution layer ARM_COMPUTE_RETURN_ON_ERROR(CLDirectDeconvolutionLayer::validate(input, weights, bias, output, deconv_info, weights_info)); @@ -109,11 +141,16 @@ DeconvolutionMethod CLDeconvolutionLayer::get_deconvolution_method(const ITensor { ARM_COMPUTE_UNUSED(output, bias, weights_info); - if(is_data_type_quantized_per_channel(weights->data_type())) + if(input->data_layout() == DataLayout::NHWC && (input->data_type() == DataType::F32 || input->data_type() == DataType::F16)) { return DeconvolutionMethod::DIRECT; } + if(is_data_type_quantized_per_channel(weights->data_type())) + { + return DeconvolutionMethod::UPSCALE_CONV2D; + } + const DataLayout data_layout = input->data_layout(); const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); @@ -121,7 +158,7 @@ DeconvolutionMethod CLDeconvolutionLayer::get_deconvolution_method(const ITensor if(weights->dimension(idx_w) != deconv_info.stride().first || weights->dimension(idx_h) != deconv_info.stride().second) { - return DeconvolutionMethod::DIRECT; + return DeconvolutionMethod::UPSCALE_CONV2D; } return DeconvolutionMethod::GEMM; @@ -130,10 +167,29 @@ DeconvolutionMethod CLDeconvolutionLayer::get_deconvolution_method(const ITensor void CLDeconvolutionLayer::run() { prepare(); - _function->run(); + + if(_impl->op != nullptr) + { + // Optimized Operator will be used + ITensorPack pack; + + pack.add_tensor(TensorType::ACL_SRC_0, _impl->src); + pack.add_tensor(TensorType::ACL_SRC_1, _impl->weights); + pack.add_tensor(TensorType::ACL_SRC_2, _impl->biases); + pack.add_tensor(TensorType::ACL_DST, _impl->dst); + + _impl->op->run(pack); + } + else + { + _function->run(); + } } void CLDeconvolutionLayer::prepare() { - _function->prepare(); + if(_impl->op == nullptr) + { + _function->prepare(); + } } diff --git a/tests/datasets/ShapeDatasets.h b/tests/datasets/ShapeDatasets.h index f0ad9fa693..e4277a981e 100644 --- a/tests/datasets/ShapeDatasets.h +++ b/tests/datasets/ShapeDatasets.h @@ -549,7 +549,7 @@ public: { TensorShape{ 5U, 5U, 7U, 4U, 3U }, TensorShape{ 5U, 5U, 4U, 13U, 2U }, - TensorShape{ 5U, 5U, 3U, 5U , 2U}, + TensorShape{ 5U, 5U, 3U, 5U, 2U }, }) { } @@ -718,6 +718,7 @@ public: SmallDeconvolutionShapes() : ShapeDataset("InputShape", { + // Multiple Vector Loops for FP32 TensorShape{ 5U, 4U, 3U, 2U }, TensorShape{ 5U, 5U, 3U }, TensorShape{ 11U, 13U, 4U, 3U } @@ -726,6 +727,19 @@ public: } }; +class SmallDeconvolutionShapesWithLargerChannels final : public ShapeDataset +{ +public: + SmallDeconvolutionShapesWithLargerChannels() + : ShapeDataset("InputShape", + { + // Multiple Vector Loops for all data types + TensorShape{ 5U, 5U, 35U } + }) + { + } +}; + /** Data set containing tiny tensor shapes for direct convolution. */ class TinyDirectConvolutionShapes final : public ShapeDataset { diff --git a/tests/validation/CL/DeconvolutionLayer.cpp b/tests/validation/CL/DeconvolutionLayer.cpp index 15962b588d..01d0dd8caa 100644 --- a/tests/validation/CL/DeconvolutionLayer.cpp +++ b/tests/validation/CL/DeconvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -70,6 +70,10 @@ const auto data3x3_asymm = datasets::SmallDeconvolutionShapes() * framework::dat const auto data3x3_precommit = datasets::SmallDeconvolutionShapes() * framework::dataset::make("StrideX", 1, 2) * framework::dataset::make("StrideY", 1, 2) * framework::dataset::make("PadX", 0, 2) * framework::dataset::make("PadY", 0, 2) * framework::dataset::make("NumKernels", { 3 }); +const auto data3x3_precommit_large_channels = datasets::SmallDeconvolutionShapesWithLargerChannels() * framework::dataset::make("StrideX", 2) * framework::dataset::make("StrideY", + 2) * framework::dataset::make("PadX", 1) + * framework::dataset::make("PadY", 2) * framework::dataset::make("NumKernels", { 5 }); + const auto data2x2_precommit = datasets::SmallDeconvolutionShapes() * framework::dataset::make("StrideX", 2) * framework::dataset::make("StrideY", 2) * framework::dataset::make("PadX", 1) * framework::dataset::make("PadY", 1) * framework::dataset::make("NumKernels", { 3 }); @@ -90,9 +94,15 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip( framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Mismatching data type TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Invalid weights shape TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F16), // Non supported data type - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Invalid bias shape + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Invalid bias shape TensorInfo(TensorShape(13U, 11U, 4U, 3U), 1, DataType::F32), // Window shrink TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(2U, 13U, 27U), 1, DataType::F32, DataLayout::NHWC), // Mismatching data type + TensorInfo(TensorShape(2U, 13U, 27U), 1, DataType::F32, DataLayout::NHWC), // Invalid weights shape + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F16, DataLayout::NHWC), // Non supported data type + TensorInfo(TensorShape(2U, 13U, 27U), 1, DataType::F32, DataLayout::NHWC), // Invalid bias shape + TensorInfo(TensorShape(4U, 11U, 13U, 3U), 1, DataType::F32, DataLayout::NHWC), // Window shrink + TensorInfo(TensorShape(2U, 16U, 32U), 1, DataType::F32, DataLayout::NHWC), }), framework::dataset::make("WeightsInfo", { TensorInfo(TensorShape(3U, 3U, 2U, 2U), 1, DataType::F16), TensorInfo(TensorShape(3U, 3U, 2U, 4U), 1, DataType::F32), @@ -100,6 +110,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip( TensorInfo(TensorShape(3U, 2U, 2U, 2U), 1, DataType::F32), TensorInfo(TensorShape(3U, 3U, 4U), 1, DataType::F32), TensorInfo(TensorShape(1U, 1U, 2U, 4U), 1, DataType::F32), + TensorInfo(TensorShape(2U, 3U, 3U, 2U), 1, DataType::F16, DataLayout::NHWC), + TensorInfo(TensorShape(2U, 3U, 3U, 4U), 1, DataType::F32, DataLayout::NHWC), + TensorInfo(TensorShape(3U, 3U, 2U, 2U), 1, DataType::F16, DataLayout::NHWC), + TensorInfo(TensorShape(2U, 2U, 3U, 2U), 1, DataType::F32, DataLayout::NHWC), + TensorInfo(TensorShape(4U, 3U, 3U), 1, DataType::F32, DataLayout::NHWC), + TensorInfo(TensorShape(2U, 2U, 2U, 4U), 1, DataType::F32, DataLayout::NHWC), })), framework::dataset::make("BiasInfo", { TensorInfo(TensorShape(1U), 1, DataType::F16), TensorInfo(TensorShape(1U), 1, DataType::F32), @@ -107,6 +123,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip( TensorInfo(TensorShape(25U, 11U), 1, DataType::F32), TensorInfo(TensorShape(1U), 1, DataType::F32), TensorInfo(TensorShape(4U), 1, DataType::F32), + TensorInfo(TensorShape(1U), 1, DataType::F16, DataLayout::NHWC), + TensorInfo(TensorShape(1U), 1, DataType::F32, DataLayout::NHWC), + TensorInfo(TensorShape(1U), 1, DataType::F32, DataLayout::NHWC), + TensorInfo(TensorShape(25U, 11U), 1, DataType::F32, DataLayout::NHWC), + TensorInfo(TensorShape(1U), 1, DataType::F32, DataLayout::NHWC), + TensorInfo(TensorShape(4U), 1, DataType::F32, DataLayout::NHWC), })), framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(25U, 11U, 2U), 1, DataType::F16), TensorInfo(TensorShape(25U, 10U, 2U), 1, DataType::F32), @@ -114,6 +136,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip( TensorInfo(TensorShape(13U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(11U, 9U, 1U, 3U), 1, DataType::F32), TensorInfo(TensorShape(32U, 16U, 4U), 1, DataType::F32), + TensorInfo(TensorShape(2U, 11U, 25U), 1, DataType::F16, DataLayout::NHWC), + TensorInfo(TensorShape(2U, 10U, 25U), 1, DataType::F32, DataLayout::NHWC), + TensorInfo(TensorShape(25U, 11U, 2U), 1, DataType::F32, DataLayout::NHWC), + TensorInfo(TensorShape(2U, 13U, 13U), 1, DataType::F32, DataLayout::NHWC), + TensorInfo(TensorShape(1U, 9U, 11U, 3U), 1, DataType::F32, DataLayout::NHWC), + TensorInfo(TensorShape(4U, 43U, 91U), 1, DataType::F32, DataLayout::NHWC), })), framework::dataset::make("PadStrideInfo", { PadStrideInfo(1, 1, 0, 0), PadStrideInfo(1, 1, 0, 0), @@ -121,8 +149,15 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip( PadStrideInfo(1, 1, 0, 0), PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0), + PadStrideInfo(1, 1, 0, 0), + PadStrideInfo(1, 1, 0, 0), + PadStrideInfo(1, 1, 0, 0), + PadStrideInfo(1, 1, 0, 0), + PadStrideInfo(1, 1, 1, 1), + PadStrideInfo(3, 3, 2, 2), })), - framework::dataset::make("Expected", { false, false, false, false, false, true })), + framework::dataset::make("Expected", { false, false, false, false, false, true, // NCHW + false, false, false, false, false, true })), // NHWC input_info, weights_info, bias_info, output_info, pad_info, expected) { bool is_valid = bool(CLDeconvolutionLayer::validate(&input_info.clone()->set_is_resizable(false), &weights_info.clone()->set_is_resizable(false), &bias_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), pad_info)); @@ -171,6 +206,17 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLDeconvolutionLayerFixture3x3, framewor // Validate output validate(CLAccessor(_target), _reference, tolerance_fp32); } + +FIXTURE_DATA_TEST_CASE(RunSmallWithLargeChannels, CLDeconvolutionLayerFixture3x3, framework::DatasetMode::PRECOMMIT, combine(combine(combine(data3x3_precommit_large_channels, + framework::dataset::make("DataType", + DataType::F32)), + data_layouts_dataset), + framework::dataset::make("AddBias", { true }))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_fp32); +} + FIXTURE_DATA_TEST_CASE(RunAsymm, CLDeconvolutionLayerAsymmFixture3x3, framework::DatasetMode::NIGHTLY, combine(combine(combine(data3x3_asymm, framework::dataset::make("DataType", DataType::F32)), data_layouts_dataset), -- cgit v1.2.1