diff options
Diffstat (limited to 'src/core/NEON')
-rw-r--r-- | src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp | 4 | ||||
-rw-r--r-- | src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp | 2 |
2 files changed, 4 insertions, 2 deletions
diff --git a/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp b/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp index 92383d9f15..dad4fee837 100644 --- a/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp +++ b/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp @@ -272,8 +272,8 @@ void NEDepthwiseConvolutionLayer3x3Kernel::configure_generic() -conv_pad_top, (num_x_steps - 1) * input_num_elems_processed + num_elems_read_per_iteration, _input->info()->tensor_shape().y() + conv_pad_bottom); - AccessWindowStatic weights_access(_weights->info(), 0, 0, _weights->info()->dimension(0), _weights->info()->dimension(1)); - AccessWindowStatic output_access(_output->info(), 0, 0, num_x_steps * _num_elems_written_per_iteration, output_shape.y()); + AccessWindowStatic weights_access(_weights->info(), 0, 0, _weights->info()->dimension(0), _weights->info()->dimension(1)); + AccessWindowHorizontal output_access(_output->info(), 0, _num_elems_written_per_iteration); update_window_and_padding(win, input_access, weights_access, output_access); output_access.set_valid_region(win, ValidRegion(Coordinates(), _output->info()->tensor_shape())); diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp index 3f33c43b59..08d8f8ce56 100644 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp +++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp @@ -451,8 +451,10 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const break; } case DataType::S32: + { _func = (bias == nullptr) ? &output_stage<int32_t, uint8_t, false, false> : &output_stage<int32_t, uint8_t, false, true>; break; + } #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { |