diff options
author | Giorgio Arena <giorgio.arena@arm.com> | 2018-04-20 16:06:21 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:49:37 +0000 |
commit | f485a100e3f11911d25a16b4ccc286c6c0816061 (patch) | |
tree | e5ded0791adee9f8c83279664d20f3db5e490a0e | |
parent | 48c19f1308ecdc7ea37a6bf5ce9459e0954e9007 (diff) | |
download | ComputeLibrary-f485a100e3f11911d25a16b4ccc286c6c0816061.tar.gz |
COMPMID-802 Fix NEIm2Col NHWC
Change-Id: I513e0199b6fa665c4a7d2a739f4871b4575ef347
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128490
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 6 | ||||
-rw-r--r-- | src/core/NEON/kernels/NEIm2ColKernel.cpp | 22 | ||||
-rw-r--r-- | tests/validation/reference/Im2Col.cpp | 12 |
3 files changed, 16 insertions, 24 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index b91e52a657..8d4c024f62 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -168,9 +168,9 @@ inline TensorShape compute_im2col_conv_shape(const ITensorInfo *input, const Siz const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); std::pair<unsigned int, unsigned int> out_dims = scaled_dimensions(output_shape[width_idx], output_shape[height_idx], kernel_dims.width, kernel_dims.height, conv_info, dilation); - output_shape.set(width_idx, (output_shape[channel_idx] * kernel_dims.area() + (has_bias ? 1 : 0))); - output_shape.set(height_idx, (out_dims.first * out_dims.second)); - output_shape.set(channel_idx, 1); + output_shape.set(0, (output_shape[channel_idx] * kernel_dims.area() + (has_bias ? 1 : 0))); + output_shape.set(1, (out_dims.first * out_dims.second)); + output_shape.set(2, 1); return output_shape; } diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp index 5e165a641c..86e3fd7a84 100644 --- a/src/core/NEON/kernels/NEIm2ColKernel.cpp +++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp @@ -218,21 +218,15 @@ void NEIm2ColKernel::run_generic(const Window &window) const int start_x = -pad_left; const int start_y = -pad_top; - Window window_in(window); - // The first three dimensions of the input are increased by the inner loops - window_in.set(Window::DimX, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimY, Window::Dimension(0, 0, 0)); - window_in.set(Window::DimZ, Window::Dimension(0, 0, 0)); - - // Setup output window - Window window_out(window); - window_out.set(width_idx, Window::Dimension(0, _output->info()->dimension(width_idx), _output->info()->strides_in_bytes()[width_idx + 1] / _output->info()->strides_in_bytes()[width_idx])); - window_out.set(height_idx, Window::Dimension(window[height_idx].start() * _convolved_dims.first, window[height_idx].end() * _convolved_dims.first, _convolved_dims.first)); - window_out.set(channel_idx, Window::Dimension(0, 1, 1)); + Window window_in_out(window); + // The first three dimensions of the input and output are increased by the inner loops + window_in_out.set(Window::DimX, Window::Dimension(0, 0, 0)); + window_in_out.set(Window::DimY, Window::Dimension(0, 0, 0)); + window_in_out.set(Window::DimZ, Window::Dimension(0, 0, 0)); // Create iterators - Iterator in(_input, window_in); - Iterator out(_output, window_out); + Iterator in(_input, window_in_out); + Iterator out(_output, window_in_out); execute_window_loop(window, [&](const Coordinates & id) { @@ -241,7 +235,7 @@ void NEIm2ColKernel::run_generic(const Window &window) // Get pointers const uint8_t *const input_ptr = in.ptr(); - auto output_ptr = reinterpret_cast<T *>(out.ptr()); + auto output_ptr = reinterpret_cast<T *>(out.ptr() + (id[width_idx] + id[height_idx] * _convolved_dims.first) * _output->info()->strides_in_bytes().y()); // Linearize volume linearize_volume<T, has_pads>(input_ptr, diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp index 825f0a6ee1..d309b7d5e6 100644 --- a/tests/validation/reference/Im2Col.cpp +++ b/tests/validation/reference/Im2Col.cpp @@ -88,14 +88,12 @@ SimpleTensor<T> im2col(const SimpleTensor<T> &src, const TensorShape &dst_shape, if(src.data_layout() == DataLayout::NHWC) { SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U)); - SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U)); - - im2col_nchw(src_nchw, dst_nchw, kernel_dims, conv_info, has_bias); - - return reference::permute<T>(dst_nchw, PermutationVector(2U, 0U, 1U)); + im2col_nchw(src_nchw, dst, kernel_dims, conv_info, has_bias); + } + else + { + im2col_nchw(src, dst, kernel_dims, conv_info, has_bias); } - - im2col_nchw(src, dst, kernel_dims, conv_info, has_bias); return dst; } |