aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-08-02 16:43:24 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commitebbb7f2ee00bdebad5da3629bbc78dc3a65fe0c5 (patch)
treeece1bc256de86ea6fa25803e30976b1eed7b3ffd /src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
parent4cb39096ded770bfc4cf9712b85ed38e66c0e3f7 (diff)
downloadComputeLibrary-ebbb7f2ee00bdebad5da3629bbc78dc3a65fe0c5.tar.gz
COMPMID-1188 - Fixed CLGEMMConvolutionLayer/NEGEMMConvolutionLayer for NHWC
We skipped im2col also without unit strides Change-Id: I04c63a6dda8553b3890e832a56ff6854349c829a Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/142520 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp2
1 files changed, 1 insertions, 1 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index 25e8d9e60b..aace261e32 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -209,7 +209,7 @@ Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInf
kernel_height = (are_weights_reshaped) ? weights_info.kernel_size().second : weights->dimension(idx_height);
mat_weights_cols = weights->dimension(3);
mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel) + ((append_bias && !skip_im2col) ? 1 : 0);
- skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1);
+ skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(idx_width), input->dimension(idx_height), kernel_width, kernel_height,
conv_info, dilation);