aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h6
-rw-r--r--src/core/NEON/kernels/NEIm2ColKernel.cpp22
-rw-r--r--tests/validation/reference/Im2Col.cpp12
3 files changed, 16 insertions, 24 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index b91e52a657..8d4c024f62 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -168,9 +168,9 @@ inline TensorShape compute_im2col_conv_shape(const ITensorInfo *input, const Siz
const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
std::pair<unsigned int, unsigned int> out_dims = scaled_dimensions(output_shape[width_idx], output_shape[height_idx], kernel_dims.width, kernel_dims.height, conv_info, dilation);
- output_shape.set(width_idx, (output_shape[channel_idx] * kernel_dims.area() + (has_bias ? 1 : 0)));
- output_shape.set(height_idx, (out_dims.first * out_dims.second));
- output_shape.set(channel_idx, 1);
+ output_shape.set(0, (output_shape[channel_idx] * kernel_dims.area() + (has_bias ? 1 : 0)));
+ output_shape.set(1, (out_dims.first * out_dims.second));
+ output_shape.set(2, 1);
return output_shape;
}
diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp
index 5e165a641c..86e3fd7a84 100644
--- a/src/core/NEON/kernels/NEIm2ColKernel.cpp
+++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp
@@ -218,21 +218,15 @@ void NEIm2ColKernel::run_generic(const Window &window)
const int start_x = -pad_left;
const int start_y = -pad_top;
- Window window_in(window);
- // The first three dimensions of the input are increased by the inner loops
- window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
- window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
- window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
-
- // Setup output window
- Window window_out(window);
- window_out.set(width_idx, Window::Dimension(0, _output->info()->dimension(width_idx), _output->info()->strides_in_bytes()[width_idx + 1] / _output->info()->strides_in_bytes()[width_idx]));
- window_out.set(height_idx, Window::Dimension(window[height_idx].start() * _convolved_dims.first, window[height_idx].end() * _convolved_dims.first, _convolved_dims.first));
- window_out.set(channel_idx, Window::Dimension(0, 1, 1));
+ Window window_in_out(window);
+ // The first three dimensions of the input and output are increased by the inner loops
+ window_in_out.set(Window::DimX, Window::Dimension(0, 0, 0));
+ window_in_out.set(Window::DimY, Window::Dimension(0, 0, 0));
+ window_in_out.set(Window::DimZ, Window::Dimension(0, 0, 0));
// Create iterators
- Iterator in(_input, window_in);
- Iterator out(_output, window_out);
+ Iterator in(_input, window_in_out);
+ Iterator out(_output, window_in_out);
execute_window_loop(window, [&](const Coordinates & id)
{
@@ -241,7 +235,7 @@ void NEIm2ColKernel::run_generic(const Window &window)
// Get pointers
const uint8_t *const input_ptr = in.ptr();
- auto output_ptr = reinterpret_cast<T *>(out.ptr());
+ auto output_ptr = reinterpret_cast<T *>(out.ptr() + (id[width_idx] + id[height_idx] * _convolved_dims.first) * _output->info()->strides_in_bytes().y());
// Linearize volume
linearize_volume<T, has_pads>(input_ptr,
diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp
index 825f0a6ee1..d309b7d5e6 100644
--- a/tests/validation/reference/Im2Col.cpp
+++ b/tests/validation/reference/Im2Col.cpp
@@ -88,14 +88,12 @@ SimpleTensor<T> im2col(const SimpleTensor<T> &src, const TensorShape &dst_shape,
if(src.data_layout() == DataLayout::NHWC)
{
SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U));
- SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U));
-
- im2col_nchw(src_nchw, dst_nchw, kernel_dims, conv_info, has_bias);
-
- return reference::permute<T>(dst_nchw, PermutationVector(2U, 0U, 1U));
+ im2col_nchw(src_nchw, dst, kernel_dims, conv_info, has_bias);
+ }
+ else
+ {
+ im2col_nchw(src, dst, kernel_dims, conv_info, has_bias);
}
-
- im2col_nchw(src, dst, kernel_dims, conv_info, has_bias);
return dst;
}