diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2017-06-21 08:54:02 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-09-17 13:03:43 +0100 |
commit | 7b06cde62a19ae0fda7bbbf0ab0bf18a88470af8 (patch) | |
tree | cf6013541a22713b2ca3b4a9757d8408fe310652 | |
parent | 6ff3b19ee6120edf015fad8caab2991faa3070af (diff) | |
download | ComputeLibrary-7b06cde62a19ae0fda7bbbf0ab0bf18a88470af8.tar.gz |
COMPMID-345 - Fixed issue with non rectangular kernels in NEConvolutionLayer funcion
Change-Id: I9157c274ce8545b7dc391ee91623d7df8ed77395
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/78331
Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
-rw-r--r-- | src/core/NEON/kernels/NEWeightsReshapeKernel.cpp | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp index aa6be44bee..f31cde719a 100644 --- a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp +++ b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp @@ -37,7 +37,8 @@ namespace template <typename T> void weights_reshape(const ITensor *input, const ITensor *bias, ITensor *output, const Window &window) { - const unsigned int kernel_size = input->info()->dimension(0); + const unsigned int kernel_size_x = input->info()->dimension(0); + const unsigned int kernel_size_y = input->info()->dimension(1); const unsigned int kernel_depth = input->info()->dimension(2); const unsigned int input_stride_x = input->info()->strides_in_bytes().x(); const unsigned int input_stride_y = input->info()->strides_in_bytes().y(); @@ -61,9 +62,9 @@ void weights_reshape(const ITensor *input, const ITensor *bias, ITensor *output, // Linearize volume for(unsigned int d = 0; d < kernel_depth; ++d) { - for(unsigned int j = 0; j < kernel_size; ++j) + for(unsigned int j = 0; j < kernel_size_y; ++j) { - for(unsigned int i = 0; i < kernel_size; ++i) + for(unsigned int i = 0; i < kernel_size_x; ++i) { *(reinterpret_cast<T *>(tmp_output_ptr)) = *(reinterpret_cast<const T *>(tmp_input_ptr)); tmp_input_ptr += input_stride_x; |