diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp')
-rw-r--r-- | arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp | 146 |
1 files changed, 92 insertions, 54 deletions
diff --git a/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp b/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp index 674fc4d2df..493b2991dc 100644 --- a/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp +++ b/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp @@ -112,11 +112,7 @@ MEMBERFN()::DepthwiseConvolutionBase( _padding_right(padding_right), _activation(activation), _input_col_stride(0), _input_row_stride(0), _input_batch_stride(0), - _input_ws_col_stride(_n_channels), - _input_ws_row_stride(_input_ws_col_stride * inner_tile_cols), - _output_col_stride(0), _output_row_stride(0), _output_batch_stride(0), - _output_ws_col_stride(_n_channels), - _output_ws_row_stride(_output_ws_col_stride * OutputTileColumns) + _output_col_stride(0), _output_row_stride(0), _output_batch_stride(0) { } @@ -231,12 +227,12 @@ MEMBERFN(void)::set_working_space(void *buffer) MEMBERFN(size_t)::_get_input_working_space_size(void) const { - return sizeof(TIn) * inner_tile_rows * inner_tile_cols * _n_channels; + return sizeof(TIn) * _n_channels; } MEMBERFN(size_t)::_get_output_working_space_size(void) const { - return sizeof(TOut) * OutputTileRows * OutputTileColumns * _n_channels; + return sizeof(TOut) * _n_channels; } MEMBERFN(void *)::_get_input_working_space(const unsigned int threadid) const @@ -263,6 +259,14 @@ MEMBERFN(void)::run( const unsigned int threadid ) { + // Clear the input padding buffer + TIn *buf = static_cast<TIn *>(_get_input_working_space(threadid)); + const TIn pad_value = static_cast<Derived *>(this)->_input_padding_value(); + for (int n = 0; n < _n_channels; n++) + { + buf[n] = pad_value; + } + // Parallelise over blocks of channels const auto start_channel = CHANNEL_BLOCK * start; const auto stop_channel = std::min<unsigned int>(_n_channels, CHANNEL_BLOCK * stop); @@ -379,60 +383,94 @@ MEMBERFN(void)::process_tile( const int pad_out_right ) { + Derived * dthis = static_cast<Derived *>(this); const bool pad_input = pad_in_top || pad_in_left || pad_in_bottom || pad_in_right; const bool pad_output = pad_out_bottom || pad_out_right; - if (pad_input) + if (!pad_input && !pad_output) { - // Copy the input into the temporary buffer, applying padding - padding::copy_and_pad_tile<TIn>( - inner_tile_rows, inner_tile_cols, n_channels, - inptr, _input_row_stride, _input_col_stride, - static_cast<TIn *>(_get_input_working_space(threadid)), _input_ws_row_stride, _input_ws_col_stride, - pad_in_top, pad_in_left, pad_in_bottom, pad_in_right, - static_cast<Derived *>(this)->_input_padding_value() - ); + switch(_activation) + { + case ActivationFunction::ReLU: + dthis->template execute_tile<ActivationFunction::ReLU>( + n_channels, packed_params, + inptr, _input_row_stride, _input_col_stride, + outptr, _output_row_stride, _output_col_stride + ); + break; + case ActivationFunction::ReLU6: + dthis->template execute_tile<ActivationFunction::ReLU6>( + n_channels, packed_params, + inptr, _input_row_stride, _input_col_stride, + outptr, _output_row_stride, _output_col_stride + ); + break; + default: + dthis->template execute_tile<ActivationFunction::None>( + n_channels, packed_params, + inptr, _input_row_stride, _input_col_stride, + outptr, _output_row_stride, _output_col_stride + ); + break; + } } - - // Execute the kernel - const TIn * const tile_inptr = !pad_input ? inptr : static_cast<const TIn *>(_get_input_working_space(threadid)); - const int in_row_stride = !pad_input ? _input_row_stride : _input_ws_row_stride; - const int in_col_stride = !pad_input ? _input_col_stride : _input_ws_col_stride; - - TOut * const tile_outptr = !pad_output ? outptr : static_cast<TOut *>(_get_output_working_space(threadid)); - const int out_row_stride = !pad_output ? _output_row_stride : _output_ws_row_stride; - const int out_col_stride = !pad_output ? _output_col_stride : _output_ws_col_stride; - - Derived * dthis = static_cast<Derived *>(this); - - switch(_activation) + else { - case ActivationFunction::ReLU: - dthis->template execute_tile<ActivationFunction::ReLU>( - n_channels, packed_params, tile_inptr, in_row_stride, in_col_stride, tile_outptr, out_row_stride, out_col_stride - ); - break; - case ActivationFunction::ReLU6: - dthis->template execute_tile<ActivationFunction::ReLU6>( - n_channels, packed_params, tile_inptr, in_row_stride, in_col_stride, tile_outptr, out_row_stride, out_col_stride - ); - break; - default: - dthis->template execute_tile<ActivationFunction::None>( - n_channels, packed_params, tile_inptr, in_row_stride, in_col_stride, tile_outptr, out_row_stride, out_col_stride - ); - break; - } + // Create arrays of input and output pointers, pointing padded elements to + // the working space padding buffers provided. + const TIn *inptrs[inner_tile_rows][inner_tile_cols]; + for (int i = 0; i < inner_tile_rows; i++) + { + for (int j = 0; j < inner_tile_cols; j++) + { + if (i < pad_in_top || (inner_tile_rows - pad_in_bottom) <= i || + j < pad_in_left || (inner_tile_cols - pad_in_right) <= j) + { + // Padded input + inptrs[i][j] = static_cast<const TIn *>(_get_input_working_space(threadid)); + } + else + { + inptrs[i][j] = inptr + (i - pad_in_top)*_input_row_stride + (j - pad_in_left)*_input_col_stride; + } + } + } - if (pad_output) - { - // Copy the output from the temporary buffer, removing unnecessary values - padding::CopyCropped<OutputTileRows, OutputTileColumns>::execute( - n_channels * sizeof(TOut), - _get_output_working_space(threadid), _output_ws_row_stride * sizeof(TOut), _output_ws_col_stride * sizeof(TOut), - outptr, _output_row_stride * sizeof(TOut), _output_col_stride * sizeof(TOut), - 0, 0, pad_out_bottom, pad_out_right - ); + TOut *outptrs[output_tile_rows][output_tile_cols]; + for (int i = 0; i < output_tile_rows; i++) + { + for (int j = 0; j < output_tile_cols; j++) + { + if (i < (output_tile_rows - pad_out_bottom) && + j < (output_tile_cols - pad_out_right)) + { + outptrs[i][j] = outptr + i*_output_row_stride + j*_output_col_stride; + } + else + { + outptrs[i][j] = static_cast<TOut *>(_get_output_working_space(threadid)); + } + } + } + + switch(_activation) + { + case ActivationFunction::ReLU: + dthis->template execute_tile<ActivationFunction::ReLU>( + n_channels, packed_params, inptrs, outptrs + ); + break; + case ActivationFunction::ReLU6: + dthis->template execute_tile<ActivationFunction::ReLU6>( + n_channels, packed_params, inptrs, outptrs + ); + break; + default: + dthis->template execute_tile<ActivationFunction::None>( + n_channels, packed_params, inptrs, outptrs + ); + break; + } } } |