From 327225d3b2f716d5c62d801a7fafc7d377521f34 Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Tue, 13 Apr 2021 13:09:30 +0100 Subject: Port NEDirectConvolutionLayer to new API Partially resolves: COMPMID-4009 Change-Id: I19ffb61c5c4541134a5028677d2d81228740e454 Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5419 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: SiCong Li Reviewed-by: Georgios Pinitas Reviewed-by: Michele Di Giorgio --- Android.bp | 5 +- .../NEON/functions/NEDirectConvolutionLayer.h | 29 +- docs/00_introduction.dox | 10 +- src/core/NEON/NEKernels.h | 2 - .../kernels/NEDirectConvolutionLayerKernel.cpp | 1382 ------------------- .../NEON/kernels/NEDirectConvolutionLayerKernel.h | 109 -- .../NEDirectConvolutionLayerOutputStageKernel.cpp | 504 ------- .../NEDirectConvolutionLayerOutputStageKernel.h | 102 -- .../cpu/kernels/CpuDirectConvolutionKernel.cpp | 1385 ++++++++++++++++++++ src/core/cpu/kernels/CpuDirectConvolutionKernel.h | 100 ++ .../CpuDirectConvolutionOutputStageKernel.h | 93 ++ .../kernels/CpuDirectConvolutionStageKernel.cpp | 514 ++++++++ .../NEON/functions/NEDirectConvolutionLayer.cpp | 109 +- src/runtime/cpu/operators/CpuDirectConvolution.cpp | 147 +++ src/runtime/cpu/operators/CpuDirectConvolution.h | 121 ++ 15 files changed, 2401 insertions(+), 2211 deletions(-) delete mode 100644 src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp delete mode 100644 src/core/NEON/kernels/NEDirectConvolutionLayerKernel.h delete mode 100644 src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp delete mode 100644 src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h create mode 100644 src/core/cpu/kernels/CpuDirectConvolutionKernel.cpp create mode 100644 src/core/cpu/kernels/CpuDirectConvolutionKernel.h create mode 100644 src/core/cpu/kernels/CpuDirectConvolutionOutputStageKernel.h create mode 100644 src/core/cpu/kernels/CpuDirectConvolutionStageKernel.cpp create mode 100644 src/runtime/cpu/operators/CpuDirectConvolution.cpp create mode 100644 src/runtime/cpu/operators/CpuDirectConvolution.h diff --git a/Android.bp b/Android.bp index 7e2da8fe53..17281a49d1 100644 --- a/Android.bp +++ b/Android.bp @@ -176,8 +176,6 @@ cc_library_static { "src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp", "src/core/NEON/kernels/NEDepthToSpaceLayerKernel.cpp", "src/core/NEON/kernels/NEDepthwiseConvolutionLayerNativeKernel.cpp", - "src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp", - "src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp", "src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp", "src/core/NEON/kernels/NEFFTRadixStageKernel.cpp", "src/core/NEON/kernels/NEFFTScaleKernel.cpp", @@ -301,6 +299,8 @@ cc_library_static { "src/core/cpu/kernels/CpuConcatenateWidthKernel.cpp", "src/core/cpu/kernels/CpuCopyKernel.cpp", "src/core/cpu/kernels/CpuDequantizationKernel.cpp", + "src/core/cpu/kernels/CpuDirectConvolutionKernel.cpp", + "src/core/cpu/kernels/CpuDirectConvolutionStageKernel.cpp", "src/core/cpu/kernels/CpuElementwiseKernel.cpp", "src/core/cpu/kernels/CpuElementwiseUnaryKernel.cpp", "src/core/cpu/kernels/CpuFillKernel.cpp", @@ -630,6 +630,7 @@ cc_library_static { "src/runtime/cpu/operators/CpuConcatenate.cpp", "src/runtime/cpu/operators/CpuCopy.cpp", "src/runtime/cpu/operators/CpuDequantization.cpp", + "src/runtime/cpu/operators/CpuDirectConvolution.cpp", "src/runtime/cpu/operators/CpuElementwise.cpp", "src/runtime/cpu/operators/CpuElementwiseUnary.cpp", "src/runtime/cpu/operators/CpuFill.cpp", diff --git a/arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h index 86914fa0bc..fc4017e635 100644 --- a/arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h +++ b/arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h @@ -28,24 +28,18 @@ #include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/IMemoryManager.h" #include "arm_compute/runtime/MemoryGroup.h" -#include "arm_compute/runtime/NEON/functions/NEActivationLayer.h" -#include "arm_compute/runtime/Tensor.h" #include namespace arm_compute { -class NEDirectConvolutionLayerOutputStageKernel; -class NEDirectConvolutionLayerKernel; -class NEFillBorderKernel; - +class ITensor; +class ITensorInfo; /** Function to run the direct convolution. * - * This function calls the following kernels: + * This function calls the following: * - * -# @ref NEFillBorderKernel for the input - * -# @ref NEDirectConvolutionLayerOutputStageKernel - * -# @ref NEDirectConvolutionLayerKernel + * -# @ref cpu::CpuDirectConvolution */ class NEDirectConvolutionLayer : public IFunction { @@ -108,16 +102,9 @@ public: void run() override; private: - MemoryGroup _memory_group; - std::unique_ptr _output_stage_kernel; - std::unique_ptr _conv_kernel; - std::unique_ptr _input_border_handler; - NEActivationLayer _activationlayer_function; - Tensor _accumulator; - bool _has_bias; - bool _is_activationlayer_enabled; - unsigned int _dim_split; - bool _is_padding_required; + struct Impl; + std::shared_ptr _memory_manager; + std::unique_ptr _impl; }; -} +} // namespace arm_compute #endif /* ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYER_H */ diff --git a/docs/00_introduction.dox b/docs/00_introduction.dox index 9f6af6da50..efc2963f6e 100644 --- a/docs/00_introduction.dox +++ b/docs/00_introduction.dox @@ -259,7 +259,7 @@ v20.11 Public major release - NENonMaximaSuppression3x3Kernel - @ref NERemapKernel - @ref NEGEMMInterleave4x4Kernel - - @ref NEDirectConvolutionLayerKernel + - NEDirectConvolutionLayerKernel - NEScaleKernel - NELocallyConnectedMatrixMultiplyKernel - @ref NEGEMMLowpOffsetContributionKernel @@ -269,7 +269,7 @@ v20.11 Public major release - @ref NEDepthwiseConvolutionLayerNativeKernel - @ref NEGEMMLowpMatrixMultiplyKernel - @ref NEGEMMMatrixMultiplyKernel - - @ref NEDirectConvolutionLayerOutputStageKernel + - NEDirectConvolutionLayerOutputStageKernel - @ref NEReductionOperationKernel - @ref NEGEMMLowpMatrixAReductionKernel - @ref NEGEMMLowpMatrixBReductionKernel @@ -682,7 +682,7 @@ v20.02 Public major release - @ref NEConvolutionLayer - @ref NEDepthwiseConvolutionLayer - NEDepthwiseConvolutionLayer3x3Kernel - - @ref NEDirectConvolutionLayerOutputStageKernel + - NEDirectConvolutionLayerOutputStageKernel - @ref NEElementwiseComparison - @ref NEElementwiseMax - @ref NEElementwiseMin @@ -1214,7 +1214,7 @@ v18.01 Public maintenance release - GCGEMMTranspose1xWKernel - GCIm2ColKernel - Refactored Arm® Neon™ Winograd (NEWinogradLayerKernel) - - Added @ref NEDirectConvolutionLayerOutputStageKernel + - Added NEDirectConvolutionLayerOutputStageKernel - Added QASYMM8 support to the following Arm® Neon™ kernels: - NEDepthwiseConvolutionLayer3x3Kernel - @ref NEFillBorderKernel @@ -1338,7 +1338,7 @@ v17.06 Public major release - New Arm® Neon™ kernels / functions: - @ref NEBatchNormalizationLayerKernel / @ref NEBatchNormalizationLayer - NEDepthConcatenateLayerKernel / NEDepthConcatenateLayer - - @ref NEDirectConvolutionLayerKernel / @ref NEDirectConvolutionLayer + - NEDirectConvolutionLayerKernel / @ref NEDirectConvolutionLayer - NELocallyConnectedMatrixMultiplyKernel / NELocallyConnectedLayer - @ref NEWeightsReshapeKernel / @ref NEConvolutionLayerReshapeWeights diff --git a/src/core/NEON/NEKernels.h b/src/core/NEON/NEKernels.h index 59884e2d05..264f521be2 100644 --- a/src/core/NEON/NEKernels.h +++ b/src/core/NEON/NEKernels.h @@ -39,8 +39,6 @@ #include "src/core/NEON/kernels/NEDepthConvertLayerKernel.h" #include "src/core/NEON/kernels/NEDepthToSpaceLayerKernel.h" #include "src/core/NEON/kernels/NEDepthwiseConvolutionLayerNativeKernel.h" -#include "src/core/NEON/kernels/NEDirectConvolutionLayerKernel.h" -#include "src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h" #include "src/core/NEON/kernels/NEFFTDigitReverseKernel.h" #include "src/core/NEON/kernels/NEFFTRadixStageKernel.h" #include "src/core/NEON/kernels/NEFFTScaleKernel.h" diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp deleted file mode 100644 index 98b76c7db3..0000000000 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp +++ /dev/null @@ -1,1382 +0,0 @@ -/* - * Copyright (c) 2017-2021 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#include "src/core/NEON/kernels/NEDirectConvolutionLayerKernel.h" - -#include "src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h" -#include "src/core/NEON/wrapper/wrapper.h" - -#include "arm_compute/core/Error.h" -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/IAccessWindow.h" -#include "arm_compute/core/ITensor.h" -#include "arm_compute/core/Types.h" -#include "arm_compute/core/Utils.h" -#include "arm_compute/core/Validate.h" -#include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "src/core/AccessWindowStatic.h" -#include "src/core/CPP/Validate.h" -#include "src/core/NEON/NEFixedPoint.h" -#include "src/core/helpers/AutoConfiguration.h" -#include "src/core/helpers/WindowHelpers.h" - -#include - -using namespace arm_compute::detail; - -namespace arm_compute -{ -namespace -{ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template -float16x8_t internal_vld1q(const float16_t *in); - -template <> -float16x8_t internal_vld1q<1>(const float16_t *in) -{ - return vld1q_f16(in); -} - -template <> -float16x8_t internal_vld1q<2>(const float16_t *in) -{ - const float16x8x2_t tmp = vld2q_f16(in); - return tmp.val[0]; -} - -template <> -float16x8_t internal_vld1q<3>(const float16_t *in) -{ - const float16x8x3_t tmp = vld3q_f16(in); - return tmp.val[0]; -} - -inline float16x8_t internal_vdupq_n(float16_t v) -{ - return vdupq_n_f16(v); -} - -inline void internal_vst1q(float16_t *p, const float16x8_t &v) -{ - vst1q_f16(p, v); -} - -float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y) -{ - return vmulq_f16(x, y); -} - -inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z) -{ - return vaddq_f16(x, vmulq_f16(y, z)); -} -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - -template -float32x4_t internal_vld1q(const float *in); - -template <> -float32x4_t internal_vld1q<1>(const float *in) -{ - return vld1q_f32(in); -} - -template <> -float32x4_t internal_vld1q<2>(const float *in) -{ - const float32x4x2_t tmp = vld2q_f32(in); - return tmp.val[0]; -} - -template <> -float32x4_t internal_vld1q<3>(const float *in) -{ - const float32x4x3_t tmp = vld3q_f32(in); - return tmp.val[0]; -} - -inline float32x4_t internal_vdupq_n(float v) -{ - return vdupq_n_f32(v); -} - -inline void internal_vst1q(float *p, const float32x4_t &v) -{ - vst1q_f32(p, v); -} - -float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y) -{ - return vmulq_f32(x, y); -} - -inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z) -{ - return vmlaq_f32(x, y, z); -} - -constexpr int small_tensor_size_optim = 8; -inline bool run_optim_small_tensor_info(const ITensorInfo *t) -{ - return t->dimension(Window::DimX) <= small_tensor_size_optim && t->dimension(Window::DimY) <= small_tensor_size_optim; -} - -inline bool run_optim_small_tensor(const ITensor *t) -{ - return run_optim_small_tensor_info(t->info()); -} - -// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8 -// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to -// store intermidiate results in memory. Temporary results are stored in SIMD registers directly and then written to the output buffer. -template -class convolver_w1x1_i8x8_f32 -{ -public: - static void convolve(const Window &window, const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) - { - ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimX) > small_tensor_size_optim); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimY) > small_tensor_size_optim); - - const int input_stride_x = input->info()->strides_in_bytes().x(); - const int input_stride_y = input->info()->strides_in_bytes().y(); - const int input_stride_z = input->info()->strides_in_bytes().z(); - const int output_stride_y = output->info()->strides_in_bytes().y(); - const int output_stride_z = output->info()->strides_in_bytes().z(); - const int kernel_stride_z = weights->info()->strides_in_bytes().z(); - const int kernel_stride_w = weights->info()->strides_in_bytes()[3]; - const int output_h = output->info()->dimension(1); - const int range_z = window.z().end() - window.z().start(); - const int kernel_depth = weights->info()->dimension(Window::DimZ); - const unsigned int conv_stride_y = std::get<1>(conv_info.stride()); - const unsigned int conv_pad_left = conv_info.pad_left(); - const unsigned int conv_pad_top = conv_info.pad_top(); - - // setup output window for the iterator - Window window_out = window; - window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX))); - window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY))); - window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z)); - - // setup input window for the iterator - Window window_in = window; - // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0 - window_in.set(Window::DimX, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimY, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimZ, Window::Dimension(0, 0, 0)); - - Window window_k = calculate_max_window(*weights->info(), Steps(1u)); - Iterator out(output, window_out); - Iterator in(input, window_in); - Iterator k(weights, window_k); - - const uint8_t *k_ptr = k.ptr(); - - execute_window_loop(window_out, [&](const Coordinates & id) - { - const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y; - uint8_t *out_ptr = out.ptr(); - int ih = 0; - int oh = 0; - std::array accum0 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) }; - std::array accum1 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) }; - for(int oz = 0; oz < range_z; ++oz) - { - accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f); - accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f); - auto p_out_base = out_ptr + oz * output_stride_z; - for(int p = 0; p < kernel_depth; ++p) - { - const auto k_val = reinterpret_cast(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w); - const auto vk0 = internal_vdupq_n(*k_val); - for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) - { - const int offset_xy = ih * input_stride_y; - auto in_val = reinterpret_cast(input_ptr + p * input_stride_z + offset_xy); - auto v_in0 = internal_vld1q(in_val); - auto v_in1 = internal_vld1q(in_val + 4); - accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0); - accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1); - } - } - for(oh = 0; oh < output_h; ++oh) - { - auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); - vst1q_f32(p_out, accum0[oh]); - vst1q_f32(p_out + 4, accum1[oh]); - } - } - }, - in, out); - } -}; - -template -class convolver_1x1 -{ -public: - static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, - const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) - { - const int input_stride_x = input->info()->strides_in_bytes().x(); - const int input_stride_y = input->info()->strides_in_bytes().y(); - const int input_stride_z = input->info()->strides_in_bytes().z(); - const int output_stride_y = output->info()->strides_in_bytes().y(); - const int output_stride_z = output->info()->strides_in_bytes().z(); - const int kernel_stride_z = weights->info()->strides_in_bytes().z(); - const int kernel_stride_w = weights->info()->strides_in_bytes()[3]; - const int output_w = output->info()->dimension(0); - const int output_h = output->info()->dimension(1); - const int range_z = window.z().end() - window.z().start(); - const int kernel_depth = weights->info()->dimension(Window::DimZ); - const unsigned int conv_stride_y = std::get<1>(conv_info.stride()); - const unsigned int conv_pad_left = conv_info.pad_left(); - const unsigned int conv_pad_top = conv_info.pad_top(); - - // setup output window for the iterator - Window window_out = window; - window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX))); - window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY))); - window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z)); - - // setup input window for the iterator - Window window_in = window; - // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0 - window_in.set(Window::DimX, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimY, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimZ, Window::Dimension(0, 0, 0)); - - Window window_k = calculate_max_window(*weights->info(), Steps(1u)); - Iterator out(output, window_out); - Iterator in(input, window_in); - Iterator k(weights, window_k); - - const uint8_t *k_ptr = k.ptr(); - - execute_window_loop(window_out, [&](const Coordinates & id) - { - /* - For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1> - */ - const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y; - uint8_t *out_ptr = out.ptr(); - int ih = 0; - int oh = 0; - for(int oz = 0; oz < range_z; ++oz) - { - auto p_out_base = out_ptr + oz * output_stride_z; - // Step 1 - { - const auto k_val = reinterpret_cast(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w); - const auto vk = internal_vdupq_n(*k_val); - for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) - { - const int offset_xy = ih * input_stride_y; - auto in_val = reinterpret_cast(input_ptr + (0 * input_stride_z + offset_xy)); - auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); - for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration) - { - internal_vst1q(p_out, internal_vmull(vk, internal_vld1q(in_val))); - } - } - } - - // Step 2 - for(int p = 1; p < kernel_depth; ++p) - { - const auto k_val = reinterpret_cast(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w); - const auto vk = internal_vdupq_n(*k_val); - for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) - { - const int offset_xy = ih * input_stride_y; - auto in_val = reinterpret_cast(input_ptr + p * input_stride_z + offset_xy); - auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); - for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration) - { - internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q(in_val))); - } - } - } - } - }, - in, out); - } -}; - -template -float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4, - const float *m0, const float *m1, const float *m2, const float *m3, const float *m4); - -inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2) -{ - const float32x4x3_t m00 = - { - { - vld1q_dup_f32(m0), - vld1q_dup_f32(m1), - vld1q_dup_f32(m2) - } - }; - return m00; -} - -inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4) -{ - const float32x4x2_t m00 = - { - { - vld1q_dup_f32(m3), - vld1q_dup_f32(m4) - } - }; - return m00; -} - -inline float32x4x3_t load_input(const float *const in) -{ - const float32x4x3_t vin = - { - { - vld1q_f32(in), - vld1q_f32(in + 4), - vld1q_f32(in + 8) - } - }; - return vin; -} - -template <> -inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4, - const float *m0, const float *m1, const float *m2, const float *m3, const float *m4) -{ - const float32x4x3_t vin0 = load_input(in_0); - const float32x4x3_t vin1 = load_input(in_1); - const float32x4x3_t vin2 = load_input(in_2); - const float32x4x3_t vin3 = load_input(in_3); - const float32x4x3_t vin4 = load_input(in_4); - const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0); - const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0); - const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1); - const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1); - const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2); - const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2); - const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3); - const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3); - const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4); - const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4); - - float32x4x2_t out = - { - { - vmulq_f32(vin0.val[0], m00.val[0]), - vmulq_f32(vin0.val[1], m00.val[0]) - } - }; - - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]); - - out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]); - - out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]); - - out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]); - - out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]); - - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]); - - out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]); - - out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]); - - out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]); - - out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]); - - return out; -} - -template <> -inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4, - const float *m0, const float *m1, const float *m2, const float *m3, const float *m4) -{ - float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4); - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1); - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2); - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3); - return out; -} - -template <> -inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4, - const float *m0, const float *m1, const float *m2, const float *m3, const float *m4) -{ - float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4); - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1); - return out; -} - -template -class convolver_3x3 -{ -public: - static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, - const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) - { - ARM_COMPUTE_UNUSED(num_elems_read_per_iteration); - const int input_stride_x = input->info()->strides_in_bytes().x(); - const int input_stride_y = input->info()->strides_in_bytes().y(); - const int input_stride_z = input->info()->strides_in_bytes().z(); - const int output_stride_y = output->info()->strides_in_bytes().y(); - const int output_stride_z = output->info()->strides_in_bytes().z(); - const int kernel_stride_x = weights->info()->strides_in_bytes().x(); - const int kernel_stride_y = weights->info()->strides_in_bytes().y(); - const int kernel_stride_z = weights->info()->strides_in_bytes().z(); - const int kernel_stride_w = weights->info()->strides_in_bytes()[3]; - const int output_w = output->info()->dimension(0); - const int output_h = output->info()->dimension(1); - const int num_planes_z = window.z().end() - window.z().start(); - const int delta_input = get_input_num_elems_processed(num_elems_written_per_iteration, stridex); - const int kernel_depth = weights->info()->dimension(Window::DimZ); - const unsigned int conv_stride_y = std::get<1>(conv_info.stride()); - const unsigned int conv_pad_left = conv_info.pad_left(); - const unsigned int conv_pad_top = conv_info.pad_top(); - - // setup output window for the iterator - Window window_out = window; - window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX))); - window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY))); - window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z)); - - // setup input window for the iterator - Window window_in = window; - // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0 - window_in.set(Window::DimX, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimY, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimZ, Window::Dimension(0, 0, 0)); - - Window window_k = calculate_max_window(*weights->info(), Steps(1u)); - - Iterator out(output, window_out); - Iterator in(input, window_in); - Iterator k(weights, window_k); - - const uint8_t *k_ptr = k.ptr(); - - execute_window_loop(window_out, [&](const Coordinates & id) - { - const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y; - uint8_t *out_ptr = out.ptr(); - int ih = 0; - int oh = 0; - /* - Each thread executing this kernel computes one or more output's volume planes. - - Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15], - the third thread [16,24] and the fourth thread [25,31]. - - The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this - is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value. - - The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages: - 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values. - 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1. - */ - for(int oz = 0; oz < num_planes_z; ++oz) - { - const int zoffset = id.z() + oz; - uint8_t *p_out_base = out_ptr + oz * output_stride_z; - // Step 1 - { - const auto ptr_k_r0 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x); - const auto ptr_k_r1 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x); - const auto ptr_k_r2 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x); - const auto vk_r0 = load_matrix_row(ptr_k_r0); - const auto vk_r1 = load_matrix_row(ptr_k_r1); - const auto vk_r2 = load_matrix_row(ptr_k_r2); - for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) - { - auto in_top = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y); - auto in_mid = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y); - auto in_low = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y); - auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); - for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, - in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration) - { - convolve_3x3(in_top, in_mid, in_low, p_out, vk_r0, vk_r1, vk_r2, stridex); - } - } - } - // Step 2 - for(int p = 1; p < kernel_depth; ++p) - { - const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w; - const uint8_t *input_base = input_ptr + p * input_stride_z; - const auto ptr_k_r0 = reinterpret_cast(ptr_k_base); - const auto ptr_k_r1 = reinterpret_cast(ptr_k_base + kernel_stride_y); - const auto ptr_k_r2 = reinterpret_cast(ptr_k_base + kernel_stride_y * 2); - const auto vk_r0 = load_matrix_row(ptr_k_r0); - const auto vk_r1 = load_matrix_row(ptr_k_r1); - const auto vk_r2 = load_matrix_row(ptr_k_r2); - for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) - { - auto in_top = reinterpret_cast(input_base + (ih + 0) * input_stride_y); - auto in_mid = reinterpret_cast(input_base + (ih + 1) * input_stride_y); - auto in_low = reinterpret_cast(input_base + (ih + 2) * input_stride_y); - auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); - for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, - in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration) - { - convolve_3x3(in_top, in_mid, in_low, p_out, vk_r0, vk_r1, vk_r2, stridex); - } - } - } - } - }, - in, out); - } -}; - -template -class convolver_5x5 -{ -public: - static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, - const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) - { - ARM_COMPUTE_UNUSED(num_elems_read_per_iteration); - const int input_stride_x = input->info()->strides_in_bytes().x(); - const int input_stride_y = input->info()->strides_in_bytes().y(); - const int input_stride_z = input->info()->strides_in_bytes().z(); - const int output_stride_y = output->info()->strides_in_bytes().y(); - const int output_stride_z = output->info()->strides_in_bytes().z(); - const int kernel_stride_x = weights->info()->strides_in_bytes().x(); - const int kernel_stride_y = weights->info()->strides_in_bytes().y(); - const int kernel_stride_z = weights->info()->strides_in_bytes().z(); - const int kernel_stride_w = weights->info()->strides_in_bytes()[3]; - const int output_w = output->info()->dimension(0); - const int output_h = output->info()->dimension(1); - const int num_planes_z = window.z().end() - window.z().start(); - const int delta_input = get_input_num_elems_processed(num_elems_written_per_iteration, stridex); - const int kernel_depth = weights->info()->dimension(Window::DimZ); - const unsigned int conv_stride_y = std::get<1>(conv_info.stride()); - const unsigned int conv_pad_left = conv_info.pad_left(); - const unsigned int conv_pad_top = conv_info.pad_top(); - - // setup output window for the iterator - Window window_out = window; - window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX))); - window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY))); - window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z)); - - // setup input window for the iterator - Window window_in = window; - // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0 - window_in.set(Window::DimX, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimY, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimZ, Window::Dimension(0, 0, 0)); - - Window window_k = calculate_max_window(*weights->info(), Steps(1u)); - - Iterator out(output, window_out); - Iterator in(input, window_in); - Iterator k(weights, window_k); - - const uint8_t *k_ptr = k.ptr(); - - execute_window_loop(window_out, [&](const Coordinates & id) - { - const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y; - uint8_t *out_ptr = out.ptr(); - int ih = 0; - int oh = 0; - for(int oz = 0; oz < num_planes_z; ++oz) - { - const int zoffset = id.z() + oz; - uint8_t *p_out_base = out_ptr + oz * output_stride_z; - // Step 1 - { - const auto ptr_k_r0 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x); - const auto ptr_k_r1 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x); - const auto ptr_k_r2 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x); - const auto ptr_k_r3 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x); - const auto ptr_k_r4 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x); - for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) - { - auto in_0 = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y); - auto in_1 = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y); - auto in_2 = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y); - auto in_3 = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y); - auto in_4 = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y); - auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); - for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, - in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration) - { - auto vres = convolve_5x5(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4); - store_results(p_out, vres); - } - } - } - // Step 2 - for(int p = 1; p < kernel_depth; ++p) - { - const auto ptr_k_r0 = reinterpret_cast(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x); - const auto ptr_k_r1 = reinterpret_cast(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x); - const auto ptr_k_r2 = reinterpret_cast(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x); - const auto ptr_k_r3 = reinterpret_cast(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x); - const auto ptr_k_r4 = reinterpret_cast(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x); - - for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) - { - auto in_0 = reinterpret_cast(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y); - auto in_1 = reinterpret_cast(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y); - auto in_2 = reinterpret_cast(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y); - auto in_3 = reinterpret_cast(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y); - auto in_4 = reinterpret_cast(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y); - auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); - for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, - in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration) - { - auto vres = convolve_5x5(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4); - accumulate_results(p_out, vres); - } - } - } - } - }, - in, out); - } -}; - -float vreduce(const float32x4_t &v) -{ - auto v0 = wrapper::vgethigh(v); - auto v1 = wrapper::vgetlow(v); - auto v_out = wrapper::vadd(v0, v1); - - float a = wrapper::vgetlane(v_out, 0); - float b = wrapper::vgetlane(v_out, 1); - return a + b; -} - -template -inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, - const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) -{ - const unsigned int conv_stride_x = std::get<0>(conv_info.stride()); - switch(conv_stride_x) - { - case 1: - convolver_1x1::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info); - break; - case 2: - convolver_1x1::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info); - break; - case 3: - convolver_1x1::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info); - break; - default: - ARM_COMPUTE_ERROR("Not implemented"); - } -} - -template <> -inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, - const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) -{ - const unsigned int conv_stride_x = std::get<0>(conv_info.stride()); - if(run_optim_small_tensor(input)) - { - switch(conv_stride_x) - { - case 1: - convolver_w1x1_i8x8_f32<1>::convolve(window, input, weights, output, conv_info); - break; - case 2: - convolver_w1x1_i8x8_f32<2>::convolve(window, input, weights, output, conv_info); - break; - case 3: - convolver_w1x1_i8x8_f32<3>::convolve(window, input, weights, output, conv_info); - break; - default: - ARM_COMPUTE_ERROR("Not implemented"); - } - } - else - { - switch(conv_stride_x) - { - case 1: - convolver_1x1::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info); - break; - case 2: - convolver_1x1::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info); - break; - case 3: - convolver_1x1::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info); - break; - default: - ARM_COMPUTE_ERROR("Not implemented"); - } - } -} - -template -inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, - const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) -{ - const unsigned int conv_stride_x = std::get<0>(conv_info.stride()); - switch(conv_stride_x) - { - case 1: - convolver_3x3::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info); - break; - case 2: - convolver_3x3::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info); - break; - case 3: - convolver_3x3::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info); - break; - default: - ARM_COMPUTE_ERROR("Not implemented"); - } -} - -template -inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, - const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) -{ - const unsigned int conv_stride_x = std::get<0>(conv_info.stride()); - switch(conv_stride_x) - { - case 1: - convolver_5x5::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info); - break; - case 2: - convolver_5x5::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info); - break; - case 3: - convolver_5x5::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info); - break; - default: - ARM_COMPUTE_ERROR("Not implemented"); - } -} - -Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info) -{ - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); - ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); - ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); - - const DataLayout data_layout = input->data_layout(); - const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); - const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); - const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported."); - ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(channel_idx) != input->dimension(channel_idx)); - ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx)); - ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4); - ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(width_idx) > 3) && (input->data_type() == DataType::F16)); - - // Checks performed when output is configured - if(output->total_size() != 0) - { - TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info); - - DataType data_type = input->data_type(); - - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape); - ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != data_type); - } - - return Status{}; -} - -std::pair validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row, - unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size) -{ - ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); - - const DataLayout data_layout = input->data_layout(); - const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); - - // Calculate right and bottom border - unsigned int kernel_size = weights->dimension(width_idx); - const int conv_stride_x = std::get<0>(conv_info.stride()); - const int conv_stride_y = std::get<1>(conv_info.stride()); - const int input_width = input->dimension(width_idx); - - Window win{}; - bool window_changed = false; - - if(data_layout == DataLayout::NCHW) - { - switch(kernel_size) - { - case 1: - { - switch(input->data_type()) - { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - num_elems_written_per_iteration = 8; - break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - case DataType::F32: - if(run_optim_small_tensor_info(input)) - { - num_elems_written_per_iteration = 8; - } - else - { - num_elems_written_per_iteration = 4; - } - break; - default: - ARM_COMPUTE_ERROR("Data type not supported."); - break; - } - num_weight_elems_read_per_row = kernel_size; - num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration; - break; - } - case 3: - switch(input->data_type()) - { - case DataType::F32: - num_weight_elems_read_per_row = 4 + kernel_size - 1; - num_elems_read_per_iteration = 12; - num_elems_written_per_iteration = 16 >> conv_stride_x; - break; -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - num_weight_elems_read_per_row = 8 + kernel_size - 1; - num_elems_read_per_iteration = 24; - num_elems_written_per_iteration = 32 >> conv_stride_x; - break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - ARM_COMPUTE_ERROR("Data type not supported."); - break; - } - break; - case 5: - { - switch(input->data_type()) - { - case DataType::F32: - num_weight_elems_read_per_row = 4 + kernel_size - 1; - num_elems_read_per_iteration = 12; - num_elems_written_per_iteration = 16 >> conv_stride_x; - break; - default: - ARM_COMPUTE_ERROR("Data type not supported."); - break; - } - } - break; - default: - { - ARM_COMPUTE_ERROR("Not implemented"); - break; - } - } - - // Calculate right pad - int start_x = kernel_size / 2 - static_cast(conv_info.pad_left()); - int end_x = ceil_to_multiple(static_cast(output->dimension(0)), num_elems_written_per_iteration) * conv_stride_x; - int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width; - - // Calculate border - const unsigned int conv_pad_left = conv_info.pad_left(); - const unsigned int conv_pad_top = conv_info.pad_top(); - const unsigned int conv_pad_right = std::max(upper_bound_w, 0); - const unsigned int conv_pad_bottom = conv_info.pad_bottom(); - - border_size.left = conv_pad_left; - border_size.top = conv_pad_top; - border_size.right = conv_pad_right; - border_size.bottom = conv_pad_bottom; - - // Configure window - win = calculate_max_window(*output, Steps(num_elems_written_per_iteration)); - - AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top, - num_elems_read_per_iteration, kernel_size, - conv_stride_x, conv_stride_y); - AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size); - AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration); - window_changed = update_window_and_padding(win, input_access, weights_access, output_access); - output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape())); - } - else - { - // Configure window NHWC without any padding - win = calculate_max_window(*output, Steps()); - } - - Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; - return std::make_pair(err, win); -} - -bool have_zero_x_internal_padding(ITensorInfo *input, ITensorInfo *weights) -{ - return (input->padding().left == 0 && weights->padding().left == 0 && input->padding().right == 0 && weights->padding().right == 0); -} - -} // namespace - -template -void NEDirectConvolutionLayerKernel::convolve_nhwc_optimized(const Window &window) -{ - // This function assumes that input and weights have not padding in channel - - // Declare useful types - using vtype = wrapper::traits::neon_bitvector; - using vector_type = typename vtype::type; - using tag_type = typename vtype::tag_type; - - // Scalar quantities - const int element_size = _input->info()->element_size(); - const int input_stride_w = _input->info()->strides_in_bytes().y() / element_size; - const int input_stride_h = _input->info()->strides_in_bytes().z() / element_size; - const int input_stride_n = _input->info()->strides_in_bytes()[3] / element_size; - const int input_dim_w = _input->info()->dimension(1); - const int input_dim_h = _input->info()->dimension(2); - - const int output_stride_c = _output->info()->strides_in_bytes().x(); - - const unsigned int kernel_stride_w = _weights->info()->strides_in_bytes().y() / element_size; - const unsigned int kernel_stride_h = _weights->info()->strides_in_bytes().z() / element_size; - const int kernel_dim_w = _weights->info()->dimension(1); - const int kernel_dim_h = _weights->info()->dimension(2); - - const int conv_pad_top = _conv_info.pad_top(); - const int conv_pad_left = _conv_info.pad_left(); - const int conv_stride_w = std::get<0>(_conv_info.stride()); - const int conv_stride_h = std::get<1>(_conv_info.stride()); - - // Setup input window for the output iterator - Window window_out = window; - window_out.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Setup input window for the weights iterator - Window window_w = calculate_max_window(*_weights->info(), Steps()); - window_w.set(Window::DimX, Window::Dimension(0, 1, 1)); - window_w.set(Window::DimY, Window::Dimension(0, 1, 1)); - window_w.set(Window::DimZ, Window::Dimension(0, 1, 1)); - - Iterator out(_output, window_out); - Iterator wei(_weights, window_w); - - constexpr int num_elems_read_per_iteration = 16 / sizeof(T); - /* - * This implementation parallelize the full WC plane of input and weights by - * treating them as series of elements. So for example, a 3x3 weights and - * floating point vector operations of 4 elements per time, the first 3 - * channel elements of the first row would be taken and additionally the first - * element of the second row. The 9 elements in each single WC weight plane - * would require 2 4-element vector operations and a last single element operation. - * - * This works since when we create the input vector to multiply with the weights, - * the exact required elements are loaded in the same order. Therefore the - * multiplication works on the correct input/weight elements. - */ - execute_window_loop(window_out, [&](const Coordinates & id) - { - /* - * In here we create theoretical indexes which then we validate for both - * inputs and weights. - * As a reminder, this loop take each output point in NHW, C is treated - * in the weights loop. - */ - // We are computing the theoretical starting input starting points - const int in_w_start_t = static_cast(id.y()) * conv_stride_w - conv_pad_left; - const int in_h_start_t = static_cast(id.z()) * conv_stride_h - conv_pad_top; - const int in_w_end_t = in_w_start_t + kernel_dim_w; - const int in_h_end_t = in_h_start_t + kernel_dim_h; - - // We are computing the valid initial and ending input points by checking the borders - const int in_w_start = std::max(in_w_start_t, 0); - const int in_h_start = std::max(in_h_start_t, 0); - const int in_w_end = std::min(in_w_end_t, input_dim_w); - const int in_h_end = std::min(in_h_end_t, input_dim_h); - - // We use the input points to select the valid weight points to use - const int index_wc_start = (in_w_start - in_w_start_t) * kernel_stride_w; - const int index_h_start = in_h_start - in_h_start_t; - const int index_wc_end = (kernel_dim_w - (in_w_end_t - in_w_end)) * kernel_stride_w; - const int index_h_end = kernel_dim_h - (in_h_end_t - in_h_end); - - execute_window_loop(window_w, [&](const Coordinates & id_w) - { - /* - * This is the loop in the weights, and it goes along N (the batches) - * As a reminder, the batches of the weights are translated into the - * channels of the output - */ - const T *in_ptr_row = reinterpret_cast(_input->buffer() + _input->info()->offset_first_element_in_bytes()) - + id[3] * input_stride_n + in_w_start * input_stride_w + in_h_start * input_stride_h; - const T *weights_ptr_row = reinterpret_cast(wei.ptr()) + index_h_start * kernel_stride_h; - uint8_t *out_ptr = out.ptr() + id_w[3] * output_stride_c; - - T out_temp = static_cast(0); - for(int index_h = index_h_start; index_h < index_h_end; ++index_h, in_ptr_row += input_stride_h, weights_ptr_row += kernel_stride_h) - { - const T *in_ptr_mover = in_ptr_row; - int index_wc = index_wc_start; - vector_type out_temp_vec = wrapper::vdup_n(static_cast(0), tag_type()); - for(; index_wc <= index_wc_end - num_elems_read_per_iteration; index_wc += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration) - { - const auto src_vec = wrapper::vloadq(in_ptr_mover); - const auto w_vec = wrapper::vloadq(weights_ptr_row + index_wc); - out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec); - } - out_temp += vreduce(out_temp_vec); - for(; index_wc < index_wc_end; ++index_wc, ++in_ptr_mover) - { - const auto src_val = *(in_ptr_mover); - const auto w_val = *(weights_ptr_row + index_wc); - out_temp += src_val * w_val; - } - } - *(reinterpret_cast(out_ptr)) = out_temp; - }, - wei); - }, - out); -} - -template -void NEDirectConvolutionLayerKernel::convolve_nhwc(const Window &window) -{ - // Declare useful types - using vtype = wrapper::traits::neon_bitvector; - using vector_type = typename vtype::type; - using tag_type = typename vtype::tag_type; - - // Scalar quantities - const int element_size = _input->info()->element_size(); - const int input_stride_w = _input->info()->strides_in_bytes().y() / element_size; - const int input_stride_h = _input->info()->strides_in_bytes().z() / element_size; - const int input_stride_n = _input->info()->strides_in_bytes()[3] / element_size; - const int input_dim_w = _input->info()->dimension(1); - const int input_dim_h = _input->info()->dimension(2); - - const int output_stride_c = _output->info()->strides_in_bytes().x(); - - const unsigned int kernel_stride_w = _weights->info()->strides_in_bytes().y() / element_size; - const unsigned int kernel_stride_h = _weights->info()->strides_in_bytes().z() / element_size; - const int kernel_dim_w = _weights->info()->dimension(1); - const int kernel_dim_h = _weights->info()->dimension(2); - - const int conv_pad_top = _conv_info.pad_top(); - const int conv_pad_left = _conv_info.pad_left(); - const int conv_stride_w = std::get<0>(_conv_info.stride()); - const int conv_stride_h = std::get<1>(_conv_info.stride()); - - // Setup input window for the output iterator - Window window_out = window; - window_out.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Setup input window for the weights iterator - Window window_w = calculate_max_window(*_weights->info(), Steps()); - window_w.set(Window::DimX, Window::Dimension(0, 1, 1)); - window_w.set(Window::DimY, Window::Dimension(0, 1, 1)); - window_w.set(Window::DimZ, Window::Dimension(0, 1, 1)); - - Iterator out(_output, window_out); - Iterator wei(_weights, window_w); - - constexpr int num_elems_read_per_iteration = 16 / sizeof(T); - - execute_window_loop(window_out, [&](const Coordinates & id) - { - // We are computing the theoretical starting input starting points - const int in_w_start_t = static_cast(id.y()) * conv_stride_w - conv_pad_left; - const int in_h_start_t = static_cast(id.z()) * conv_stride_h - conv_pad_top; - const int in_w_end_t = in_w_start_t + kernel_dim_w; - const int in_h_end_t = in_h_start_t + kernel_dim_h; - - // We are computing the valid initial and ending input points by checking the borders - const int in_w_start = std::max(in_w_start_t, 0); - const int in_h_start = std::max(in_h_start_t, 0); - const int in_w_end = std::min(in_w_end_t, input_dim_w); - const int in_h_end = std::min(in_h_end_t, input_dim_h); - - // We use the input points to select the valid weight points to use - const int wei_w_start = in_w_start - in_w_start_t; - const int wei_h_start = in_h_start - in_h_start_t; - const int wei_w_end = kernel_dim_w - (in_w_end_t - in_w_end); - const int wei_h_end = kernel_dim_h - (in_h_end_t - in_h_end); - - const int index_c_end = _weights->info()->dimension(0); - const T *const in_ptr_start = reinterpret_cast(_input->buffer() + _input->info()->offset_first_element_in_bytes()) + id[3] * input_stride_n; - - execute_window_loop(window_w, [&](const Coordinates & id_w) - { - const T *const weights_ptr_start = reinterpret_cast(wei.ptr()); - uint8_t *out_ptr = out.ptr() + id_w[3] * output_stride_c; - - T out_temp = static_cast(0); - for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h) - { - const T *const in_ptr_row = in_ptr_start + index_in_h * input_stride_h; - const T *const weights_ptr_row = weights_ptr_start + index_wei_h * kernel_stride_h; - for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w) - { - const T *in_ptr_mover = in_ptr_row + index_in_w * input_stride_w; - const T *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w; - int index_c = 0; - vector_type out_temp_vec = wrapper::vdup_n(static_cast(0), tag_type()); - for(; index_c <= index_c_end - num_elems_read_per_iteration; index_c += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration, weights_ptr_mover += num_elems_read_per_iteration) - { - const auto src_vec = wrapper::vloadq(in_ptr_mover); - const auto w_vec = wrapper::vloadq(weights_ptr_mover); - out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec); - } - out_temp += vreduce(out_temp_vec); - for(; index_c < index_c_end; ++index_c, ++in_ptr_mover, ++weights_ptr_mover) - { - const auto src_val = *(in_ptr_mover); - const auto w_val = *(weights_ptr_mover); - out_temp += src_val * w_val; - } - } - } - *(reinterpret_cast(out_ptr)) = out_temp; - }, - wei); - }, - out); -} - -NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel() - : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0), - _num_elems_written_per_iteration(0), _data_layout() -{ -} - -BorderSize NEDirectConvolutionLayerKernel::border_size() const -{ - return _border_size; -} - -void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) -{ - ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); - - _input = input; - _weights = weights; - _output = output; - _conv_info = conv_info; - _data_layout = _input->info()->data_layout(); - _kernel_size = weights->info()->dimension(get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH)); - - const unsigned int conv_pad_left = conv_info.pad_left(); - const unsigned int conv_pad_top = conv_info.pad_top(); - const unsigned int conv_pad_right = conv_info.pad_right(); - const unsigned int conv_pad_bottom = conv_info.pad_bottom(); - if(_data_layout == DataLayout::NCHW) - { - _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left); - } - else - { - _border_size = BorderSize(0); - } - - // Get convolved dimensions - TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input->info(), *weights->info(), conv_info); - - DataType data_type = input->info()->data_type(); - - // Output auto inizialitation if not yet initialized - auto_init_if_empty(*output->info(), output_shape, 1, data_type); - - // Perform validation step - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), output->info(), conv_info)); - - // Configure kernel window - auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, _num_weight_elems_read_per_row, - _num_elems_read_per_iteration, _num_elems_written_per_iteration, _border_size); - ARM_COMPUTE_ERROR_THROW_ON(win_config.first); - INEKernel::configure(win_config.second); -} - -Status NEDirectConvolutionLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info) -{ - unsigned int num_weight_elems_read_per_row = 0; - unsigned int num_elems_read_per_iteration = 0; - unsigned int num_elems_written_per_iteration = 0; - BorderSize border_size = {}; - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, weights, output, conv_info)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), - weights->clone().get(), - output->clone().get(), - conv_info, - num_weight_elems_read_per_row, - num_elems_read_per_iteration, - num_elems_written_per_iteration, - border_size) - .first); - - return Status{}; -} - -void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo &info) -{ - ARM_COMPUTE_UNUSED(info); - ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); - ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr); - - const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH)); - - if(_data_layout == DataLayout::NCHW) - { - switch(kernel_size) - { - case 1: - { - switch(_input->info()->data_type()) - { - case DataType::F32: - convolve_1x1(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info); - break; -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - convolve_1x1(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info); - break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } - break; - } - case 3: - { - switch(_input->info()->data_type()) - { - case DataType::F32: - convolve_3x3(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info); - break; -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - convolve_3x3(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info); - break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } - break; - } - case 5: - { - switch(_input->info()->data_type()) - { - case DataType::F32: - convolve_5x5(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info); - break; - default: - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } - break; - } - default: - { - ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported."); - break; - } - } - } - else - { - switch(_input->info()->data_type()) - { - case DataType::F32: - { - if(have_zero_x_internal_padding(_input->info(), _weights->info())) - { - convolve_nhwc_optimized(window); - } - else - { - convolve_nhwc(window); - } - break; - } - default: - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } - } -} -} // namespace arm_compute diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.h b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.h deleted file mode 100644 index 259eb683f6..0000000000 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.h +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright (c) 2017-2021 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#ifndef ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYERKERNEL_H -#define ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYERKERNEL_H - -#include "src/core/NEON/INEKernel.h" - -namespace arm_compute -{ -class ITensor; - -/** Interface for the kernel to perform Direct Convolution Layer. */ -class NEDirectConvolutionLayerKernel : public INEKernel -{ -public: - const char *name() const override - { - return "NEDirectConvolutionLayerKernel"; - } - /** Default constructor */ - NEDirectConvolutionLayerKernel(); - /** Prevent instances of this class from being copied (As this class contains pointers) */ - NEDirectConvolutionLayerKernel(const NEDirectConvolutionLayerKernel &) = delete; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - NEDirectConvolutionLayerKernel &operator=(const NEDirectConvolutionLayerKernel &) = delete; - /** Allow instances of this class to be moved */ - NEDirectConvolutionLayerKernel(NEDirectConvolutionLayerKernel &&) = default; - /** Allow instances of this class to be moved */ - NEDirectConvolutionLayerKernel &operator=(NEDirectConvolutionLayerKernel &&) = default; - /** Default destructor */ - ~NEDirectConvolutionLayerKernel() = default; - /** Set the input, weights, and output tensors. - * - * @note: DirectConvolution only works in the following configurations: - * 1x1 convolution with stride_x = 1/2/3, stride_y = 1/2/3 - * 3x3 convolution with stride_x = 1/2/3, stride_y = 1/2/3 - * - * @param[in] input The input tensor to convolve. 3 lower dimensions represent a single input [width, height, IFM], - * while every optional dimension from 4 and above represent a batch of inputs. Data types supported: F16/F32. - * @param[in] weights Weights tensor. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. - * The 3rd dimension must be the same as the input's volume 3rd dimension. - * Data type supported:Same as @p input. - * @param[out] output Output tensor. - * The 3rd dimensions must be equal to the 4th dimension of the @p kernels tensor. Data types supported: F16/F32 - * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. - */ - void configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info); - /** Static function to check if given info will lead to a valid configuration of @ref NEDirectConvolutionLayerKernel - * - * @param[in] input The input tensor to convolve. 3 lower dimensions represent a single input [width, height, IFM], - * while every optional dimension from 4 and above represent a batch of inputs. Data types supported: F16/F32. - * @param[in] weights Weights tensor. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. - * The 3rd dimension must be the same as the input's volume 3rd dimension. - * Data type supported:Same as @p input. - * @param[in] output Output tensor. - * The 3rd dimensions must be equal to the 4th dimension of the @p kernels tensor. Data types supported: F16/F32 - * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. - * - * @return a status - */ - static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info); - - // Inherited methods overridden: - void run(const Window &window, const ThreadInfo &info) override; - BorderSize border_size() const override; - -private: - /* Template function for optimized convolution NHWC */ - template - void convolve_nhwc_optimized(const Window &window); - - /* Template function for convolution NHWC */ - template - void convolve_nhwc(const Window &window); - - const ITensor *_input; - const ITensor *_weights; - ITensor *_output; - PadStrideInfo _conv_info; - BorderSize _border_size; - unsigned int _kernel_size; - unsigned int _num_weight_elems_read_per_row; - unsigned int _num_elems_read_per_iteration; - unsigned int _num_elems_written_per_iteration; - DataLayout _data_layout; -}; -} // namespace arm_compute -#endif /*ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYERKERNEL_H */ diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp deleted file mode 100644 index 3597045bd5..0000000000 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp +++ /dev/null @@ -1,504 +0,0 @@ -/* - * Copyright (c) 2017-2021 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#include "src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h" - -#include "arm_compute/core/Error.h" -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/ITensor.h" -#include "arm_compute/core/Types.h" -#include "arm_compute/core/Validate.h" -#include "arm_compute/core/Window.h" -#include "arm_compute/core/utils/misc/Traits.h" -#include "src/core/AccessWindowStatic.h" -#include "src/core/CPP/Validate.h" -#include "src/core/NEON/NEAsymm.h" -#include "src/core/NEON/NEFixedPoint.h" -#include "src/core/NEON/wrapper/wrapper.h" -#include "src/core/helpers/AutoConfiguration.h" -#include "src/core/helpers/WindowHelpers.h" - -#include -#include -#include - -namespace arm_compute -{ -namespace -{ -Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, - const DirectConvolutionLayerOutputStageKernelInfo &info) -{ - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input); - ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); - ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::S32, DataType::F32); - - if(bias != nullptr) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias); - ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL))); - ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1); - } - - if(input->data_type() == DataType::S32) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output == nullptr, "In-place computation not allowed for quantized output"); - } - - // Checks performed when output is configured - if((output != nullptr) && (output->total_size() != 0)) - { - if(is_data_type_float(input->data_type())) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - } - else - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED); - } - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); - } - else if(input->data_type() == DataType::S32) - { - // In case of quantized computation and unconfigured output, the output data type must be provided through DirectConvolutionLayerOutputStageKernelInfo - ARM_COMPUTE_RETURN_ERROR_ON((info.output_data_type != DataType::QASYMM8) && (info.output_data_type != DataType::QASYMM8_SIGNED)); - } - - return Status{}; -} - -template -typename std::enable_if::value, void>::type -output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, - int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias) -{ - /** SIMD vector tag type. */ - using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t; - - ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN); - ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier); - ARM_COMPUTE_UNUSED(result_shift); - ARM_COMPUTE_UNUSED(result_offset_after_shift); - - const int window_start_x = window.x().start(); - const int window_end_x = window.x().end(); - const int window_step_x = 16 / input->info()->element_size(); - Window win = window; - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator in(input, win); - Iterator out(output, win); - execute_window_loop(win, [&](const Coordinates & id) - { - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - // Get bias and pointer to input - const auto in_ptr = reinterpret_cast(in.ptr()) + x; - auto v_in = wrapper::vloadq(in_ptr); - - // Accumulate bias - if(has_bias) - { - const auto vb = wrapper::vdup_n(*reinterpret_cast(bias->ptr_to_element(Coordinates(id.z()))), ExactTagType{}); - v_in = wrapper::vadd(v_in, vb); - } - - const auto out_ptr = reinterpret_cast(out.ptr()) + x; - wrapper::vstore(out_ptr, v_in); - } - - // Left-overs loop - for(; x < window_end_x; ++x) - { - // Get bias and pointer to input - auto s_in = *(reinterpret_cast(in.ptr()) + x); - - // Accumulate bias - if(has_bias) - { - const auto b = *reinterpret_cast(bias->ptr_to_element(Coordinates(id.z()))); - s_in += b; - } - - *(reinterpret_cast(out.ptr()) + x) = s_in; - } - - }, - in, out); -} - -template -typename std::enable_if::value, void>::type -output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, - int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias) -{ - ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier); - ARM_COMPUTE_UNUSED(result_shift); - ARM_COMPUTE_UNUSED(result_offset_after_shift); - - Window window_bias = window; - window_bias.set(Window::DimX, Window::Dimension(0, 1, 1)); - window_bias.set(Window::DimY, Window::Dimension(0, 0, 0)); - window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0)); - window_bias.set(3, Window::Dimension(0, 0, 0)); - - const int window_start_x = window.x().start(); - const int window_end_x = window.x().end(); - const int window_step_x = 16 / input->info()->element_size(); - Window win = window; - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator in(input, win); - Iterator bi(bias, window_bias); - Iterator out(output, win); - - execute_window_loop(win, [&](const Coordinates &) - { - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - // Get bias and pointer to input - const auto in_ptr = reinterpret_cast(in.ptr()); - auto v_in = wrapper::vloadq(in_ptr + x); - - // Accumulate bias - if(has_bias) - { - const auto bias_ptr = reinterpret_cast(bi.ptr()) + x; - v_in = wrapper::vadd(v_in, wrapper::vloadq(bias_ptr)); - } - - const auto out_ptr = reinterpret_cast(out.ptr()); - wrapper::vstore(out_ptr + x, v_in); - } - - // Left-overs loop - for(; x < window_end_x; ++x) - { - // Get bias and pointer to input - auto s_in = *(reinterpret_cast(in.ptr()) + x); - - // Accumulate bias - if(has_bias) - { - const auto bias_ptr = reinterpret_cast(bi.ptr()) + x; - s_in += *bias_ptr; - } - - const auto out_ptr = reinterpret_cast(out.ptr()); - *(out_ptr + x) = s_in; - } - }, - in, bi, out); -} - -// Quantized case -template < typename TOut, typename std::enable_if < std::is_same::value || std::is_same::value, int >::type = 0 > -void output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, - int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias) -{ - using VectorType = typename wrapper::traits::neon_bitvector_t; - using TagType = typename wrapper::traits::neon_bitvector_tag_t; - - const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift); - - const VectorType min = wrapper::vdup_n(std::numeric_limits::lowest(), TagType{}); - const VectorType max = wrapper::vdup_n(std::numeric_limits::max(), TagType{}); - - const int window_start_x = window.x().start(); - const int window_end_x = window.x().end(); - const int window_step_x = 16 / input->info()->element_size(); - Window win = window; - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator in(input, win); - Iterator out(output, win); - - execute_window_loop(win, [&](const Coordinates & id) - { - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - // Get bias and pointer to input - const auto in_ptr = reinterpret_cast(in.ptr()) + x; - int32x4x4_t v_in = - { - { - wrapper::vloadq(in_ptr), - wrapper::vloadq(in_ptr + 4), - wrapper::vloadq(in_ptr + 8), - wrapper::vloadq(in_ptr + 12) - } - }; - - // Accumulate bias - if(has_bias) - { - const auto vb = wrapper::vdup_n(*reinterpret_cast(bias->ptr_to_element(Coordinates(id.z()))), TagType{}); - v_in = - { - { - wrapper::vadd(v_in.val[0], vb), - wrapper::vadd(v_in.val[1], vb), - wrapper::vadd(v_in.val[2], vb), - wrapper::vadd(v_in.val[3], vb) - } - }; - } - - const auto out_ptr = reinterpret_cast(out.ptr()) + x; - wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, - min, max, false)); - } - - // Left-overs loop - for(; x < window_end_x; ++x) - { - // Get bias and pointer to input - int32_t s_in = *(reinterpret_cast(in.ptr()) + x); - - // Accumulate bias - if(has_bias) - { - const auto b = *reinterpret_cast(bias->ptr_to_element(Coordinates(id.z()))); - s_in += b; - } - - const auto out_ptr = reinterpret_cast(out.ptr()) + x; - *out_ptr = finalize_quantization(s_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, - std::numeric_limits::lowest(), std::numeric_limits::max(), false); - } - }, - in, out); -} -template < typename TOut, typename std::enable_if < std::is_same::value || std::is_same::value, int >::type = 0 > -void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, - int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias) -{ - using VectorType = typename wrapper::traits::neon_bitvector_t; - using TagType = typename wrapper::traits::neon_bitvector_tag_t; - - const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift); - - const VectorType min = wrapper::vdup_n(std::numeric_limits::lowest(), TagType{}); - const VectorType max = wrapper::vdup_n(std::numeric_limits::max(), TagType{}); - - Window window_bias = window; - window_bias.set(Window::DimX, Window::Dimension(0, 1, 1)); - window_bias.set(Window::DimY, Window::Dimension(0, 0, 0)); - window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0)); - window_bias.set(3, Window::Dimension(0, 0, 0)); - - const int window_start_x = window.x().start(); - const int window_end_x = window.x().end(); - const int window_step_x = 16 / input->info()->element_size(); - Window win = window; - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator in(input, win); - Iterator bi(bias, window_bias); - Iterator out(output, win); - - execute_window_loop(win, [&](const Coordinates &) - { - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - // Get bias and pointer to input - const auto in_ptr = reinterpret_cast(in.ptr()) + x; - int32x4x4_t v_in = - { - { - wrapper::vloadq(in_ptr), - wrapper::vloadq(in_ptr + 4), - wrapper::vloadq(in_ptr + 8), - wrapper::vloadq(in_ptr + 12), - } - }; - - // Accumulate bias - if(has_bias) - { - const auto bias_ptr = reinterpret_cast(bi.ptr()) + x; - - wrapper::vadd(v_in.val[0], wrapper::vloadq(bias_ptr)); - wrapper::vadd(v_in.val[1], wrapper::vloadq(bias_ptr + 4)); - wrapper::vadd(v_in.val[2], wrapper::vloadq(bias_ptr + 8)); - wrapper::vadd(v_in.val[3], wrapper::vloadq(bias_ptr + 12)); - } - - const auto out_ptr = reinterpret_cast(out.ptr()) + x; - wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max, false)); - } - - // Left-overs loop - for(; x < window_end_x; ++x) - { - // Get bias and pointer to input - const auto in_ptr = reinterpret_cast(in.ptr()) + x; - int32_t s_in = *in_ptr; - - // Accumulate bias - if(has_bias) - { - const auto bias_ptr = reinterpret_cast(bi.ptr()) + x; - s_in += *bias_ptr; - } - - const auto out_ptr = reinterpret_cast(out.ptr()) + x; - *out_ptr = finalize_quantization(s_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, - std::numeric_limits::lowest(), std::numeric_limits::max(), false); - } - }, - in, bi, out); -} -} // namespace - -NEDirectConvolutionLayerOutputStageKernel::NEDirectConvolutionLayerOutputStageKernel() - : _func(nullptr), _input(nullptr), _bias(nullptr), _output(nullptr), _result_fixedpoint_multiplier(0), _result_shift(0), _result_offset_after_shift(0) -{ -} - -void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const ITensor *bias, ITensor *output, - const DirectConvolutionLayerOutputStageKernelInfo &info) -{ - // Perform validation step - ARM_COMPUTE_ERROR_ON_NULLPTR(input); - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info(), info)); - - _func = nullptr; - _bias = bias; - _input = input; - _output = (output != nullptr) ? output : input; - _result_fixedpoint_multiplier = info.result_fixedpoint_multiplier; - _result_shift = info.result_shift; - _result_offset_after_shift = info.result_offset_after_shift; - - // Auto-initialize output output if required - if(output != nullptr && output->info() != nullptr) - { - // Work out expected output data type - const DataType output_dt = (input->info()->data_type() == DataType::S32) ? info.output_data_type : DataType::S32; - // Output tensor auto initialization if not yet initialized - auto_init_if_empty(*output->info(), input->info()->clone()->set_data_type(output_dt)); - } - - Window win = calculate_max_window(*input->info(), Steps()); - - INEKernel::configure(win); - - const bool is_qasymm8_signed = (output != nullptr) ? is_data_type_quantized_asymmetric_signed(output->info()->data_type()) : false; - - // Set appropriate function - if(input->info()->data_layout() == DataLayout::NCHW) - { - switch(input->info()->data_type()) - { - case DataType::S32: - { - if(is_qasymm8_signed) - { - _func = &output_stage_nchw; - } - else - { - _func = &output_stage_nchw; - } - break; - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - { - _func = &output_stage_nchw; - break; - } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - case DataType::F32: - { - _func = &output_stage_nchw; - break; - } - default: - { - ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs."); - } - } - } - else - { - switch(input->info()->data_type()) - { - case DataType::S32: - { - if(is_qasymm8_signed) - { - _func = &output_stage_nhwc; - } - else - { - _func = &output_stage_nhwc; - } - break; - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - { - _func = &output_stage_nhwc; - break; - } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - case DataType::F32: - { - _func = &output_stage_nhwc; - break; - } - default: - { - ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs."); - } - } - } -} - -Status NEDirectConvolutionLayerOutputStageKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, - const DirectConvolutionLayerOutputStageKernelInfo &info) -{ - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output, info)); - - return Status{}; -} - -void NEDirectConvolutionLayerOutputStageKernel::run(const Window &window, const ThreadInfo &info) -{ - ARM_COMPUTE_UNUSED(info); - ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); - ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - ARM_COMPUTE_ERROR_ON(_func == nullptr); - - const bool has_bias = _bias != nullptr; - (*_func)(_input, _bias, window, _output, _result_fixedpoint_multiplier, _result_shift, _result_offset_after_shift, has_bias); -} -} // namespace arm_compute diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h deleted file mode 100644 index 8f7eeb05b2..0000000000 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright (c) 2017-2021 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#ifndef ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYEROUTPUTSTAGEKERNEL_H -#define ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYEROUTPUTSTAGEKERNEL_H - -#include "arm_compute/core/KernelDescriptors.h" -#include "src/core/NEON/INEKernel.h" - -namespace arm_compute -{ -class ITensor; -/** Kernel to accumulate the biases, if provided, or downscale in case of quantized input. - * - * @note We assume bias to be shared - * @note For quantized computations (i.e. @p input of S32 type) the output data type for auto-initialization must be passed as part - * of the @ref DirectConvolutionLayerOutputStageKernelInfo. - */ -class NEDirectConvolutionLayerOutputStageKernel : public INEKernel -{ -public: - const char *name() const override - { - return "NEDirectConvolutionLayerOutputStageKernel"; - } - /** Default constructor */ - NEDirectConvolutionLayerOutputStageKernel(); - /** Prevent instances of this class from being copied (As this class contains pointers) */ - NEDirectConvolutionLayerOutputStageKernel(const NEDirectConvolutionLayerOutputStageKernel &) = delete; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - NEDirectConvolutionLayerOutputStageKernel &operator=(const NEDirectConvolutionLayerOutputStageKernel &) = delete; - /** Allow instances of this class to be moved */ - NEDirectConvolutionLayerOutputStageKernel(NEDirectConvolutionLayerOutputStageKernel &&) = default; - /** Allow instances of this class to be moved */ - NEDirectConvolutionLayerOutputStageKernel &operator=(NEDirectConvolutionLayerOutputStageKernel &&) = default; - /** Default destructor */ - ~NEDirectConvolutionLayerOutputStageKernel() = default; - /** Set the accumulate buffer and the biases of the kernel. - * - * @param[in, out] input Input to add the bias to. If @p output is not specified then accumulation is done in-place. - * Data type supported: F16/F32/S32 - * @param[in] bias (Optional) The shared bias tensor to add. It must be 1D Tensor. Data type supported: Same as @p input - * @param[out] output (Optional) If the output tensor is specified the accumulation is done out-of-place. (Defaults to nullptr) - * Note that in-place computation is only supported for F16/F32. For S32 this must not be nullptr. - * Data type supported: F16/F32 or QASYMM8/QASYMM8_SIGNED if @p input is S32 - * @param[in] info (Optional) DirectConvolutionLayerOutputStageKernel descriptor metadata - */ - void configure(ITensor *input, const ITensor *bias = nullptr, ITensor *output = nullptr, - const DirectConvolutionLayerOutputStageKernelInfo &info = DirectConvolutionLayerOutputStageKernelInfo()); - /** Static function to check if given info will lead to a valid configuration of @ref NEDirectConvolutionLayerOutputStageKernel - * - * @param[in] input Input to add the bias to. If @p output is not specified then accumulation is done in-place. - * Data type supported: F16/F32/S32 - * @param[in] bias (Optional) The shared bias tensor to add. It must be 1D Tensor. Data type supported: Same as @p input - * @param[in] output (Optional) If the output tensor is specified the accumulation is done out-of-place. (Defaults to nullptr) - * Note that in-place computation is only supported for F16/F32. For S32 this must not be nullptr. - * Data type supported: F16/F32 or QASYMM8/QASYMM8_SIGNED if @p input is S32 - * @param[in] info (Optional) DirectConvolutionLayerOutputStageKernel descriptor metadata - * - * @return a status - */ - static Status validate(const ITensorInfo *input, const ITensorInfo *bias = nullptr, const ITensorInfo *output = nullptr, - const DirectConvolutionLayerOutputStageKernelInfo &info = DirectConvolutionLayerOutputStageKernelInfo()); - - // Inherited methods overridden: - void run(const Window &window, const ThreadInfo &info) override; - -private: - using OutputStageKernel = void(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, - int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias); - -private: - OutputStageKernel *_func; - ITensor *_input; - const ITensor *_bias; - ITensor *_output; - int _result_fixedpoint_multiplier; - int _result_shift; - int _result_offset_after_shift; -}; -} // namespace arm_compute -#endif /*ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYEROUTPUTSTAGEKERNEL_H */ diff --git a/src/core/cpu/kernels/CpuDirectConvolutionKernel.cpp b/src/core/cpu/kernels/CpuDirectConvolutionKernel.cpp new file mode 100644 index 0000000000..4f46eb2bf6 --- /dev/null +++ b/src/core/cpu/kernels/CpuDirectConvolutionKernel.cpp @@ -0,0 +1,1385 @@ +/* + * Copyright (c) 2017-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/cpu/kernels/CpuDirectConvolutionKernel.h" + +#include "src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h" +#include "src/core/NEON/wrapper/wrapper.h" + +#include "arm_compute/core/Error.h" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/IAccessWindow.h" +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "src/core/AccessWindowStatic.h" +#include "src/core/CPP/Validate.h" +#include "src/core/NEON/NEFixedPoint.h" +#include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/WindowHelpers.h" + +#include + +using namespace arm_compute::detail; + +namespace arm_compute +{ +namespace cpu +{ +namespace kernels +{ +namespace +{ +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template +float16x8_t internal_vld1q(const float16_t *in); + +template <> +float16x8_t internal_vld1q<1>(const float16_t *in) +{ + return vld1q_f16(in); +} + +template <> +float16x8_t internal_vld1q<2>(const float16_t *in) +{ + const float16x8x2_t tmp = vld2q_f16(in); + return tmp.val[0]; +} + +template <> +float16x8_t internal_vld1q<3>(const float16_t *in) +{ + const float16x8x3_t tmp = vld3q_f16(in); + return tmp.val[0]; +} + +inline float16x8_t internal_vdupq_n(float16_t v) +{ + return vdupq_n_f16(v); +} + +inline void internal_vst1q(float16_t *p, const float16x8_t &v) +{ + vst1q_f16(p, v); +} + +float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y) +{ + return vmulq_f16(x, y); +} + +inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z) +{ + return vaddq_f16(x, vmulq_f16(y, z)); +} +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + +template +float32x4_t internal_vld1q(const float *in); + +template <> +float32x4_t internal_vld1q<1>(const float *in) +{ + return vld1q_f32(in); +} + +template <> +float32x4_t internal_vld1q<2>(const float *in) +{ + const float32x4x2_t tmp = vld2q_f32(in); + return tmp.val[0]; +} + +template <> +float32x4_t internal_vld1q<3>(const float *in) +{ + const float32x4x3_t tmp = vld3q_f32(in); + return tmp.val[0]; +} + +inline float32x4_t internal_vdupq_n(float v) +{ + return vdupq_n_f32(v); +} + +inline void internal_vst1q(float *p, const float32x4_t &v) +{ + vst1q_f32(p, v); +} + +float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y) +{ + return vmulq_f32(x, y); +} + +inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z) +{ + return vmlaq_f32(x, y, z); +} + +constexpr int small_tensor_size_optim = 8; +inline bool run_optim_small_tensor_info(const ITensorInfo *t) +{ + return t->dimension(Window::DimX) <= small_tensor_size_optim && t->dimension(Window::DimY) <= small_tensor_size_optim; +} + +inline bool run_optim_small_tensor(const ITensor *t) +{ + return run_optim_small_tensor_info(t->info()); +} + +// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8 +// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to +// store intermidiate results in memory. Temporary results are stored in SIMD registers directly and then written to the output buffer. +template +class convolver_w1x1_i8x8_f32 +{ +public: + static void convolve(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info) + { + ARM_COMPUTE_ERROR_ON(src->info()->dimension(Window::DimX) > small_tensor_size_optim); + ARM_COMPUTE_ERROR_ON(src->info()->dimension(Window::DimY) > small_tensor_size_optim); + + const int input_stride_x = src->info()->strides_in_bytes().x(); + const int input_stride_y = src->info()->strides_in_bytes().y(); + const int input_stride_z = src->info()->strides_in_bytes().z(); + const int output_stride_y = dst->info()->strides_in_bytes().y(); + const int output_stride_z = dst->info()->strides_in_bytes().z(); + const int kernel_stride_z = weights->info()->strides_in_bytes().z(); + const int kernel_stride_w = weights->info()->strides_in_bytes()[3]; + const int output_h = dst->info()->dimension(1); + const int range_z = window.z().end() - window.z().start(); + const int kernel_depth = weights->info()->dimension(Window::DimZ); + const unsigned int conv_stride_y = std::get<1>(conv_info.stride()); + const unsigned int conv_pad_left = conv_info.pad_left(); + const unsigned int conv_pad_top = conv_info.pad_top(); + + // setup output window for the iterator + Window window_out = window; + window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX))); + window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY))); + window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z)); + + // setup input window for the iterator + Window window_in = window; + // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0 + window_in.set(Window::DimX, Window::Dimension(0, 0, 0)); + window_in.set(Window::DimY, Window::Dimension(0, 0, 0)); + window_in.set(Window::DimZ, Window::Dimension(0, 0, 0)); + + Window window_k = calculate_max_window(*weights->info(), Steps(1u)); + Iterator out(dst, window_out); + Iterator in(src, window_in); + Iterator k(weights, window_k); + + const uint8_t *k_ptr = k.ptr(); + + execute_window_loop(window_out, [&](const Coordinates & id) + { + const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y; + uint8_t *out_ptr = out.ptr(); + int ih = 0; + int oh = 0; + std::array accum0 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) }; + std::array accum1 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) }; + for(int oz = 0; oz < range_z; ++oz) + { + accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f); + accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f); + auto p_out_base = out_ptr + oz * output_stride_z; + for(int p = 0; p < kernel_depth; ++p) + { + const auto k_val = reinterpret_cast(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w); + const auto vk0 = internal_vdupq_n(*k_val); + for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) + { + const int offset_xy = ih * input_stride_y; + auto in_val = reinterpret_cast(input_ptr + p * input_stride_z + offset_xy); + auto v_in0 = internal_vld1q(in_val); + auto v_in1 = internal_vld1q(in_val + 4); + accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0); + accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1); + } + } + for(oh = 0; oh < output_h; ++oh) + { + auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); + vst1q_f32(p_out, accum0[oh]); + vst1q_f32(p_out + 4, accum1[oh]); + } + } + }, + in, out); + } +}; + +template +class convolver_1x1 +{ +public: + static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, + const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info) + { + const int input_stride_x = src->info()->strides_in_bytes().x(); + const int input_stride_y = src->info()->strides_in_bytes().y(); + const int input_stride_z = src->info()->strides_in_bytes().z(); + const int output_stride_y = dst->info()->strides_in_bytes().y(); + const int output_stride_z = dst->info()->strides_in_bytes().z(); + const int kernel_stride_z = weights->info()->strides_in_bytes().z(); + const int kernel_stride_w = weights->info()->strides_in_bytes()[3]; + const int output_w = dst->info()->dimension(0); + const int output_h = dst->info()->dimension(1); + const int range_z = window.z().end() - window.z().start(); + const int kernel_depth = weights->info()->dimension(Window::DimZ); + const unsigned int conv_stride_y = std::get<1>(conv_info.stride()); + const unsigned int conv_pad_left = conv_info.pad_left(); + const unsigned int conv_pad_top = conv_info.pad_top(); + + // setup output window for the iterator + Window window_out = window; + window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX))); + window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY))); + window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z)); + + // setup input window for the iterator + Window window_in = window; + // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0 + window_in.set(Window::DimX, Window::Dimension(0, 0, 0)); + window_in.set(Window::DimY, Window::Dimension(0, 0, 0)); + window_in.set(Window::DimZ, Window::Dimension(0, 0, 0)); + + Window window_k = calculate_max_window(*weights->info(), Steps(1u)); + Iterator out(dst, window_out); + Iterator in(src, window_in); + Iterator k(weights, window_k); + + const uint8_t *k_ptr = k.ptr(); + + execute_window_loop(window_out, [&](const Coordinates & id) + { + /* + For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1> + */ + const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y; + uint8_t *out_ptr = out.ptr(); + int ih = 0; + int oh = 0; + for(int oz = 0; oz < range_z; ++oz) + { + auto p_out_base = out_ptr + oz * output_stride_z; + // Step 1 + { + const auto k_val = reinterpret_cast(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w); + const auto vk = internal_vdupq_n(*k_val); + for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) + { + const int offset_xy = ih * input_stride_y; + auto in_val = reinterpret_cast(input_ptr + (0 * input_stride_z + offset_xy)); + auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); + for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration) + { + internal_vst1q(p_out, internal_vmull(vk, internal_vld1q(in_val))); + } + } + } + + // Step 2 + for(int p = 1; p < kernel_depth; ++p) + { + const auto k_val = reinterpret_cast(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w); + const auto vk = internal_vdupq_n(*k_val); + for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) + { + const int offset_xy = ih * input_stride_y; + auto in_val = reinterpret_cast(input_ptr + p * input_stride_z + offset_xy); + auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); + for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration) + { + internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q(in_val))); + } + } + } + } + }, + in, out); + } +}; + +template +float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4, + const float *m0, const float *m1, const float *m2, const float *m3, const float *m4); + +inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2) +{ + const float32x4x3_t m00 = + { + { + vld1q_dup_f32(m0), + vld1q_dup_f32(m1), + vld1q_dup_f32(m2) + } + }; + return m00; +} + +inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4) +{ + const float32x4x2_t m00 = + { + { + vld1q_dup_f32(m3), + vld1q_dup_f32(m4) + } + }; + return m00; +} + +inline float32x4x3_t load_input(const float *const in) +{ + const float32x4x3_t vin = + { + { + vld1q_f32(in), + vld1q_f32(in + 4), + vld1q_f32(in + 8) + } + }; + return vin; +} + +template <> +inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4, + const float *m0, const float *m1, const float *m2, const float *m3, const float *m4) +{ + const float32x4x3_t vin0 = load_input(in_0); + const float32x4x3_t vin1 = load_input(in_1); + const float32x4x3_t vin2 = load_input(in_2); + const float32x4x3_t vin3 = load_input(in_3); + const float32x4x3_t vin4 = load_input(in_4); + const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0); + const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0); + const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1); + const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1); + const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2); + const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2); + const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3); + const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3); + const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4); + const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4); + + float32x4x2_t out = + { + { + vmulq_f32(vin0.val[0], m00.val[0]), + vmulq_f32(vin0.val[1], m00.val[0]) + } + }; + + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]); + + out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]); + + out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]); + + out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]); + + out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]); + + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]); + + out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]); + + out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]); + + out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]); + + out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]); + + return out; +} + +template <> +inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4, + const float *m0, const float *m1, const float *m2, const float *m3, const float *m4) +{ + float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4); + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1); + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2); + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3); + return out; +} + +template <> +inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4, + const float *m0, const float *m1, const float *m2, const float *m3, const float *m4) +{ + float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4); + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1); + return out; +} + +template +class convolver_3x3 +{ +public: + static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, + const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info) + { + ARM_COMPUTE_UNUSED(num_elems_read_per_iteration); + const int input_stride_x = src->info()->strides_in_bytes().x(); + const int input_stride_y = src->info()->strides_in_bytes().y(); + const int input_stride_z = src->info()->strides_in_bytes().z(); + const int output_stride_y = dst->info()->strides_in_bytes().y(); + const int output_stride_z = dst->info()->strides_in_bytes().z(); + const int kernel_stride_x = weights->info()->strides_in_bytes().x(); + const int kernel_stride_y = weights->info()->strides_in_bytes().y(); + const int kernel_stride_z = weights->info()->strides_in_bytes().z(); + const int kernel_stride_w = weights->info()->strides_in_bytes()[3]; + const int output_w = dst->info()->dimension(0); + const int output_h = dst->info()->dimension(1); + const int num_planes_z = window.z().end() - window.z().start(); + const int delta_input = get_input_num_elems_processed(num_elems_written_per_iteration, stridex); + const int kernel_depth = weights->info()->dimension(Window::DimZ); + const unsigned int conv_stride_y = std::get<1>(conv_info.stride()); + const unsigned int conv_pad_left = conv_info.pad_left(); + const unsigned int conv_pad_top = conv_info.pad_top(); + + // setup output window for the iterator + Window window_out = window; + window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX))); + window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY))); + window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z)); + + // setup input window for the iterator + Window window_in = window; + // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0 + window_in.set(Window::DimX, Window::Dimension(0, 0, 0)); + window_in.set(Window::DimY, Window::Dimension(0, 0, 0)); + window_in.set(Window::DimZ, Window::Dimension(0, 0, 0)); + + Window window_k = calculate_max_window(*weights->info(), Steps(1u)); + + Iterator out(dst, window_out); + Iterator in(src, window_in); + Iterator k(weights, window_k); + + const uint8_t *k_ptr = k.ptr(); + + execute_window_loop(window_out, [&](const Coordinates & id) + { + const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y; + uint8_t *out_ptr = out.ptr(); + int ih = 0; + int oh = 0; + /* + Each thread executing this kernel computes one or more output's volume planes. + + Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15], + the third thread [16,24] and the fourth thread [25,31]. + + The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this + is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value. + + The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages: + 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values. + 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1. + */ + for(int oz = 0; oz < num_planes_z; ++oz) + { + const int zoffset = id.z() + oz; + uint8_t *p_out_base = out_ptr + oz * output_stride_z; + // Step 1 + { + const auto ptr_k_r0 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x); + const auto ptr_k_r1 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x); + const auto ptr_k_r2 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x); + const auto vk_r0 = load_matrix_row(ptr_k_r0); + const auto vk_r1 = load_matrix_row(ptr_k_r1); + const auto vk_r2 = load_matrix_row(ptr_k_r2); + for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) + { + auto in_top = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y); + auto in_mid = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y); + auto in_low = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y); + auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); + for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, + in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration) + { + convolve_3x3(in_top, in_mid, in_low, p_out, vk_r0, vk_r1, vk_r2, stridex); + } + } + } + // Step 2 + for(int p = 1; p < kernel_depth; ++p) + { + const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w; + const uint8_t *input_base = input_ptr + p * input_stride_z; + const auto ptr_k_r0 = reinterpret_cast(ptr_k_base); + const auto ptr_k_r1 = reinterpret_cast(ptr_k_base + kernel_stride_y); + const auto ptr_k_r2 = reinterpret_cast(ptr_k_base + kernel_stride_y * 2); + const auto vk_r0 = load_matrix_row(ptr_k_r0); + const auto vk_r1 = load_matrix_row(ptr_k_r1); + const auto vk_r2 = load_matrix_row(ptr_k_r2); + for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) + { + auto in_top = reinterpret_cast(input_base + (ih + 0) * input_stride_y); + auto in_mid = reinterpret_cast(input_base + (ih + 1) * input_stride_y); + auto in_low = reinterpret_cast(input_base + (ih + 2) * input_stride_y); + auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); + for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, + in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration) + { + convolve_3x3(in_top, in_mid, in_low, p_out, vk_r0, vk_r1, vk_r2, stridex); + } + } + } + } + }, + in, out); + } +}; + +template +class convolver_5x5 +{ +public: + static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, + const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info) + { + ARM_COMPUTE_UNUSED(num_elems_read_per_iteration); + const int input_stride_x = src->info()->strides_in_bytes().x(); + const int input_stride_y = src->info()->strides_in_bytes().y(); + const int input_stride_z = src->info()->strides_in_bytes().z(); + const int output_stride_y = dst->info()->strides_in_bytes().y(); + const int output_stride_z = dst->info()->strides_in_bytes().z(); + const int kernel_stride_x = weights->info()->strides_in_bytes().x(); + const int kernel_stride_y = weights->info()->strides_in_bytes().y(); + const int kernel_stride_z = weights->info()->strides_in_bytes().z(); + const int kernel_stride_w = weights->info()->strides_in_bytes()[3]; + const int output_w = dst->info()->dimension(0); + const int output_h = dst->info()->dimension(1); + const int num_planes_z = window.z().end() - window.z().start(); + const int delta_input = get_input_num_elems_processed(num_elems_written_per_iteration, stridex); + const int kernel_depth = weights->info()->dimension(Window::DimZ); + const unsigned int conv_stride_y = std::get<1>(conv_info.stride()); + const unsigned int conv_pad_left = conv_info.pad_left(); + const unsigned int conv_pad_top = conv_info.pad_top(); + + // setup output window for the iterator + Window window_out = window; + window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX))); + window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY))); + window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z)); + + // setup input window for the iterator + Window window_in = window; + // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0 + window_in.set(Window::DimX, Window::Dimension(0, 0, 0)); + window_in.set(Window::DimY, Window::Dimension(0, 0, 0)); + window_in.set(Window::DimZ, Window::Dimension(0, 0, 0)); + + Window window_k = calculate_max_window(*weights->info(), Steps(1u)); + + Iterator out(dst, window_out); + Iterator in(src, window_in); + Iterator k(weights, window_k); + + const uint8_t *k_ptr = k.ptr(); + + execute_window_loop(window_out, [&](const Coordinates & id) + { + const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y; + uint8_t *out_ptr = out.ptr(); + int ih = 0; + int oh = 0; + for(int oz = 0; oz < num_planes_z; ++oz) + { + const int zoffset = id.z() + oz; + uint8_t *p_out_base = out_ptr + oz * output_stride_z; + // Step 1 + { + const auto ptr_k_r0 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x); + const auto ptr_k_r1 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x); + const auto ptr_k_r2 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x); + const auto ptr_k_r3 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x); + const auto ptr_k_r4 = reinterpret_cast(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x); + for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) + { + auto in_0 = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y); + auto in_1 = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y); + auto in_2 = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y); + auto in_3 = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y); + auto in_4 = reinterpret_cast(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y); + auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); + for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, + in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration) + { + auto vres = convolve_5x5(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4); + store_results(p_out, vres); + } + } + } + // Step 2 + for(int p = 1; p < kernel_depth; ++p) + { + const auto ptr_k_r0 = reinterpret_cast(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x); + const auto ptr_k_r1 = reinterpret_cast(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x); + const auto ptr_k_r2 = reinterpret_cast(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x); + const auto ptr_k_r3 = reinterpret_cast(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x); + const auto ptr_k_r4 = reinterpret_cast(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x); + + for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y) + { + auto in_0 = reinterpret_cast(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y); + auto in_1 = reinterpret_cast(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y); + auto in_2 = reinterpret_cast(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y); + auto in_3 = reinterpret_cast(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y); + auto in_4 = reinterpret_cast(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y); + auto p_out = reinterpret_cast(p_out_base + oh * output_stride_y); + for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, + in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration) + { + auto vres = convolve_5x5(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4); + accumulate_results(p_out, vres); + } + } + } + } + }, + in, out); + } +}; + +float vreduce(const float32x4_t &v) +{ + auto v0 = wrapper::vgethigh(v); + auto v1 = wrapper::vgetlow(v); + auto v_out = wrapper::vadd(v0, v1); + + float a = wrapper::vgetlane(v_out, 0); + float b = wrapper::vgetlane(v_out, 1); + return a + b; +} + +template +inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, + const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info) +{ + const unsigned int conv_stride_x = std::get<0>(conv_info.stride()); + switch(conv_stride_x) + { + case 1: + convolver_1x1::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info); + break; + case 2: + convolver_1x1::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info); + break; + case 3: + convolver_1x1::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info); + break; + default: + ARM_COMPUTE_ERROR("Not implemented"); + } +} + +template <> +inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, + const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info) +{ + const unsigned int conv_stride_x = std::get<0>(conv_info.stride()); + if(run_optim_small_tensor(src)) + { + switch(conv_stride_x) + { + case 1: + convolver_w1x1_i8x8_f32<1>::convolve(window, src, weights, dst, conv_info); + break; + case 2: + convolver_w1x1_i8x8_f32<2>::convolve(window, src, weights, dst, conv_info); + break; + case 3: + convolver_w1x1_i8x8_f32<3>::convolve(window, src, weights, dst, conv_info); + break; + default: + ARM_COMPUTE_ERROR("Not implemented"); + } + } + else + { + switch(conv_stride_x) + { + case 1: + convolver_1x1::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info); + break; + case 2: + convolver_1x1::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info); + break; + case 3: + convolver_1x1::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info); + break; + default: + ARM_COMPUTE_ERROR("Not implemented"); + } + } +} + +template +inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, + const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info) +{ + const unsigned int conv_stride_x = std::get<0>(conv_info.stride()); + switch(conv_stride_x) + { + case 1: + convolver_3x3::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info); + break; + case 2: + convolver_3x3::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info); + break; + case 3: + convolver_3x3::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info); + break; + default: + ARM_COMPUTE_ERROR("Not implemented"); + } +} + +template +inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, + const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info) +{ + const unsigned int conv_stride_x = std::get<0>(conv_info.stride()); + switch(conv_stride_x) + { + case 1: + convolver_5x5::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info); + break; + case 2: + convolver_5x5::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info); + break; + case 3: + convolver_5x5::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info); + break; + default: + ARM_COMPUTE_ERROR("Not implemented"); + } +} + +Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const PadStrideInfo &conv_info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst); + ARM_COMPUTE_RETURN_ERROR_ON(src->data_layout() == DataLayout::UNKNOWN); + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights); + + const DataLayout data_layout = src->data_layout(); + const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported."); + ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(channel_idx) != src->dimension(channel_idx)); + ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx)); + ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4); + ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && src->data_type() != DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(width_idx) > 3) && (src->data_type() == DataType::F16)); + + // Checks performed when output is configured + if(dst->total_size() != 0) + { + TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *weights, conv_info); + + DataType data_type = src->data_type(); + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), output_shape); + ARM_COMPUTE_RETURN_ERROR_ON(dst->data_type() != data_type); + } + + return Status{}; +} + +std::pair validate_and_configure_window(ITensorInfo *src, ITensorInfo *weights, ITensorInfo *dst, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row, + unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size) +{ + ARM_COMPUTE_ERROR_ON(src->data_layout() == DataLayout::UNKNOWN); + + const DataLayout data_layout = src->data_layout(); + const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + + // Calculate right and bottom border + unsigned int kernel_size = weights->dimension(width_idx); + const int conv_stride_x = std::get<0>(conv_info.stride()); + const int conv_stride_y = std::get<1>(conv_info.stride()); + const int input_width = src->dimension(width_idx); + + Window win{}; + bool window_changed = false; + + if(data_layout == DataLayout::NCHW) + { + switch(kernel_size) + { + case 1: + { + switch(src->data_type()) + { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + num_elems_written_per_iteration = 8; + break; +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + case DataType::F32: + if(run_optim_small_tensor_info(src)) + { + num_elems_written_per_iteration = 8; + } + else + { + num_elems_written_per_iteration = 4; + } + break; + default: + ARM_COMPUTE_ERROR("Data type not supported."); + break; + } + num_weight_elems_read_per_row = kernel_size; + num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration; + break; + } + case 3: + switch(src->data_type()) + { + case DataType::F32: + num_weight_elems_read_per_row = 4 + kernel_size - 1; + num_elems_read_per_iteration = 12; + num_elems_written_per_iteration = 16 >> conv_stride_x; + break; +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + num_weight_elems_read_per_row = 8 + kernel_size - 1; + num_elems_read_per_iteration = 24; + num_elems_written_per_iteration = 32 >> conv_stride_x; + break; +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + default: + ARM_COMPUTE_ERROR("Data type not supported."); + break; + } + break; + case 5: + { + switch(src->data_type()) + { + case DataType::F32: + num_weight_elems_read_per_row = 4 + kernel_size - 1; + num_elems_read_per_iteration = 12; + num_elems_written_per_iteration = 16 >> conv_stride_x; + break; + default: + ARM_COMPUTE_ERROR("Data type not supported."); + break; + } + } + break; + default: + { + ARM_COMPUTE_ERROR("Not implemented"); + break; + } + } + + // Calculate right pad + int start_x = kernel_size / 2 - static_cast(conv_info.pad_left()); + int end_x = ceil_to_multiple(static_cast(dst->dimension(0)), num_elems_written_per_iteration) * conv_stride_x; + int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width; + + // Calculate border + const unsigned int conv_pad_left = conv_info.pad_left(); + const unsigned int conv_pad_top = conv_info.pad_top(); + const unsigned int conv_pad_right = std::max(upper_bound_w, 0); + const unsigned int conv_pad_bottom = conv_info.pad_bottom(); + + border_size.left = conv_pad_left; + border_size.top = conv_pad_top; + border_size.right = conv_pad_right; + border_size.bottom = conv_pad_bottom; + + // Configure window + win = calculate_max_window(*dst, Steps(num_elems_written_per_iteration)); + + AccessWindowRectangle input_access(src, -conv_pad_left, -conv_pad_top, + num_elems_read_per_iteration, kernel_size, + conv_stride_x, conv_stride_y); + AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size); + AccessWindowHorizontal output_access(dst, 0, num_elems_written_per_iteration); + window_changed = update_window_and_padding(win, input_access, weights_access, output_access); + output_access.set_valid_region(win, ValidRegion(Coordinates(), dst->tensor_shape())); + } + else + { + // Configure window NHWC without any padding + win = calculate_max_window(*dst, Steps()); + } + + Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; + return std::make_pair(err, win); +} + +bool have_zero_x_internal_padding(ITensorInfo *src, ITensorInfo *weights) +{ + return (src->padding().left == 0 && weights->padding().left == 0 && src->padding().right == 0 && weights->padding().right == 0); +} + +} // namespace + +template +void CpuDirectConvolutionKernel::convolve_nhwc_optimized(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst) +{ + // This function assumes that input and weights have not padding in channel + + // Declare useful types + using vtype = wrapper::traits::neon_bitvector; + using vector_type = typename vtype::type; + using tag_type = typename vtype::tag_type; + + // Scalar quantities + const int element_size = src->info()->element_size(); + const int input_stride_w = src->info()->strides_in_bytes().y() / element_size; + const int input_stride_h = src->info()->strides_in_bytes().z() / element_size; + const int input_stride_n = src->info()->strides_in_bytes()[3] / element_size; + const int input_dim_w = src->info()->dimension(1); + const int input_dim_h = src->info()->dimension(2); + + const int output_stride_c = dst->info()->strides_in_bytes().x(); + + const unsigned int kernel_stride_w = weights->info()->strides_in_bytes().y() / element_size; + const unsigned int kernel_stride_h = weights->info()->strides_in_bytes().z() / element_size; + const int kernel_dim_w = weights->info()->dimension(1); + const int kernel_dim_h = weights->info()->dimension(2); + + const int conv_pad_top = _conv_info.pad_top(); + const int conv_pad_left = _conv_info.pad_left(); + const int conv_stride_w = std::get<0>(_conv_info.stride()); + const int conv_stride_h = std::get<1>(_conv_info.stride()); + + // Setup input window for the output iterator + Window window_out = window; + window_out.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Setup input window for the weights iterator + Window window_w = calculate_max_window(*weights->info(), Steps()); + window_w.set(Window::DimX, Window::Dimension(0, 1, 1)); + window_w.set(Window::DimY, Window::Dimension(0, 1, 1)); + window_w.set(Window::DimZ, Window::Dimension(0, 1, 1)); + + Iterator out(dst, window_out); + Iterator wei(weights, window_w); + + constexpr int num_elems_read_per_iteration = 16 / sizeof(T); + /* + * This implementation parallelize the full WC plane of input and weights by + * treating them as series of elements. So for example, a 3x3 weights and + * floating point vector operations of 4 elements per time, the first 3 + * channel elements of the first row would be taken and additionally the first + * element of the second row. The 9 elements in each single WC weight plane + * would require 2 4-element vector operations and a last single element operation. + * + * This works since when we create the input vector to multiply with the weights, + * the exact required elements are loaded in the same order. Therefore the + * multiplication works on the correct input/weight elements. + */ + execute_window_loop(window_out, [&](const Coordinates & id) + { + /* + * In here we create theoretical indexes which then we validate for both + * inputs and weights. + * As a reminder, this loop take each output point in NHW, C is treated + * in the weights loop. + */ + // We are computing the theoretical starting input starting points + const int in_w_start_t = static_cast(id.y()) * conv_stride_w - conv_pad_left; + const int in_h_start_t = static_cast(id.z()) * conv_stride_h - conv_pad_top; + const int in_w_end_t = in_w_start_t + kernel_dim_w; + const int in_h_end_t = in_h_start_t + kernel_dim_h; + + // We are computing the valid initial and ending input points by checking the borders + const int in_w_start = std::max(in_w_start_t, 0); + const int in_h_start = std::max(in_h_start_t, 0); + const int in_w_end = std::min(in_w_end_t, input_dim_w); + const int in_h_end = std::min(in_h_end_t, input_dim_h); + + // We use the input points to select the valid weight points to use + const int index_wc_start = (in_w_start - in_w_start_t) * kernel_stride_w; + const int index_h_start = in_h_start - in_h_start_t; + const int index_wc_end = (kernel_dim_w - (in_w_end_t - in_w_end)) * kernel_stride_w; + const int index_h_end = kernel_dim_h - (in_h_end_t - in_h_end); + + execute_window_loop(window_w, [&](const Coordinates & id_w) + { + /* + * This is the loop in the weights, and it goes along N (the batches) + * As a reminder, the batches of the weights are translated into the + * channels of the output + */ + const T *in_ptr_row = reinterpret_cast(src->buffer() + src->info()->offset_first_element_in_bytes()) + + id[3] * input_stride_n + in_w_start * input_stride_w + in_h_start * input_stride_h; + const T *weights_ptr_row = reinterpret_cast(wei.ptr()) + index_h_start * kernel_stride_h; + uint8_t *out_ptr = out.ptr() + id_w[3] * output_stride_c; + + T out_temp = static_cast(0); + for(int index_h = index_h_start; index_h < index_h_end; ++index_h, in_ptr_row += input_stride_h, weights_ptr_row += kernel_stride_h) + { + const T *in_ptr_mover = in_ptr_row; + int index_wc = index_wc_start; + vector_type out_temp_vec = wrapper::vdup_n(static_cast(0), tag_type()); + for(; index_wc <= index_wc_end - num_elems_read_per_iteration; index_wc += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration) + { + const auto src_vec = wrapper::vloadq(in_ptr_mover); + const auto w_vec = wrapper::vloadq(weights_ptr_row + index_wc); + out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec); + } + out_temp += vreduce(out_temp_vec); + for(; index_wc < index_wc_end; ++index_wc, ++in_ptr_mover) + { + const auto src_val = *(in_ptr_mover); + const auto w_val = *(weights_ptr_row + index_wc); + out_temp += src_val * w_val; + } + } + *(reinterpret_cast(out_ptr)) = out_temp; + }, + wei); + }, + out); +} + +template +void CpuDirectConvolutionKernel::convolve_nhwc(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst) +{ + // Declare useful types + using vtype = wrapper::traits::neon_bitvector; + using vector_type = typename vtype::type; + using tag_type = typename vtype::tag_type; + + // Scalar quantities + const int element_size = src->info()->element_size(); + const int input_stride_w = src->info()->strides_in_bytes().y() / element_size; + const int input_stride_h = src->info()->strides_in_bytes().z() / element_size; + const int input_stride_n = src->info()->strides_in_bytes()[3] / element_size; + const int input_dim_w = src->info()->dimension(1); + const int input_dim_h = src->info()->dimension(2); + + const int output_stride_c = dst->info()->strides_in_bytes().x(); + + const unsigned int kernel_stride_w = weights->info()->strides_in_bytes().y() / element_size; + const unsigned int kernel_stride_h = weights->info()->strides_in_bytes().z() / element_size; + const int kernel_dim_w = weights->info()->dimension(1); + const int kernel_dim_h = weights->info()->dimension(2); + + const int conv_pad_top = _conv_info.pad_top(); + const int conv_pad_left = _conv_info.pad_left(); + const int conv_stride_w = std::get<0>(_conv_info.stride()); + const int conv_stride_h = std::get<1>(_conv_info.stride()); + + // Setup input window for the output iterator + Window window_out = window; + window_out.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Setup input window for the weights iterator + Window window_w = calculate_max_window(*weights->info(), Steps()); + window_w.set(Window::DimX, Window::Dimension(0, 1, 1)); + window_w.set(Window::DimY, Window::Dimension(0, 1, 1)); + window_w.set(Window::DimZ, Window::Dimension(0, 1, 1)); + + Iterator out(dst, window_out); + Iterator wei(weights, window_w); + + constexpr int num_elems_read_per_iteration = 16 / sizeof(T); + + execute_window_loop(window_out, [&](const Coordinates & id) + { + // We are computing the theoretical starting input starting points + const int in_w_start_t = static_cast(id.y()) * conv_stride_w - conv_pad_left; + const int in_h_start_t = static_cast(id.z()) * conv_stride_h - conv_pad_top; + const int in_w_end_t = in_w_start_t + kernel_dim_w; + const int in_h_end_t = in_h_start_t + kernel_dim_h; + + // We are computing the valid initial and ending input points by checking the borders + const int in_w_start = std::max(in_w_start_t, 0); + const int in_h_start = std::max(in_h_start_t, 0); + const int in_w_end = std::min(in_w_end_t, input_dim_w); + const int in_h_end = std::min(in_h_end_t, input_dim_h); + + // We use the input points to select the valid weight points to use + const int wei_w_start = in_w_start - in_w_start_t; + const int wei_h_start = in_h_start - in_h_start_t; + const int wei_w_end = kernel_dim_w - (in_w_end_t - in_w_end); + const int wei_h_end = kernel_dim_h - (in_h_end_t - in_h_end); + + const int index_c_end = weights->info()->dimension(0); + const T *const in_ptr_start = reinterpret_cast(src->buffer() + src->info()->offset_first_element_in_bytes()) + id[3] * input_stride_n; + + execute_window_loop(window_w, [&](const Coordinates & id_w) + { + const T *const weights_ptr_start = reinterpret_cast(wei.ptr()); + uint8_t *out_ptr = out.ptr() + id_w[3] * output_stride_c; + + T out_temp = static_cast(0); + for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h) + { + const T *const in_ptr_row = in_ptr_start + index_in_h * input_stride_h; + const T *const weights_ptr_row = weights_ptr_start + index_wei_h * kernel_stride_h; + for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w) + { + const T *in_ptr_mover = in_ptr_row + index_in_w * input_stride_w; + const T *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w; + int index_c = 0; + vector_type out_temp_vec = wrapper::vdup_n(static_cast(0), tag_type()); + for(; index_c <= index_c_end - num_elems_read_per_iteration; index_c += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration, weights_ptr_mover += num_elems_read_per_iteration) + { + const auto src_vec = wrapper::vloadq(in_ptr_mover); + const auto w_vec = wrapper::vloadq(weights_ptr_mover); + out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec); + } + out_temp += vreduce(out_temp_vec); + for(; index_c < index_c_end; ++index_c, ++in_ptr_mover, ++weights_ptr_mover) + { + const auto src_val = *(in_ptr_mover); + const auto w_val = *(weights_ptr_mover); + out_temp += src_val * w_val; + } + } + } + *(reinterpret_cast(out_ptr)) = out_temp; + }, + wei); + }, + out); +} + +BorderSize CpuDirectConvolutionKernel::border_size() const +{ + return _border_size; +} + +void CpuDirectConvolutionKernel::configure(ITensorInfo *src, ITensorInfo *weights, ITensorInfo *dst, const PadStrideInfo &conv_info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst); + + _conv_info = conv_info; + _data_layout = src->data_layout(); + _kernel_size = weights->dimension(get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH)); + + const unsigned int conv_pad_left = conv_info.pad_left(); + const unsigned int conv_pad_top = conv_info.pad_top(); + const unsigned int conv_pad_right = conv_info.pad_right(); + const unsigned int conv_pad_bottom = conv_info.pad_bottom(); + if(_data_layout == DataLayout::NCHW) + { + _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left); + } + else + { + _border_size = BorderSize(0); + } + + // Get convolved dimensions + TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *weights, conv_info); + + DataType data_type = src->data_type(); + + // Output auto inizialitation if not yet initialized + auto_init_if_empty(*dst, output_shape, 1, data_type); + + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, dst, conv_info)); + + // Configure kernel window + auto win_config = validate_and_configure_window(src, weights, dst, conv_info, _num_weight_elems_read_per_row, + _num_elems_read_per_iteration, _num_elems_written_per_iteration, _border_size); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICpuKernel::configure(win_config.second); +} + +Status CpuDirectConvolutionKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const PadStrideInfo &conv_info) +{ + unsigned int num_weight_elems_read_per_row = 0; + unsigned int num_elems_read_per_iteration = 0; + unsigned int num_elems_written_per_iteration = 0; + BorderSize border_size = {}; + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, dst, conv_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(src->clone().get(), + weights->clone().get(), + dst->clone().get(), + conv_info, + num_weight_elems_read_per_row, + num_elems_read_per_iteration, + num_elems_written_per_iteration, + border_size) + .first); + + return Status{}; +} + +void CpuDirectConvolutionKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window); + + auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0); + auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto dst = tensors.get_tensor(TensorType::ACL_DST); + const int kernel_size = weights->info()->dimension(get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH)); + + if(_data_layout == DataLayout::NCHW) + { + switch(kernel_size) + { + case 1: + { + switch(src->info()->data_type()) + { + case DataType::F32: + convolve_1x1(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info); + break; +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + convolve_1x1(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info); + break; +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + default: + ARM_COMPUTE_ERROR("Data type not supported"); + break; + } + break; + } + case 3: + { + switch(src->info()->data_type()) + { + case DataType::F32: + convolve_3x3(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info); + break; +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + convolve_3x3(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info); + break; +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + default: + ARM_COMPUTE_ERROR("Data type not supported"); + break; + } + break; + } + case 5: + { + switch(src->info()->data_type()) + { + case DataType::F32: + convolve_5x5(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info); + break; + default: + ARM_COMPUTE_ERROR("Data type not supported"); + break; + } + break; + } + default: + { + ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported."); + break; + } + } + } + else + { + switch(src->info()->data_type()) + { + case DataType::F32: + { + if(have_zero_x_internal_padding(src->info(), weights->info())) + { + convolve_nhwc_optimized(window, src, weights, dst); + } + else + { + convolve_nhwc(window, src, weights, dst); + } + break; + } + default: + ARM_COMPUTE_ERROR("Data type not supported"); + break; + } + } +} +const char *CpuDirectConvolutionKernel::name() const +{ + return "CpuDirectConvolutionLayerKernel"; +} +} // namespace kernels +} // namespace cpu +} // namespace arm_compute diff --git a/src/core/cpu/kernels/CpuDirectConvolutionKernel.h b/src/core/cpu/kernels/CpuDirectConvolutionKernel.h new file mode 100644 index 0000000000..fb8218394b --- /dev/null +++ b/src/core/cpu/kernels/CpuDirectConvolutionKernel.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2017-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CPU_DIRECTCONVOLUTION_KERNEL_H +#define ARM_COMPUTE_CPU_DIRECTCONVOLUTION_KERNEL_H + +#include "src/core/common/Macros.h" +#include "src/core/cpu/ICpuKernel.h" + +namespace arm_compute +{ +class ITensor; +namespace cpu +{ +namespace kernels +{ +/** Interface for the kernel to perform Direct Convolution Layer. */ +class CpuDirectConvolutionKernel : public ICpuKernel +{ +public: + /** Default constructor */ + CpuDirectConvolutionKernel() = default; + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuDirectConvolutionKernel); + /** Set the input, weights, and output tensors. + * + * @note: DirectConvolution only works in the following configurations: + * 1x1 convolution with stride_x = 1/2/3, stride_y = 1/2/3 + * 3x3 convolution with stride_x = 1/2/3, stride_y = 1/2/3 + * + * @param[in] src The input tensor to convolve. 3 lower dimensions represent a single input [width, height, IFM], + * while every optional dimension from 4 and above represent a batch of inputs. Data types supported: F16/F32. + * @param[in] weights Weights tensor. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. + * The 3rd dimension must be the same as the input's volume 3rd dimension. + * Data type supported:Same as @p input. + * @param[out] dst Output tensor. + * The 3rd dimensions must be equal to the 4th dimension of the @p kernels tensor. Data types supported: F16/F32 + * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. + */ + void configure(ITensorInfo *src, ITensorInfo *weights, ITensorInfo *dst, const PadStrideInfo &conv_info); + /** Static function to check if given info will lead to a valid configuration of @ref CpuDirectConvolutionKernel + * + * @param[in] src The input tensor to convolve. 3 lower dimensions represent a single input [width, height, IFM], + * while every optional dimension from 4 and above represent a batch of inputs. Data types supported: F16/F32. + * @param[in] weights Weights tensor. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. + * The 3rd dimension must be the same as the input's volume 3rd dimension. + * Data type supported:Same as @p input. + * @param[in] dst Output tensor. + * The 3rd dimensions must be equal to the 4th dimension of the @p kernels tensor. Data types supported: F16/F32 + * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. + * + * @return a status + */ + static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const PadStrideInfo &conv_info); + + // Inherited methods overridden: + void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; + const char *name() const override; + BorderSize border_size() const override; + +private: + /* Template function for optimized convolution NHWC */ + template + void convolve_nhwc_optimized(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst); + + /* Template function for convolution NHWC */ + template + void convolve_nhwc(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst); + + PadStrideInfo _conv_info{}; + BorderSize _border_size{}; + unsigned int _kernel_size{ 0 }; + unsigned int _num_weight_elems_read_per_row{ 0 }; + unsigned int _num_elems_read_per_iteration{ 0 }; + unsigned int _num_elems_written_per_iteration{ 0 }; + DataLayout _data_layout{ DataLayout::UNKNOWN }; +}; +} // namespace kernels +} // namespace cpu +} // namespace arm_compute +#endif /*ARM_COMPUTE_CPU_DIRECTCONVOLUTION_KERNEL_H */ diff --git a/src/core/cpu/kernels/CpuDirectConvolutionOutputStageKernel.h b/src/core/cpu/kernels/CpuDirectConvolutionOutputStageKernel.h new file mode 100644 index 0000000000..9eeab194cb --- /dev/null +++ b/src/core/cpu/kernels/CpuDirectConvolutionOutputStageKernel.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2017-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CPU_DIRECTCONVOLUTION_OUTPUTSTAGE_KERNEL_H +#define ARM_COMPUTE_CPU_DIRECTCONVOLUTION_OUTPUTSTAGE_KERNEL_H + +#include "arm_compute/core/KernelDescriptors.h" +#include "src/core/common/Macros.h" +#include "src/core/cpu/ICpuKernel.h" + +namespace arm_compute +{ +class ITensor; +namespace cpu +{ +namespace kernels +{ +/** Kernel to accumulate the biases, if provided, or downscale in case of quantized input. + * + * @note We assume bias to be shared + * @note For quantized computations (i.e. @p src of S32 type) the output data type for auto-initialization must be passed as part + * of the @ref DirectConvolutionLayerOutputStageKernelInfo. + */ +class CpuDirectConvolutionOutputStageKernel : public ICpuKernel +{ +public: + /** Default constructor */ + CpuDirectConvolutionOutputStageKernel() = default; + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuDirectConvolutionOutputStageKernel); + /** Set the accumulate buffer and the biases of the kernel. + * + * @param[in, out] src Input to add the bias to. If @p output is not specified then accumulation is done in-place. + * Data type supported: F16/F32/S32 + * @param[in] bias (Optional) The shared bias tensor to add. It must be 1D Tensor. Data type supported: Same as @p src + * @param[out] dst (Optional) If the output tensor is specified the accumulation is done out-of-place. (Defaults to nullptr) + * Note that in-place computation is only supported for F16/F32. For S32 this must not be nullptr. + * Data type supported: F16/F32 or QASYMM8/QASYMM8_SIGNED if @p src is S32 + * @param[in] info (Optional) DirectConvolutionLayerOutputStageKernel descriptor metadata + */ + void configure(ITensorInfo *src, const ITensorInfo *bias = nullptr, ITensorInfo *dst = nullptr, + const DirectConvolutionLayerOutputStageKernelInfo &info = DirectConvolutionLayerOutputStageKernelInfo()); + /** Static function to check if given info will lead to a valid configuration of @ref CpuDirectConvolutionOutputStageKernel + * + * @param[in] src Input to add the bias to. If @p output is not specified then accumulation is done in-place. + * Data type supported: F16/F32/S32 + * @param[in] bias (Optional) The shared bias tensor to add. It must be 1D Tensor. Data type supported: Same as @p src + * @param[in] dst (Optional) If the output tensor is specified the accumulation is done out-of-place. (Defaults to nullptr) + * Note that in-place computation is only supported for F16/F32. For S32 this must not be nullptr. + * Data type supported: F16/F32 or QASYMM8/QASYMM8_SIGNED if @p src is S32 + * @param[in] info (Optional) DirectConvolutionLayerOutputStageKernel descriptor metadata + * + * @return a status + */ + static Status validate(const ITensorInfo *src, const ITensorInfo *bias = nullptr, const ITensorInfo *dst = nullptr, + const DirectConvolutionLayerOutputStageKernelInfo &info = DirectConvolutionLayerOutputStageKernelInfo()); + + // Inherited methods overridden: + void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; + const char *name() const override; + +private: + using OutputStageKernel = void(ITensor *src, const ITensor *bias, const Window &window, ITensor *dst, + int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift); + + OutputStageKernel *_func{ nullptr }; + int _result_fixedpoint_multiplier{ 0 }; + int _result_shift{ 0 }; + int _result_offset_after_shift{ 0 }; +}; +} // namespace kernels +} // namespace cpu +} // namespace arm_compute +#endif /*ARM_COMPUTE_CPU_DIRECTCONVOLUTION_OUTPUTSTAGE_KERNEL_H */ diff --git a/src/core/cpu/kernels/CpuDirectConvolutionStageKernel.cpp b/src/core/cpu/kernels/CpuDirectConvolutionStageKernel.cpp new file mode 100644 index 0000000000..d955b0b461 --- /dev/null +++ b/src/core/cpu/kernels/CpuDirectConvolutionStageKernel.cpp @@ -0,0 +1,514 @@ +/* + * Copyright (c) 2017-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/cpu/kernels/CpuDirectConvolutionOutputStageKernel.h" + +#include "arm_compute/core/Error.h" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/core/Window.h" +#include "arm_compute/core/utils/misc/Traits.h" +#include "src/core/AccessWindowStatic.h" +#include "src/core/CPP/Validate.h" +#include "src/core/NEON/NEAsymm.h" +#include "src/core/NEON/NEFixedPoint.h" +#include "src/core/NEON/wrapper/wrapper.h" +#include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/WindowHelpers.h" + +#include +#include +#include + +namespace arm_compute +{ +namespace cpu +{ +namespace kernels +{ +namespace +{ +Status validate_arguments(const ITensorInfo *src, const ITensorInfo *bias, const ITensorInfo *dst, + const DirectConvolutionLayerOutputStageKernelInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src); + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src); + ARM_COMPUTE_RETURN_ERROR_ON(src->data_layout() == DataLayout::UNKNOWN); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::S32, DataType::F32); + + if(bias != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, bias); + ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != src->dimension(get_data_layout_dimension_index(src->data_layout(), DataLayoutDimension::CHANNEL))); + ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1); + } + + if(src->data_type() == DataType::S32) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst == nullptr, "In-place computation not allowed for quantized output"); + } + + // Checks performed when output is configured + if((dst != nullptr) && (dst->total_size() != 0)) + { + if(is_data_type_float(src->data_type())) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED); + } + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(src, dst); + } + else if(src->data_type() == DataType::S32) + { + // In case of quantized computation and unconfigured output, the output data type must be provided through DirectConvolutionLayerOutputStageKernelInfo + ARM_COMPUTE_RETURN_ERROR_ON((info.output_data_type != DataType::QASYMM8) && (info.output_data_type != DataType::QASYMM8_SIGNED)); + } + + return Status{}; +} + +template +typename std::enable_if::value, void>::type +output_stage_nchw(ITensor *src, const ITensor *bias, const Window &window, ITensor *dst, + int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift) +{ + const bool has_bias = bias != nullptr; + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t; + + ARM_COMPUTE_ERROR_ON(src->info()->data_layout() == DataLayout::UNKNOWN); + ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier); + ARM_COMPUTE_UNUSED(result_shift); + ARM_COMPUTE_UNUSED(result_offset_after_shift); + + const int window_start_x = window.x().start(); + const int window_end_x = window.x().end(); + const int window_step_x = 16 / src->info()->element_size(); + Window win = window; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator in(src, win); + Iterator out(dst, win); + execute_window_loop(win, [&](const Coordinates & id) + { + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + // Get bias and pointer to input + const auto in_ptr = reinterpret_cast(in.ptr()) + x; + auto v_in = wrapper::vloadq(in_ptr); + + // Accumulate bias + if(has_bias) + { + const auto vb = wrapper::vdup_n(*reinterpret_cast(bias->ptr_to_element(Coordinates(id.z()))), ExactTagType{}); + v_in = wrapper::vadd(v_in, vb); + } + + const auto out_ptr = reinterpret_cast(out.ptr()) + x; + wrapper::vstore(out_ptr, v_in); + } + + // Left-overs loop + for(; x < window_end_x; ++x) + { + // Get bias and pointer to input + auto s_in = *(reinterpret_cast(in.ptr()) + x); + + // Accumulate bias + if(has_bias) + { + const auto b = *reinterpret_cast(bias->ptr_to_element(Coordinates(id.z()))); + s_in += b; + } + + *(reinterpret_cast(out.ptr()) + x) = s_in; + } + + }, + in, out); +} + +template +typename std::enable_if::value, void>::type +output_stage_nhwc(ITensor *src, const ITensor *bias, const Window &window, ITensor *dst, + int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift) +{ + const bool has_bias = bias != nullptr; + ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier); + ARM_COMPUTE_UNUSED(result_shift); + ARM_COMPUTE_UNUSED(result_offset_after_shift); + + Window window_bias = window; + window_bias.set(Window::DimX, Window::Dimension(0, 1, 1)); + window_bias.set(Window::DimY, Window::Dimension(0, 0, 0)); + window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0)); + window_bias.set(3, Window::Dimension(0, 0, 0)); + + const int window_start_x = window.x().start(); + const int window_end_x = window.x().end(); + const int window_step_x = 16 / src->info()->element_size(); + Window win = window; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator in(src, win); + Iterator bi(bias, window_bias); + Iterator out(dst, win); + + execute_window_loop(win, [&](const Coordinates &) + { + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + // Get bias and pointer to input + const auto in_ptr = reinterpret_cast(in.ptr()); + auto v_in = wrapper::vloadq(in_ptr + x); + + // Accumulate bias + if(has_bias) + { + const auto bias_ptr = reinterpret_cast(bi.ptr()) + x; + v_in = wrapper::vadd(v_in, wrapper::vloadq(bias_ptr)); + } + + const auto out_ptr = reinterpret_cast(out.ptr()); + wrapper::vstore(out_ptr + x, v_in); + } + + // Left-overs loop + for(; x < window_end_x; ++x) + { + // Get bias and pointer to input + auto s_in = *(reinterpret_cast(in.ptr()) + x); + + // Accumulate bias + if(has_bias) + { + const auto bias_ptr = reinterpret_cast(bi.ptr()) + x; + s_in += *bias_ptr; + } + + const auto out_ptr = reinterpret_cast(out.ptr()); + *(out_ptr + x) = s_in; + } + }, + in, bi, out); +} + +// Quantized case +template < typename TOut, typename std::enable_if < std::is_same::value || std::is_same::value, int >::type = 0 > +void output_stage_nchw(ITensor *src, const ITensor *bias, const Window &window, ITensor *dst, + int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift) +{ + const bool has_bias = bias != nullptr; + using VectorType = typename wrapper::traits::neon_bitvector_t; + using TagType = typename wrapper::traits::neon_bitvector_tag_t; + + const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift); + + const VectorType min = wrapper::vdup_n(std::numeric_limits::lowest(), TagType{}); + const VectorType max = wrapper::vdup_n(std::numeric_limits::max(), TagType{}); + + const int window_start_x = window.x().start(); + const int window_end_x = window.x().end(); + const int window_step_x = 16 / src->info()->element_size(); + Window win = window; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator in(src, win); + Iterator out(dst, win); + + execute_window_loop(win, [&](const Coordinates & id) + { + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + // Get bias and pointer to input + const auto in_ptr = reinterpret_cast(in.ptr()) + x; + int32x4x4_t v_in = + { + { + wrapper::vloadq(in_ptr), + wrapper::vloadq(in_ptr + 4), + wrapper::vloadq(in_ptr + 8), + wrapper::vloadq(in_ptr + 12) + } + }; + + // Accumulate bias + if(has_bias) + { + const auto vb = wrapper::vdup_n(*reinterpret_cast(bias->ptr_to_element(Coordinates(id.z()))), TagType{}); + v_in = + { + { + wrapper::vadd(v_in.val[0], vb), + wrapper::vadd(v_in.val[1], vb), + wrapper::vadd(v_in.val[2], vb), + wrapper::vadd(v_in.val[3], vb) + } + }; + } + + const auto out_ptr = reinterpret_cast(out.ptr()) + x; + wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, + min, max, false)); + } + + // Left-overs loop + for(; x < window_end_x; ++x) + { + // Get bias and pointer to input + int32_t s_in = *(reinterpret_cast(in.ptr()) + x); + + // Accumulate bias + if(has_bias) + { + const auto b = *reinterpret_cast(bias->ptr_to_element(Coordinates(id.z()))); + s_in += b; + } + + const auto out_ptr = reinterpret_cast(out.ptr()) + x; + *out_ptr = finalize_quantization(s_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, + std::numeric_limits::lowest(), std::numeric_limits::max(), false); + } + }, + in, out); +} +template < typename TOut, typename std::enable_if < std::is_same::value || std::is_same::value, int >::type = 0 > +void output_stage_nhwc(ITensor *src, const ITensor *bias, const Window &window, ITensor *dst, + int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift) +{ + const bool has_bias = bias != nullptr; + using VectorType = typename wrapper::traits::neon_bitvector_t; + using TagType = typename wrapper::traits::neon_bitvector_tag_t; + + const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift); + + const VectorType min = wrapper::vdup_n(std::numeric_limits::lowest(), TagType{}); + const VectorType max = wrapper::vdup_n(std::numeric_limits::max(), TagType{}); + + Window window_bias = window; + window_bias.set(Window::DimX, Window::Dimension(0, 1, 1)); + window_bias.set(Window::DimY, Window::Dimension(0, 0, 0)); + window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0)); + window_bias.set(3, Window::Dimension(0, 0, 0)); + + const int window_start_x = window.x().start(); + const int window_end_x = window.x().end(); + const int window_step_x = 16 / src->info()->element_size(); + Window win = window; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator in(src, win); + Iterator bi(bias, window_bias); + Iterator out(dst, win); + + execute_window_loop(win, [&](const Coordinates &) + { + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + // Get bias and pointer to input + const auto in_ptr = reinterpret_cast(in.ptr()) + x; + int32x4x4_t v_in = + { + { + wrapper::vloadq(in_ptr), + wrapper::vloadq(in_ptr + 4), + wrapper::vloadq(in_ptr + 8), + wrapper::vloadq(in_ptr + 12), + } + }; + + // Accumulate bias + if(has_bias) + { + const auto bias_ptr = reinterpret_cast(bi.ptr()) + x; + + wrapper::vadd(v_in.val[0], wrapper::vloadq(bias_ptr)); + wrapper::vadd(v_in.val[1], wrapper::vloadq(bias_ptr + 4)); + wrapper::vadd(v_in.val[2], wrapper::vloadq(bias_ptr + 8)); + wrapper::vadd(v_in.val[3], wrapper::vloadq(bias_ptr + 12)); + } + + const auto out_ptr = reinterpret_cast(out.ptr()) + x; + wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max, false)); + } + + // Left-overs loop + for(; x < window_end_x; ++x) + { + // Get bias and pointer to input + const auto in_ptr = reinterpret_cast(in.ptr()) + x; + int32_t s_in = *in_ptr; + + // Accumulate bias + if(has_bias) + { + const auto bias_ptr = reinterpret_cast(bi.ptr()) + x; + s_in += *bias_ptr; + } + + const auto out_ptr = reinterpret_cast(out.ptr()) + x; + *out_ptr = finalize_quantization(s_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, + std::numeric_limits::lowest(), std::numeric_limits::max(), false); + } + }, + in, bi, out); +} +} // namespace + +void CpuDirectConvolutionOutputStageKernel::configure(ITensorInfo *src, const ITensorInfo *bias, ITensorInfo *dst, + const DirectConvolutionLayerOutputStageKernelInfo &info) +{ + ARM_COMPUTE_UNUSED(bias); + // Perform validation step + ARM_COMPUTE_ERROR_ON_NULLPTR(src); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, bias, dst, info)); + + _func = nullptr; + _result_fixedpoint_multiplier = info.result_fixedpoint_multiplier; + _result_shift = info.result_shift; + _result_offset_after_shift = info.result_offset_after_shift; + + // Auto-initialize output output if required + if(dst != nullptr) + { + // Work out expected output data type + const DataType output_dt = (src->data_type() == DataType::S32) ? info.output_data_type : DataType::S32; + // Output tensor auto initialization if not yet initialized + auto_init_if_empty(*dst, src->clone()->set_data_type(output_dt)); + } + + Window win = calculate_max_window(*src, Steps()); + + ICpuKernel::configure(win); + + const bool is_qasymm8_signed = (dst != nullptr) ? is_data_type_quantized_asymmetric_signed(dst->data_type()) : false; + + // Set appropriate function + if(src->data_layout() == DataLayout::NCHW) + { + switch(src->data_type()) + { + case DataType::S32: + { + if(is_qasymm8_signed) + { + _func = &output_stage_nchw; + } + else + { + _func = &output_stage_nchw; + } + break; + } +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + { + _func = &output_stage_nchw; + break; + } +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + case DataType::F32: + { + _func = &output_stage_nchw; + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs."); + } + } + } + else + { + switch(src->data_type()) + { + case DataType::S32: + { + if(is_qasymm8_signed) + { + _func = &output_stage_nhwc; + } + else + { + _func = &output_stage_nhwc; + } + break; + } +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + { + _func = &output_stage_nhwc; + break; + } +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + case DataType::F32: + { + _func = &output_stage_nhwc; + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs."); + } + } + } +} + +Status CpuDirectConvolutionOutputStageKernel::validate(const ITensorInfo *src, const ITensorInfo *bias, const ITensorInfo *dst, + const DirectConvolutionLayerOutputStageKernelInfo &info) +{ + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, bias, dst, info)); + return Status{}; +} + +void CpuDirectConvolutionOutputStageKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window); + ARM_COMPUTE_ERROR_ON(_func == nullptr); + + auto src = tensors.get_tensor(TensorType::ACL_SRC_0); + auto bias = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto dst = tensors.get_tensor(TensorType::ACL_DST); + + (*_func)(src, bias, window, dst, _result_fixedpoint_multiplier, _result_shift, _result_offset_after_shift); +} + +const char *CpuDirectConvolutionOutputStageKernel::name() const +{ + return "CpuDirectConvolutionOutputStageKernel"; +} +} // namespace kernels +} // namespace cpu +} // namespace arm_compute diff --git a/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp index a953edc78f..73834381c6 100644 --- a/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -27,107 +27,48 @@ #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" #include "arm_compute/runtime/NEON/NEScheduler.h" -#include "src/core/NEON/kernels/NEDirectConvolutionLayerKernel.h" -#include "src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h" -#include "src/core/NEON/kernels/NEFillBorderKernel.h" +#include "src/runtime/cpu/operators/CpuDirectConvolution.h" namespace arm_compute { -NEDirectConvolutionLayer::~NEDirectConvolutionLayer() = default; +struct NEDirectConvolutionLayer::Impl +{ + ITensor *src{ nullptr }; + const ITensor *weights{ nullptr }; + const ITensor *bias{ nullptr }; + ITensor *dst{ nullptr }; + std::unique_ptr op{ nullptr }; +}; NEDirectConvolutionLayer::NEDirectConvolutionLayer(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _output_stage_kernel(), _conv_kernel(), _input_border_handler(), _activationlayer_function(), _accumulator(), _has_bias(false), - _is_activationlayer_enabled(false), _dim_split(Window::DimZ), _is_padding_required() + : _memory_manager(std::move(memory_manager)), _impl(std::make_unique()) { } +NEDirectConvolutionLayer::~NEDirectConvolutionLayer() = default; void NEDirectConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ITensor *bias, ITensor *output, const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info) { - ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN); - _output_stage_kernel = std::make_unique(); - _conv_kernel = std::make_unique(); - _input_border_handler = std::make_unique(); - - // Free accumulator - if(_accumulator.buffer() != nullptr) - { - _accumulator.allocator()->free(); - } - - _dim_split = input->info()->data_layout() == DataLayout::NCHW ? Window::DimZ : Window::DimY; - - // Check if bias should be added in the convolution result - _has_bias = (bias != nullptr); - - _conv_kernel->configure(input, weights, output, conv_info); - if(_has_bias) - { - _output_stage_kernel->configure(output, bias); - } - _is_padding_required = !_conv_kernel->border_size().empty(); - - if(_is_padding_required) - { - // Add zero padding XY - _input_border_handler->configure(input, _conv_kernel->border_size(), BorderMode::CONSTANT, PixelValue(static_cast(0.f))); - } - - //Configure Activation Layer - _is_activationlayer_enabled = act_info.enabled(); - if(_is_activationlayer_enabled) - { - _activationlayer_function.configure(output, nullptr, act_info); - } + _impl->src = input; + _impl->weights = weights; + _impl->bias = bias; + _impl->dst = output; + _impl->op = std::make_unique(_memory_manager); + _impl->op->configure(input->info(), weights->info(), (bias != nullptr ? bias->info() : nullptr), output->info(), conv_info, act_info); } Status NEDirectConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, const ITensorInfo *output, const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info) { - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); - - // output might not be initialized since it can be an intermediate tensor of another layer - DataType data_type = input->data_type(); - TensorInfo accumulator(output->clone()->set_is_resizable(true).reset_padding().set_data_type(data_type)); - - // Validate Convolution kernel - ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayerKernel::validate(input, weights, &accumulator, conv_info)); - - if(bias != nullptr) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, bias); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(bias->dimension(0) != weights->dimension(3), - "Biases size and number of input feature maps should match"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(bias->num_dimensions() > 1, "Biases should be one dimensional"); - } - - // Validate bias kernel - ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayerOutputStageKernel::validate(&accumulator, bias, output)); - - if(act_info.enabled()) - { - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, act_info)); - } - - return Status{}; + return cpu::CpuDirectConvolution::validate(input, weights, bias, output, conv_info, act_info); } void NEDirectConvolutionLayer::run() { - MemoryGroupResourceScope scope_mg(_memory_group); - - if(_is_padding_required) - { - NEScheduler::get().schedule(_input_border_handler.get(), Window::DimZ); - } - NEScheduler::get().schedule(_conv_kernel.get(), _dim_split); - if(_has_bias) - { - NEScheduler::get().schedule(_output_stage_kernel.get(), Window::DimY); - } - - if(_is_activationlayer_enabled) - { - _activationlayer_function.run(); - } + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC_0, _impl->src); + pack.add_tensor(TensorType::ACL_SRC_1, _impl->weights); + pack.add_tensor(TensorType::ACL_SRC_2, _impl->bias); + pack.add_tensor(TensorType::ACL_DST, _impl->dst); + _impl->op->run(pack); } } // namespace arm_compute diff --git a/src/runtime/cpu/operators/CpuDirectConvolution.cpp b/src/runtime/cpu/operators/CpuDirectConvolution.cpp new file mode 100644 index 0000000000..33f79603e8 --- /dev/null +++ b/src/runtime/cpu/operators/CpuDirectConvolution.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/runtime/cpu/operators/CpuDirectConvolution.h" + +#include "arm_compute/core/PixelValue.h" +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/runtime/NEON/NEScheduler.h" + +namespace arm_compute +{ +namespace cpu +{ +CpuDirectConvolution::~CpuDirectConvolution() = default; + +CpuDirectConvolution::CpuDirectConvolution(std::shared_ptr memory_manager) + : _memory_group(std::move(memory_manager)), _output_stage_kernel(), _conv_kernel(), _input_border_handler(), _activationlayer_function(), _accumulator(), _has_bias(false), + _is_activationlayer_enabled(false), _dim_split(Window::DimZ), _is_padding_required() +{ +} + +void CpuDirectConvolution::configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *bias, ITensorInfo *dst, const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info) +{ + ARM_COMPUTE_ERROR_ON(src->data_layout() == DataLayout::UNKNOWN); + _output_stage_kernel = std::make_unique(); + _conv_kernel = std::make_unique(); + _input_border_handler = std::make_unique(); + + // Free accumulator + if(_accumulator.buffer() != nullptr) + { + _accumulator.allocator()->free(); + } + + _dim_split = src->data_layout() == DataLayout::NCHW ? Window::DimZ : Window::DimY; + + // Check if bias should be added in the convolution result + _has_bias = (bias != nullptr); + + _conv_kernel->configure(src, weights, dst, conv_info); + if(_has_bias) + { + _output_stage_kernel->configure(dst, bias); + } + _is_padding_required = !_conv_kernel->border_size().empty(); + + if(_is_padding_required) + { + // Add zero padding XY + _input_border_handler->configure(src, _conv_kernel->border_size(), BorderMode::CONSTANT, PixelValue(static_cast(0.f))); + } + + //Configure Activation Layer + _is_activationlayer_enabled = act_info.enabled(); + if(_is_activationlayer_enabled) + { + _activationlayer_function = std::make_unique(); + _activationlayer_function->configure(dst, dst, act_info); + } +} + +Status CpuDirectConvolution::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *bias, const ITensorInfo *dst, const PadStrideInfo &conv_info, + const ActivationLayerInfo &act_info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst); + + // output might not be initialized since it can be an intermediate tensor of another layer + DataType data_type = src->data_type(); + TensorInfo accumulator(dst->clone()->set_is_resizable(true).reset_padding().set_data_type(data_type)); + + // Validate Convolution kernel + ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConvolutionKernel::validate(src, weights, &accumulator, conv_info)); + + if(bias != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, bias); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(bias->dimension(0) != weights->dimension(3), + "Biases size and number of input feature maps should match"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(bias->num_dimensions() > 1, "Biases should be one dimensional"); + } + + // Validate bias kernel + ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConvolutionOutputStageKernel::validate(&accumulator, bias, dst)); + + if(act_info.enabled()) + { + ARM_COMPUTE_RETURN_ON_ERROR(CpuActivation::validate(dst, nullptr, act_info)); + } + + return Status{}; +} + +void CpuDirectConvolution::run(ITensorPack &tensors) +{ + MemoryGroupResourceScope scope_mg(_memory_group); + + auto src = tensors.get_tensor(TensorType::ACL_SRC_0); + auto bias = tensors.get_const_tensor(TensorType::ACL_SRC_2); + auto dst = tensors.get_tensor(TensorType::ACL_DST); + + if(_is_padding_required) + { + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC_DST, src); + NEScheduler::get().schedule_op(_input_border_handler.get(), Window::DimZ, _input_border_handler->window(), pack); + } + NEScheduler::get().schedule_op(_conv_kernel.get(), _dim_split, _conv_kernel->window(), tensors); + if(_has_bias) + { + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC_0, dst); + pack.add_tensor(TensorType::ACL_SRC_1, bias); + pack.add_tensor(TensorType::ACL_DST, dst); + NEScheduler::get().schedule_op(_output_stage_kernel.get(), Window::DimY, _output_stage_kernel->window(), pack); + } + + if(_is_activationlayer_enabled) + { + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, dst); + pack.add_tensor(TensorType::ACL_DST, dst); + _activationlayer_function->run(pack); + } +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/runtime/cpu/operators/CpuDirectConvolution.h b/src/runtime/cpu/operators/CpuDirectConvolution.h new file mode 100644 index 0000000000..0635e087fd --- /dev/null +++ b/src/runtime/cpu/operators/CpuDirectConvolution.h @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CPU_DIRECTCONVOLUTION_H +#define ARM_COMPUTE_CPU_DIRECTCONVOLUTION_H + +#include "arm_compute/core/ITensorInfo.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/core/experimental/Types.h" +#include "arm_compute/runtime/IMemoryManager.h" +#include "arm_compute/runtime/MemoryGroup.h" +#include "arm_compute/runtime/NEON/functions/NEActivationLayer.h" +#include "arm_compute/runtime/Tensor.h" +#include "src/core/NEON/kernels/NEFillBorderKernel.h" +#include "src/core/cpu/ICpuKernel.h" +#include "src/core/cpu/kernels/CpuDirectConvolutionKernel.h" +#include "src/core/cpu/kernels/CpuDirectConvolutionOutputStageKernel.h" +#include "src/runtime/cpu/ICpuOperator.h" +#include "src/runtime/cpu/operators/CpuActivation.h" + +#include + +namespace arm_compute +{ +namespace cpu +{ +/** Function to run the direct convolution. + * + * This function calls the following kernels: + * + * -# @ref NEFillBorderKernel for the input + * -# @ref kernels::CpuDirectConvolutionOutputStageKernel + * -# @ref kernels::CpuDirectConvolutionKernel + */ +class CpuDirectConvolution : public ICpuOperator +{ +public: + /** Constructor */ + CpuDirectConvolution(std::shared_ptr memory_manager = nullptr); + /** Destructor */ + ~CpuDirectConvolution(); + /** Set the input, weights, biases and output tensors. + * + * @note: DirectConvolution only works in the following configurations: + * 1x1 convolution with stride_x = 1/2/3, stride_y = 1/2/3 data type = F16/F32 + * 3x3 convolution with stride_x = 1/2/3, stride_y = 1/2/3 data type = F16/F32 + * 5x5 convolution with stride_x = 1/2/3, stride_y = 1/2/3 data type = F32 + * + * @param[in, out] src Input tensor info. Data types supported: F16/F32. + * @param[in] weights Set of kernels to convolve the input volume. + * Supported sizes: 1x1, 3x3 and 5x5. + * The 3rd dimension must be the same as the input's volume 3rd dimension. + * Data type supported: Same as @p src. + * @param[in] bias Set of biases. Can be nullptr. Data type supported: Same as @p src. + * @param[out] dst Output tensor info. + * The 3rd dimensions must be equal to the 4th dimension of the @p kernels tensor. Data types supported: Same as @p input. + * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. + */ + void configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *bias, ITensorInfo *dst, const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info = ActivationLayerInfo()); + /** Static function to check if given info will lead to a valid configuration of @ref NEDirectConvolutionLayer + * + * @note: DirectConvolution only works in the following configurations: + * 1x1 convolution with stride_x = 1/2/3, stride_y = 1/2/3 data type = F16/F32 + * 3x3 convolution with stride_x = 1/2/3, stride_y = 1/2/3 data type = F16/F32 + * 5x5 convolution with stride_x = 1/2/3, stride_y = 1/2/3 data type = F32 + * + * @param[in] src Input tensor info. Data types supported: F16/F32. + * @param[in] weights Set of kernels to convolve the input volume. + * Supported sizes: 1x1, 3x3 and 5x5. + * The 3rd dimension must be the same as the input's volume 3rd dimension. + * Data type supported: Same as @p src. + * @param[in] bias Set of biases. Can be nullptr. Data type supported: Same as @p src. + * @param[in] dst Output tensor info. + * The 3rd dimensions must be equal to the 4th dimension of the @p kernels tensor. Data types supported: Same as @p input. + * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. + * + * @return a status + */ + static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *bias, const ITensorInfo *dst, const PadStrideInfo &conv_info, + const ActivationLayerInfo &act_info = ActivationLayerInfo()); + + // Inherited methods overridden: + void run(ITensorPack &tensors) override; + +private: + MemoryGroup _memory_group; + std::unique_ptr _output_stage_kernel; + std::unique_ptr _conv_kernel; + std::unique_ptr _input_border_handler; + std::unique_ptr _activationlayer_function; + Tensor _accumulator; + bool _has_bias{ false }; + bool _is_activationlayer_enabled{ false }; + unsigned int _dim_split{ 0 }; + bool _is_padding_required{ false }; +}; +} // namespace cpu +} // namespace arm_compute +#endif /* ARM_COMPUTE_CPU_DIRECTCONVOLUTION_H */ -- cgit v1.2.1