From 87350f47084d2b69daa11c3b1c7eb47e02260063 Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Tue, 15 Sep 2020 13:03:34 +0100 Subject: COMPMID-3144: Remove padding from NEDirectConvolutionLayerKernel Change-Id: I22b907eebfbe037e6e1c7bf604172f4709a9cbed Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4082 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Georgios Pinitas --- .../NEON/kernels/NEDirectConvolutionLayerKernel.h | 10 +- .../NEON/functions/NEDirectConvolutionLayer.h | 3 +- .../kernels/NEDirectConvolutionLayerKernel.cpp | 675 ++++++++------------- .../NEON/functions/NEDirectConvolutionLayer.cpp | 19 +- tests/validation/NEON/DirectConvolutionLayer.cpp | 52 +- .../fixtures/DirectConvolutionLayerFixture.h | 3 - 6 files changed, 314 insertions(+), 448 deletions(-) diff --git a/arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h b/arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h index 4cb9c90a1a..c927aff1eb 100644 --- a/arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h +++ b/arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 Arm Limited. + * Copyright (c) 2017-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -86,6 +86,14 @@ public: BorderSize border_size() const override; private: + /* Template function for optimized convolution NHWC */ + template + void convolve_nhwc_optimized(const Window &window); + + /* Template function for convolution NHWC */ + template + void convolve_nhwc(const Window &window); + const ITensor *_input; const ITensor *_weights; ITensor *_output; diff --git a/arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h index 9b18f645bd..d1c811c363 100644 --- a/arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h +++ b/arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 Arm Limited. + * Copyright (c) 2017-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -106,6 +106,7 @@ private: bool _has_bias; bool _is_activationlayer_enabled; unsigned int _dim_split; + bool _is_padding_required; }; } #endif /* ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYER_H */ diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp index ac1d6aec8f..c22fa6a2b3 100644 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp +++ b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp @@ -40,9 +40,10 @@ #include -using namespace arm_compute; using namespace arm_compute::detail; +namespace arm_compute +{ namespace { #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -472,117 +473,6 @@ inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const return out; } -template -class convolver_nhwc -{ -public: - static void convolve(const Window &window, uint32_t kernel_size, unsigned int num_elems_read_per_iteration, - const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) - { - const int input_width = input->info()->dimension(0); - const int input_depth = input->info()->dimension(2); - const int input_stride_x = input->info()->strides_in_bytes().x(); - const int input_stride_y = input->info()->strides_in_bytes().y(); - const int input_stride_z = input->info()->strides_in_bytes().z(); - const int output_stride_x = output->info()->strides_in_bytes().x(); - const int kernel_stride_x = weights->info()->strides_in_bytes().x(); - const int kernel_stride_y = weights->info()->strides_in_bytes().y(); - const int kernel_stride_z = weights->info()->strides_in_bytes().z(); - const int conv_pad_top = conv_info.pad_top(); - const unsigned int conv_stride_x = std::get<0>(conv_info.stride()); - const unsigned int conv_stride_y = std::get<1>(conv_info.stride()); - const T1 zero = 0; - - // Setup input window for the input iterator - Window window_in = window; - window_in.set(Window::DimX, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimY, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimZ, Window::Dimension(0, 0, 0)); - - // Setup input window for the output iterator - Window window_out = window; - window_out.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Setup input window for the weights iterator - Window window_k = calculate_max_window(*weights->info(), Steps()); - window_k.set(Window::DimX, Window::Dimension(0, 1, 1)); - window_k.set(Window::DimY, Window::Dimension(0, 1, 1)); - window_k.set(Window::DimZ, Window::Dimension(0, 1, 1)); - window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1)); - - Iterator in(input, window_in); - Iterator out(output, window_out); - Iterator k(weights, window_k); - - execute_window_loop(window_k, [&](const Coordinates & id_k) - { - execute_window_loop(window_out, [&](const Coordinates & id) - { - const auto in_y = static_cast(id.y() * conv_stride_x - conv_info.pad_left()); - const auto in_z = static_cast(id.z() * conv_stride_y - conv_pad_top); - - const uint8_t *in_ptr = in.ptr() + in_y * input_stride_y + in_z * input_stride_z; - uint8_t *out_ptr = out.ptr() + id_k[3] * output_stride_x; - - T1 out_val = 0; - - auto in_addr_base0 = in_ptr; - auto we_addr_base0 = k.ptr(); - - for(uint32_t z = 0; z < kernel_size; ++z, in_addr_base0 += input_stride_z, we_addr_base0 += kernel_stride_z) - { - const int in_z = id.z() * conv_stride_y + z - conv_pad_top; - - if(in_z >= 0 && in_z < input_depth) // If false, pad top/bottom - { - auto in_addr_base1 = in_addr_base0; - auto we_addr_base1 = we_addr_base0; - - for(uint32_t y = 0; y < kernel_size; ++y, in_addr_base1 += input_stride_y, we_addr_base1 += kernel_stride_y) - { - auto out_values = internal_vdupq_n(zero); - - int x = 0; - int no_leftover = input_width - num_elems_read_per_iteration; - - for(; x < no_leftover; x += num_elems_read_per_iteration) - { - const auto in_addr = reinterpret_cast(in_addr_base1 + x * input_stride_x); - const auto in_values = internal_vld1q<1>(in_addr); - - const auto we_addr = reinterpret_cast(we_addr_base1 + x * kernel_stride_x); - const auto we_values = internal_vld1q<1>(we_addr); - - out_values = internal_vmlal(out_values, in_values, we_values); - } - - auto carry_addition = wrapper::vpadd(wrapper::vgethigh(out_values), wrapper::vgetlow(out_values)); - carry_addition = wrapper::vpadd(carry_addition, carry_addition); - out_val += wrapper::vgetlane(carry_addition, 0); - - // Leftover - for(; x < input_width; ++x) - { - const auto in_addr = reinterpret_cast(in_addr_base1 + x * input_stride_x); - const auto in_value = *(in_addr); - - const auto we_addr = reinterpret_cast(we_addr_base1 + x * kernel_stride_x); - const auto we_value = *(we_addr); - - out_val += in_value * we_value; - } - } - } - } - - *(reinterpret_cast(out_ptr)) = out_val; - }, - in, out); - }, - k); - } -}; - template class convolver_3x3 { @@ -815,76 +705,6 @@ public: } }; -inline void convolve_row1x9_nhwc(const float *row_ptr, const float *weights_ptr, size_t src_stride_y, size_t weights_stride_y, - float32x4_t &acc0, float32x4_t &acc1, float32x4_t &acc2, float32x4_t &acc3) -{ - // Load 4 channels for each of the 12 inputs values along the same X spatial dimension - const float32x4_t src0 = wrapper::vloadq(row_ptr); - const float32x4_t src1 = wrapper::vloadq(row_ptr + 1 * src_stride_y); - const float32x4_t src2 = wrapper::vloadq(row_ptr + 2 * src_stride_y); - const float32x4_t src3 = wrapper::vloadq(row_ptr + 3 * src_stride_y); - const float32x4_t src4 = wrapper::vloadq(row_ptr + 4 * src_stride_y); - const float32x4_t src5 = wrapper::vloadq(row_ptr + 5 * src_stride_y); - const float32x4_t src6 = wrapper::vloadq(row_ptr + 6 * src_stride_y); - const float32x4_t src7 = wrapper::vloadq(row_ptr + 7 * src_stride_y); - const float32x4_t src8 = wrapper::vloadq(row_ptr + 8 * src_stride_y); - const float32x4_t src9 = wrapper::vloadq(row_ptr + 9 * src_stride_y); - const float32x4_t src10 = wrapper::vloadq(row_ptr + 10 * src_stride_y); - const float32x4_t src11 = wrapper::vloadq(row_ptr + 11 * src_stride_y); - - // Load 4 channels for each of the 9 weights values along the same X spatial dimension - const float32x4_t w0 = wrapper::vloadq(weights_ptr); - const float32x4_t w1 = wrapper::vloadq(weights_ptr + 1 * weights_stride_y); - const float32x4_t w2 = wrapper::vloadq(weights_ptr + 2 * weights_stride_y); - const float32x4_t w3 = wrapper::vloadq(weights_ptr + 3 * weights_stride_y); - const float32x4_t w4 = wrapper::vloadq(weights_ptr + 4 * weights_stride_y); - const float32x4_t w5 = wrapper::vloadq(weights_ptr + 5 * weights_stride_y); - const float32x4_t w6 = wrapper::vloadq(weights_ptr + 6 * weights_stride_y); - const float32x4_t w7 = wrapper::vloadq(weights_ptr + 7 * weights_stride_y); - const float32x4_t w8 = wrapper::vloadq(weights_ptr + 8 * weights_stride_y); - - // Store 4 channels for each of the 4 output values along the same X spatial dimension - acc0 = wrapper::vmla(acc0, w0, src0); - acc0 = wrapper::vmla(acc0, w1, src1); - acc0 = wrapper::vmla(acc0, w2, src2); - acc0 = wrapper::vmla(acc0, w3, src3); - acc0 = wrapper::vmla(acc0, w4, src4); - acc0 = wrapper::vmla(acc0, w5, src5); - acc0 = wrapper::vmla(acc0, w6, src6); - acc0 = wrapper::vmla(acc0, w7, src7); - acc0 = wrapper::vmla(acc0, w8, src8); - - acc1 = wrapper::vmla(acc1, w0, src1); - acc1 = wrapper::vmla(acc1, w1, src2); - acc1 = wrapper::vmla(acc1, w2, src3); - acc1 = wrapper::vmla(acc1, w3, src4); - acc1 = wrapper::vmla(acc1, w4, src5); - acc1 = wrapper::vmla(acc1, w5, src6); - acc1 = wrapper::vmla(acc1, w6, src7); - acc1 = wrapper::vmla(acc1, w7, src8); - acc1 = wrapper::vmla(acc1, w8, src9); - - acc2 = wrapper::vmla(acc2, w0, src2); - acc2 = wrapper::vmla(acc2, w1, src3); - acc2 = wrapper::vmla(acc2, w2, src4); - acc2 = wrapper::vmla(acc2, w3, src5); - acc2 = wrapper::vmla(acc2, w4, src6); - acc2 = wrapper::vmla(acc2, w5, src7); - acc2 = wrapper::vmla(acc2, w6, src8); - acc2 = wrapper::vmla(acc2, w7, src9); - acc2 = wrapper::vmla(acc2, w8, src10); - - acc3 = wrapper::vmla(acc3, w0, src3); - acc3 = wrapper::vmla(acc3, w1, src4); - acc3 = wrapper::vmla(acc3, w2, src5); - acc3 = wrapper::vmla(acc3, w3, src6); - acc3 = wrapper::vmla(acc3, w4, src7); - acc3 = wrapper::vmla(acc3, w5, src8); - acc3 = wrapper::vmla(acc3, w6, src9); - acc3 = wrapper::vmla(acc3, w7, src10); - acc3 = wrapper::vmla(acc3, w8, src11); -} - float vreduce(const float32x4_t &v) { auto v0 = wrapper::vgethigh(v); @@ -896,175 +716,6 @@ float vreduce(const float32x4_t &v) return a + b; } -template -class convolver_9x9_nhwc -{ -public: - static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, - const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) - { - // Declare useful types - using vector_type = typename V::type; - using scalar_type = typename V::scalar_type; - using tag_type = typename V::tag_type; - - // Scalar quantities - const int element_size = input->info()->element_size(); - const int input_width = input->info()->dimension(0); - const int input_depth = input->info()->dimension(2); - const int input_stride_y = input->info()->strides_in_bytes().y() / element_size; - const int input_stride_z = input->info()->strides_in_bytes().z() / element_size; - const int input_stride_w = input->info()->strides_in_bytes()[3]; - const int output_stride_x = output->info()->strides_in_bytes().x(); - const int output_stride_y = output->info()->strides_in_bytes().y(); - const int kernel_stride_y = weights->info()->strides_in_bytes().y() / element_size; - const int kernel_stride_z = weights->info()->strides_in_bytes().z() / element_size; - const unsigned int conv_stride_y = std::get<1>(conv_info.stride()); - const unsigned int conv_pad_top = conv_info.pad_top(); - const unsigned int conv_pad_left = conv_info.pad_left(); - - // Setup input window for the input iterator - Window window_in = window; - window_in.set(Window::DimX, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimY, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimZ, Window::Dimension(0, 0, 0)); - - // Setup input window for the output iterator - Window window_out = window; - window_out.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Setup input window for the weights iterator - Window window_k = calculate_max_window(*weights->info(), Steps()); - window_k.set(Window::DimX, Window::Dimension(0, 1, 1)); - window_k.set(Window::DimY, Window::Dimension(0, 1, 1)); - window_k.set(Window::DimZ, Window::Dimension(0, 1, 1)); - window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1)); - - Iterator in(input, window_in); - Iterator out(output, window_out); - Iterator k(weights, window_k); - - // Calculate the max_offset. - // max_offset is the offset for the last NOT valid value in the Z dimension (spatial dimension Y for NHWC) - // |******************| - // | pad_top | - // |******************| - // | | - // | plane0 | - // | batch0 | - // |__________________| - // |******************| Batch 0 - // | pad_bottom | - // | pad_top | - // |******************| - // | | - // | plane1 | - // | batch0 | - // |__________________|-----> max_offset - // |******************| - // | pad_bottom | - // | pad_top | - // |******************| - // | | - // | plane0 | - // | batch1 | - // |__________________| - // |******************| Batch 1 - // | pad_bottom | - // | pad_top | - // |******************| - // | | - // | plane1 | - // | batch1 | - // |__________________| - // | pad_bottom | - // |******************| - const int64_t max_offset = input_stride_z * input_depth - (input->info()->padding().bottom + input->info()->padding().top) * input_stride_y; - execute_window_loop(window_k, [&](const Coordinates & id_k) // loop on the batch size - { - - execute_window_loop(window_out, [&](const Coordinates & id) - { - const auto y_offset = int(id.y() - conv_pad_left) * input_stride_y; - - // Buffer pointers - const scalar_type *in_ptr = reinterpret_cast(input->buffer() + input->info()->offset_first_element_in_bytes() + id[3] * input_stride_w); - const scalar_type *weights_ptr = reinterpret_cast(k.ptr()); - uint8_t *out_ptr = out.ptr() + id_k[3] * output_stride_x; - - // Output elements - vector_type out0 = wrapper::vdup_n(scalar_type(0), tag_type()); - vector_type out1 = wrapper::vdup_n(scalar_type(0), tag_type()); - vector_type out2 = wrapper::vdup_n(scalar_type(0), tag_type()); - vector_type out3 = wrapper::vdup_n(scalar_type(0), tag_type()); - - // Reduce along the feature maps - for(int x = 0; x < input_width; x += num_elems_read_per_iteration) - { - // z == 0 - auto in_z = static_cast(id.z() * conv_stride_y - conv_pad_top); - in_z = std::min(static_cast(in_z), static_cast(input_depth)); - auto offset = y_offset + in_z * input_stride_z; - offset = std::min(offset, max_offset); - convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 0 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3); - - // z == 1 - in_z = static_cast(id.z() * conv_stride_y - conv_pad_top + 1); - in_z = std::min(static_cast(in_z), static_cast(input_depth)); - offset = y_offset + in_z * input_stride_z; - offset = std::min(offset, max_offset); - convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 1 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3); - - // z == 2 - in_z = static_cast(id.z() * conv_stride_y - conv_pad_top + 2); - in_z = std::min(static_cast(in_z), static_cast(input_depth)); - offset = y_offset + in_z * input_stride_z; - offset = std::min(offset, max_offset); - convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 2 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3); - - // z == 3 - in_z = static_cast(id.z() * conv_stride_y - conv_pad_top + 3); - offset = y_offset + in_z * input_stride_z; - offset = std::min(offset, max_offset); - convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 3 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3); - - // z == 4 - in_z = static_cast(id.z() * conv_stride_y - conv_pad_top + 4); - offset = y_offset + in_z * input_stride_z; - convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 4 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3); - - // z == 5 - offset += input_stride_z; - offset = std::min(offset, max_offset); - convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 5 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3); - - // z == 6 - offset += input_stride_z; - offset = std::min(offset, max_offset); - convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 6 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3); - - // z == 7 - offset += input_stride_z; - offset = std::min(offset, max_offset); - convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 7 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3); - - // z == 8 - offset += input_stride_z; - offset = std::min(offset, max_offset); - convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 8 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3); - } - - *(reinterpret_cast(out_ptr + 0 * output_stride_y)) = vreduce(out0); - *(reinterpret_cast(out_ptr + 1 * output_stride_y)) = vreduce(out1); - *(reinterpret_cast(out_ptr + 2 * output_stride_y)) = vreduce(out2); - *(reinterpret_cast(out_ptr + 3 * output_stride_y)) = vreduce(out3); - }, - in, out); - }, - k); - } -}; - template inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) @@ -1169,21 +820,6 @@ inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_i } } -template -inline void convolve_9x9_nhwc(const Window &window, unsigned int num_elems_read_per_iteration, - const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info) -{ - const unsigned int conv_stride_x = std::get<0>(conv_info.stride()); - switch(conv_stride_x) - { - case 1: - convolver_9x9_nhwc::convolve(window, num_elems_read_per_iteration, input, weights, output, conv_info); - break; - default: - ARM_COMPUTE_ERROR("Not implemented"); - } -} - Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); @@ -1337,68 +973,248 @@ std::pair validate_and_configure_window(ITensorInfo *input, ITen } else { - if(kernel_size == 9) - { - border_size.left = 0; - border_size.top = conv_info.pad_left(); + // Configure window NHWC without any padding + win = calculate_max_window(*output, Steps()); + Coordinates coord; + coord.set_num_dimensions(output->num_dimensions()); + output->set_valid_region(ValidRegion(coord, output->tensor_shape())); + } - const int num_elems_read_per_iteration_x = 4; - const int num_elems_written_per_iteration_x = 1; - const int num_elems_read_per_iteration_y = 12; - const int num_elems_written_per_iteration_y = 4; + Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; + return std::make_pair(err, win); +} - num_elems_read_per_iteration = num_elems_read_per_iteration_x; - num_elems_written_per_iteration = num_elems_written_per_iteration_x; +bool have_zero_x_internal_padding(ITensorInfo *input, ITensorInfo *weights) +{ + return (input->padding().left == 0 && weights->padding().left == 0 && input->padding().right == 0 && weights->padding().right == 0); +} - border_size.right = num_elems_read_per_iteration_x; - if((conv_info.pad_bottom() != 0) || (conv_info.pad_top() != 0)) - { - // If bottom or top padding are set, we need to read num_elems_read_per_iteration_y rows to zero. - // Since num_elems_read_per_iteration_y is always greater than conv_info.pad_right() we can set - // the bottom padding to num_elems_read_per_iteration_y - border_size.bottom = num_elems_read_per_iteration_y; - } - else if(conv_info.pad_right() != 0) - { - // Convetional border padding. Fill the bottom paddings so that we can read in batch of num_elems_read_per_iteration_y - border_size.bottom = ceil_to_multiple(input->dimension(1) + conv_info.pad_right(), num_elems_read_per_iteration_y) - input->dimension(1); - } - else +} // namespace + +template +void NEDirectConvolutionLayerKernel::convolve_nhwc_optimized(const Window &window) +{ + // This function assumes that input and weights have not padding in channel + + // Declare useful types + using vtype = wrapper::traits::neon_bitvector; + using vector_type = typename vtype::type; + using tag_type = typename vtype::tag_type; + + // Scalar quantities + const int element_size = _input->info()->element_size(); + const int input_stride_w = _input->info()->strides_in_bytes().y() / element_size; + const int input_stride_h = _input->info()->strides_in_bytes().z() / element_size; + const int input_stride_n = _input->info()->strides_in_bytes()[3] / element_size; + const int input_dim_w = _input->info()->dimension(1); + const int input_dim_h = _input->info()->dimension(2); + + const int output_stride_c = _output->info()->strides_in_bytes().x(); + + const unsigned int kernel_stride_w = _weights->info()->strides_in_bytes().y() / element_size; + const unsigned int kernel_stride_h = _weights->info()->strides_in_bytes().z() / element_size; + const int kernel_dim_w = _weights->info()->dimension(1); + const int kernel_dim_h = _weights->info()->dimension(2); + + const int conv_pad_top = _conv_info.pad_top(); + const int conv_pad_left = _conv_info.pad_left(); + const int conv_stride_w = std::get<0>(_conv_info.stride()); + const int conv_stride_h = std::get<1>(_conv_info.stride()); + + // Setup input window for the output iterator + Window window_out = window; + window_out.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Setup input window for the weights iterator + Window window_w = calculate_max_window(*_weights->info(), Steps()); + window_w.set(Window::DimX, Window::Dimension(0, 1, 1)); + window_w.set(Window::DimY, Window::Dimension(0, 1, 1)); + window_w.set(Window::DimZ, Window::Dimension(0, 1, 1)); + + Iterator out(_output, window_out); + Iterator wei(_weights, window_w); + + constexpr int num_elems_read_per_iteration = 16 / sizeof(T); + /* + * This implementation parallelize the full WC plane of input and weights by + * treating them as series of elements. So for example, a 3x3 weights and + * floating point vector operations of 4 elements per time, the first 3 + * channel elements of the first row would be taken and additionally the first + * element of the second row. The 9 elements in each single WC weight plane + * would require 2 4-element vector operations and a last single element operation. + * + * This works since when we create the input vector to multiply with the weights, + * the exact required elements are loaded in the same order. Therefore the + * multiplication works on the correct input/weight elements. + */ + execute_window_loop(window_out, [&](const Coordinates & id) + { + /* + * In here we create theoretical indexes which then we validate for both + * inputs and weights. + * As a reminder, this loop take each output point in NHW, C is treated + * in the weights loop. + */ + // We are computing the theoretical starting input starting points + const int in_w_start_t = static_cast(id.y()) * conv_stride_w - conv_pad_left; + const int in_h_start_t = static_cast(id.z()) * conv_stride_h - conv_pad_top; + const int in_w_end_t = in_w_start_t + kernel_dim_w; + const int in_h_end_t = in_h_start_t + kernel_dim_h; + + // We are computing the valid initial and ending input points by checking the borders + const int in_w_start = std::max(in_w_start_t, 0); + const int in_h_start = std::max(in_h_start_t, 0); + const int in_w_end = std::min(in_w_end_t, input_dim_w); + const int in_h_end = std::min(in_h_end_t, input_dim_h); + + // We use the input points to select the valid weight points to use + const int index_wc_start = (in_w_start - in_w_start_t) * kernel_stride_w; + const int index_h_start = in_h_start - in_h_start_t; + const int index_wc_end = (kernel_dim_w - (in_w_end_t - in_w_end)) * kernel_stride_w; + const int index_h_end = kernel_dim_h - (in_h_end_t - in_h_end); + + execute_window_loop(window_w, [&](const Coordinates & id_w) + { + /* + * This is the loop in the weights, and it goes along N (the batches) + * As a reminder, the batches of the weights are translated into the + * channels of the output + */ + const T *in_ptr_row = reinterpret_cast(_input->buffer() + _input->info()->offset_first_element_in_bytes()) + + id[3] * input_stride_n + in_w_start * input_stride_w + in_h_start * input_stride_h; + const T *weights_ptr_row = reinterpret_cast(wei.ptr()) + index_h_start * kernel_stride_h; + uint8_t *out_ptr = out.ptr() + id_w[3] * output_stride_c; + + T out_temp = static_cast(0); + for(int index_h = index_h_start; index_h < index_h_end; ++index_h, in_ptr_row += input_stride_h, weights_ptr_row += kernel_stride_h) { - // No padding - border_size.bottom = 0; + const T *in_ptr_mover = in_ptr_row; + int index_wc = index_wc_start; + vector_type out_temp_vec = wrapper::vdup_n(static_cast(0), tag_type()); + for(; index_wc <= index_wc_end - num_elems_read_per_iteration; index_wc += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration) + { + const auto src_vec = wrapper::vloadq(in_ptr_mover); + const auto w_vec = wrapper::vloadq(weights_ptr_row + index_wc); + out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec); + } + out_temp += vreduce(out_temp_vec); + for(; index_wc < index_wc_end; ++index_wc, ++in_ptr_mover) + { + const auto src_val = *(in_ptr_mover); + const auto w_val = *(weights_ptr_row + index_wc); + out_temp += src_val * w_val; + } } + *(reinterpret_cast(out_ptr)) = out_temp; + }, + wei); + }, + out); +} - win = calculate_max_window(*output, Steps(num_elems_written_per_iteration_x, num_elems_written_per_iteration_y)); - - AccessWindowStatic input_access(input, 0, -border_size.top, - ceil_to_multiple(input->dimension(0), num_elems_read_per_iteration_x), - input->dimension(1) + border_size.bottom); - - AccessWindowStatic weights_access(weights, 0, 0, ceil_to_multiple(weights->dimension(0), num_elems_read_per_iteration_x), weights->dimension(1)); - AccessWindowRectangle output_access(output, 0, 0, num_elems_written_per_iteration_x, num_elems_written_per_iteration_y); - window_changed = update_window_and_padding(win, input_access, weights_access, output_access); - output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape())); - } - else +template +void NEDirectConvolutionLayerKernel::convolve_nhwc(const Window &window) +{ + // Declare useful types + using vtype = wrapper::traits::neon_bitvector; + using vector_type = typename vtype::type; + using tag_type = typename vtype::tag_type; + + // Scalar quantities + const int element_size = _input->info()->element_size(); + const int input_stride_w = _input->info()->strides_in_bytes().y() / element_size; + const int input_stride_h = _input->info()->strides_in_bytes().z() / element_size; + const int input_stride_n = _input->info()->strides_in_bytes()[3] / element_size; + const int input_dim_w = _input->info()->dimension(1); + const int input_dim_h = _input->info()->dimension(2); + + const int output_stride_c = _output->info()->strides_in_bytes().x(); + + const unsigned int kernel_stride_w = _weights->info()->strides_in_bytes().y() / element_size; + const unsigned int kernel_stride_h = _weights->info()->strides_in_bytes().z() / element_size; + const int kernel_dim_w = _weights->info()->dimension(1); + const int kernel_dim_h = _weights->info()->dimension(2); + + const int conv_pad_top = _conv_info.pad_top(); + const int conv_pad_left = _conv_info.pad_left(); + const int conv_stride_w = std::get<0>(_conv_info.stride()); + const int conv_stride_h = std::get<1>(_conv_info.stride()); + + // Setup input window for the output iterator + Window window_out = window; + window_out.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Setup input window for the weights iterator + Window window_w = calculate_max_window(*_weights->info(), Steps()); + window_w.set(Window::DimX, Window::Dimension(0, 1, 1)); + window_w.set(Window::DimY, Window::Dimension(0, 1, 1)); + window_w.set(Window::DimZ, Window::Dimension(0, 1, 1)); + + Iterator out(_output, window_out); + Iterator wei(_weights, window_w); + + constexpr int num_elems_read_per_iteration = 16 / sizeof(T); + + execute_window_loop(window_out, [&](const Coordinates & id) + { + // We are computing the theoretical starting input starting points + const int in_w_start_t = static_cast(id.y()) * conv_stride_w - conv_pad_left; + const int in_h_start_t = static_cast(id.z()) * conv_stride_h - conv_pad_top; + const int in_w_end_t = in_w_start_t + kernel_dim_w; + const int in_h_end_t = in_h_start_t + kernel_dim_h; + + // We are computing the valid initial and ending input points by checking the borders + const int in_w_start = std::max(in_w_start_t, 0); + const int in_h_start = std::max(in_h_start_t, 0); + const int in_w_end = std::min(in_w_end_t, input_dim_w); + const int in_h_end = std::min(in_h_end_t, input_dim_h); + + // We use the input points to select the valid weight points to use + const int wei_w_start = in_w_start - in_w_start_t; + const int wei_h_start = in_h_start - in_h_start_t; + const int wei_w_end = kernel_dim_w - (in_w_end_t - in_w_end); + const int wei_h_end = kernel_dim_h - (in_h_end_t - in_h_end); + + const int index_c_end = _weights->info()->dimension(0); + const T *const in_ptr_start = reinterpret_cast(_input->buffer() + _input->info()->offset_first_element_in_bytes()) + id[3] * input_stride_n; + + execute_window_loop(window_w, [&](const Coordinates & id_w) { - border_size.left = 0; - border_size.top = conv_info.pad_left(); - border_size.right = 0; - border_size.bottom = conv_info.pad_right(); - num_elems_read_per_iteration = 16 / element_size_from_data_type(input->data_type()); - win = calculate_max_window(*output, Steps()); - - AccessWindowRectangle input_access(input, 0, -border_size.top, num_elems_read_per_iteration, kernel_size, 1.f, conv_stride_x); - AccessWindowRectangle weights_access(weights, 0, 0, num_elems_read_per_iteration, kernel_size); - window_changed = update_window_and_padding(win, input_access, weights_access); - } - } + const T *const weights_ptr_start = reinterpret_cast(wei.ptr()); + uint8_t *out_ptr = out.ptr() + id_w[3] * output_stride_c; - Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; - return std::make_pair(err, win); + T out_temp = static_cast(0); + for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h) + { + const T *const in_ptr_row = in_ptr_start + index_in_h * input_stride_h; + const T *const weights_ptr_row = weights_ptr_start + index_wei_h * kernel_stride_h; + for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w) + { + const T *in_ptr_mover = in_ptr_row + index_in_w * input_stride_w; + const T *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w; + int index_c = 0; + vector_type out_temp_vec = wrapper::vdup_n(static_cast(0), tag_type()); + for(; index_c <= index_c_end - num_elems_read_per_iteration; index_c += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration, weights_ptr_mover += num_elems_read_per_iteration) + { + const auto src_vec = wrapper::vloadq(in_ptr_mover); + const auto w_vec = wrapper::vloadq(weights_ptr_mover); + out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec); + } + out_temp += vreduce(out_temp_vec); + for(; index_c < index_c_end; ++index_c, ++in_ptr_mover, ++weights_ptr_mover) + { + const auto src_val = *(in_ptr_mover); + const auto w_val = *(weights_ptr_mover); + out_temp += src_val * w_val; + } + } + } + *(reinterpret_cast(out_ptr)) = out_temp; + }, + wei); + }, + out); } -} // namespace NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel() : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0), @@ -1425,7 +1241,14 @@ void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITens const unsigned int conv_pad_top = conv_info.pad_top(); const unsigned int conv_pad_right = conv_info.pad_right(); const unsigned int conv_pad_bottom = conv_info.pad_bottom(); - _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left); + if(_input->info()->data_layout() == DataLayout::NCHW) + { + _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left); + } + else + { + _border_size = BorderSize(0); + } // Get convolved dimensions TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input->info(), *weights->info(), conv_info); @@ -1536,22 +1359,17 @@ void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo } else { - const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH)); - const int stride_x = std::get<0>(_conv_info.stride()); - const int stride_y = std::get<1>(_conv_info.stride()); - switch(_input->info()->data_type()) { case DataType::F32: { - if(kernel_size == 9 && stride_x == 1 && stride_y == 1) + if(have_zero_x_internal_padding(_input->info(), _weights->info())) { - using vtype = wrapper::traits::neon_vector; - convolve_9x9_nhwc(window, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info); + convolve_nhwc_optimized(window); } else { - convolver_nhwc::convolve(window, kernel_size, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info); + convolve_nhwc(window); } break; } @@ -1561,3 +1379,4 @@ void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo } } } +} // namespace arm_compute \ No newline at end of file diff --git a/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp index da7e771aaf..fe545905d5 100644 --- a/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp @@ -28,14 +28,11 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/runtime/NEON/NEScheduler.h" -#include -#include - namespace arm_compute { NEDirectConvolutionLayer::NEDirectConvolutionLayer(std::shared_ptr memory_manager) : _memory_group(std::move(memory_manager)), _output_stage_kernel(), _conv_kernel(), _input_border_handler(), _activationlayer_function(), _accumulator(), _has_bias(false), - _is_activationlayer_enabled(false), _dim_split(Window::DimZ) + _is_activationlayer_enabled(false), _dim_split(Window::DimZ), _is_padding_required() { } @@ -59,9 +56,13 @@ void NEDirectConvolutionLayer::configure(ITensor *input, const ITensor *weights, { _output_stage_kernel.configure(output, bias); } + _is_padding_required = !_conv_kernel.border_size().empty(); - // Add zero padding XY - _input_border_handler.configure(input, _conv_kernel.border_size(), BorderMode::CONSTANT, PixelValue(static_cast(0.f))); + if(_is_padding_required) + { + // Add zero padding XY + _input_border_handler.configure(input, _conv_kernel.border_size(), BorderMode::CONSTANT, PixelValue(static_cast(0.f))); + } //Configure Activation Layer _is_activationlayer_enabled = act_info.enabled(); @@ -104,10 +105,12 @@ Status NEDirectConvolutionLayer::validate(const ITensorInfo *input, const ITenso void NEDirectConvolutionLayer::run() { - NEScheduler::get().schedule(&_input_border_handler, Window::DimZ); - MemoryGroupResourceScope scope_mg(_memory_group); + if(_is_padding_required) + { + NEScheduler::get().schedule(&_input_border_handler, Window::DimZ); + } NEScheduler::get().schedule(&_conv_kernel, _dim_split); if(_has_bias) { diff --git a/tests/validation/NEON/DirectConvolutionLayer.cpp b/tests/validation/NEON/DirectConvolutionLayer.cpp index 7277592736..afd9e3952f 100644 --- a/tests/validation/NEON/DirectConvolutionLayer.cpp +++ b/tests/validation/NEON/DirectConvolutionLayer.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ +#include "arm_compute/core/Helpers.h" #include "arm_compute/core/Types.h" #include "arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h" #include "arm_compute/runtime/Tensor.h" @@ -78,12 +79,12 @@ const auto data_f16 = combine(datasets::SmallDirectConvolutionShapes(), combine(framework::dataset::make("StrideY", { 1, 2, 3 }), data_pad_f16))); -const auto data = combine(datasets::SmallDirectConvolutionShapes(), - combine(framework::dataset::make("StrideX", { 1 }), - combine(framework::dataset::make("StrideY", { 1 }), - combine(framework::dataset::make("PadX", { 1 }), - combine(framework::dataset::make("PadY", { 1 }), - framework::dataset::make("KernelSize", 3)))))); +const auto data_prec = combine(datasets::SmallDirectConvolutionShapes(), + combine(framework::dataset::make("StrideX", { 1 }), + combine(framework::dataset::make("StrideY", { 1 }), + combine(framework::dataset::make("PadX", { 1 }), + combine(framework::dataset::make("PadY", { 1 }), + framework::dataset::make("KernelSize", 3)))))); const auto data9x9 = combine(datasets::SmallDirectConvolutionShapes(), combine(framework::dataset::make("StrideX", { 1 }), @@ -95,7 +96,7 @@ const auto data9x9 = combine(datasets::SmallDirectConvolutionShapes(), const auto data_f32_nightly = combine(data_f32, framework::dataset::make("NumKernels", { 1, 4 })); const auto data_f16_nightly = combine(data_f16, framework::dataset::make("NumKernels", { 1, 4 })); -const auto data_precommit = combine(data, framework::dataset::make("NumKernels", { 1 })); +const auto data_precommit = combine(data_prec, framework::dataset::make("NumKernels", { 1 })); const auto data_precommit9x9 = combine(data9x9, framework::dataset::make("NumKernels", { 4 })); /* The following tests is from real use-case that made DirectConvolution @@ -195,6 +196,43 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip( // clang-format on // *INDENT-ON* +DATA_TEST_CASE(NoPaddingNHWCKernel, framework::DatasetMode::ALL, combine(combine(combine(data_precommit, + framework::dataset::make("DataType", DataType::F32)), + ActivationFunctionsDataset), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + + shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, act_info, data_layout) +{ + TensorShape input_shape = TensorShape(shape); + TensorShape weights_shape(kernel_size, kernel_size, input_shape.z(), num_kernels); + const PadStrideInfo info(stride_x, stride_y, pad_x, pad_y, DimensionRoundingType::FLOOR); + + TensorInfo input_info = TensorInfo(input_shape, 1, data_type); + TensorInfo weights_info = TensorInfo(weights_shape, 1, data_type); + + TensorShape output_shape = compute_deep_convolution_shape(input_info, weights_info, info); + + if(data_layout == DataLayout::NHWC) + { + permute(input_shape, PermutationVector(2U, 0U, 1U)); + permute(weights_shape, PermutationVector(2U, 0U, 1U)); + permute(output_shape, PermutationVector(2U, 0U, 1U)); + } + + // Create tensors + Tensor src = create_tensor(input_shape, data_type, 1, QuantizationInfo(), data_layout); + Tensor weights = create_tensor(weights_shape, data_type, 1, QuantizationInfo(), data_layout); + Tensor dst = create_tensor(output_shape, data_type, 1, QuantizationInfo(), data_layout); + + // Create and configure function + NEDirectConvolutionLayer conv; + conv.configure(&src, &weights, nullptr, &dst, info, act_info); + + validate(src.info()->padding(), PaddingSize(0, 0, 0, 0)); + validate(weights.info()->padding(), PaddingSize(0, 0, 0, 0)); + validate(dst.info()->padding(), PaddingSize(0, 0, 0, 0)); +} + template using NEDirectConvolutionLayerFixture = DirectConvolutionValidationFixture; diff --git a/tests/validation/fixtures/DirectConvolutionLayerFixture.h b/tests/validation/fixtures/DirectConvolutionLayerFixture.h index 3da5158e97..e37063e2e5 100644 --- a/tests/validation/fixtures/DirectConvolutionLayerFixture.h +++ b/tests/validation/fixtures/DirectConvolutionLayerFixture.h @@ -51,13 +51,10 @@ class DirectConvolutionValidationGenericFixture : public framework::Fixture public: using TBias = typename std::conditional < std::is_same::value || std::is_same::value, int32_t, T >::type; -public: template void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, QuantizationInfo quantization_info, ActivationLayerInfo act_info, DataLayout data_layout) { - ARM_COMPUTE_ERROR_ON(data_layout == DataLayout::UNKNOWN); - _quantization_info = quantization_info; _data_type = data_type; -- cgit v1.2.1