aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp11
1 files changed, 6 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
index 09f99748bf..98b76c7db3 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
@@ -1217,7 +1217,7 @@ void NEDirectConvolutionLayerKernel::convolve_nhwc(const Window &window)
NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
: _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
- _num_elems_written_per_iteration(0)
+ _num_elems_written_per_iteration(0), _data_layout()
{
}
@@ -1234,13 +1234,14 @@ void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITens
_weights = weights;
_output = output;
_conv_info = conv_info;
- _kernel_size = weights->info()->dimension(get_data_layout_dimension_index(weights->info()->data_layout(), DataLayoutDimension::WIDTH));
+ _data_layout = _input->info()->data_layout();
+ _kernel_size = weights->info()->dimension(get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH));
const unsigned int conv_pad_left = conv_info.pad_left();
const unsigned int conv_pad_top = conv_info.pad_top();
const unsigned int conv_pad_right = conv_info.pad_right();
const unsigned int conv_pad_bottom = conv_info.pad_bottom();
- if(_input->info()->data_layout() == DataLayout::NCHW)
+ if(_data_layout == DataLayout::NCHW)
{
_border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
}
@@ -1294,9 +1295,9 @@ void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
- const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH));
+ const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH));
- if(_input->info()->data_layout() == DataLayout::NCHW)
+ if(_data_layout == DataLayout::NCHW)
{
switch(kernel_size)
{