From 1ed1fc6d3b7d8494ce3bbc5f8b46bfde6fc586f9 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Mon, 26 Mar 2018 16:20:05 +0100 Subject: COMPMID-812 Add NHWC data format support for NEON depthwise convolution (optimized case). Change-Id: Icdfd6c02ed526daf4f59a4b76c7bbc1bc48fde74 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125938 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- .../NEDirectConvolutionLayerOutputStageKernel.cpp | 177 ++++++++++++++++----- 1 file changed, 135 insertions(+), 42 deletions(-) (limited to 'src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp') diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp index 08d8f8ce56..edda2cd9da 100644 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp +++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp @@ -44,6 +44,7 @@ namespace Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input); + ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::QS32, DataType::S32, DataType::F32); @@ -68,6 +69,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias); } + ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL))); ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1); } else @@ -79,6 +81,8 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con if((output != nullptr) && (output->total_size() != 0)) { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + if(is_data_type_fixed_point(input->data_type())) { ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS8 && output->data_type() != DataType::QS8, "Wrong data type for output"); @@ -101,6 +105,8 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con std::pair validate_and_configure_window(ITensorInfo *input, ITensorInfo *bias, ITensorInfo *output) { + ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); + bool window_changed = false; unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(input->data_type()); @@ -138,8 +144,16 @@ std::pair validate_and_configure_window(ITensorInfo *input, ITen } else { - AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1)); - window_changed = update_window_and_padding(win, input_access, bias_access); + if(input->data_layout() == DataLayout::NCHW) + { + AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1)); + window_changed = update_window_and_padding(win, input_access, bias_access); + } + else + { + AccessWindowHorizontal bias_access(bias, 0, num_elems_processed_per_iteration); + window_changed = update_window_and_padding(win, input_access, bias_access); + } } input_access.set_valid_region(win, ValidRegion(Coordinates(), input->tensor_shape())); @@ -253,6 +267,7 @@ template void output_stage(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift) { + ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN); ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier); ARM_COMPUTE_UNUSED(result_shift); ARM_COMPUTE_UNUSED(result_offset_after_shift); @@ -303,6 +318,66 @@ void output_stage(ITensor *input, const ITensor *bias, const Window &window, ITe } } +template +void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, + int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift) +{ + ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier); + ARM_COMPUTE_UNUSED(result_shift); + ARM_COMPUTE_UNUSED(result_offset_after_shift); + + Window window_bias = window; + window_bias.set(Window::DimY, Window::Dimension(0, 0, 0)); + window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0)); + window_bias.set(3, Window::Dimension(0, 0, 0)); + + Iterator in(input, window); + Iterator bi(bias, window_bias); + + if(in_place) // In place accumulate + { + execute_window_loop(window, [&](const Coordinates & id) + { + // Get bias and pointer to input + const auto in_ptr = reinterpret_cast(in.ptr()); + const auto bias_ptr = reinterpret_cast(bi.ptr()); + + // Accumulate bias + if(has_bias) + { + internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr))); + } + else + { + internal_vst1q(in_ptr, internal_vld1q(in_ptr)); + } + }, + in, bi); + } + else // Out of place accumulate + { + Iterator out(output, window); + execute_window_loop(window, [&](const Coordinates & id) + { + // Get bias and pointer to input + const auto in_ptr = reinterpret_cast(in.ptr()); + const auto out_ptr = reinterpret_cast(out.ptr()); + const auto bias_ptr = reinterpret_cast(bi.ptr()); + + // Accumulate bias + if(has_bias) + { + internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr))); + } + else + { + internal_vst1q(out_ptr, internal_vld1q(in_ptr)); + } + }, + in, bi); + } +} + // QASYMM8 specializations template <> void output_stage(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, @@ -415,61 +490,79 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const INEKernel::configure(win_config.second); // Set appropriate function - switch(input->info()->data_type()) + if(input->info()->data_layout() == DataLayout::NCHW) { - case DataType::QS8: + switch(input->info()->data_type()) { - if(bias == nullptr) + case DataType::QS8: { - _func = (output == nullptr) ? &output_stage : &output_stage; + if(bias == nullptr) + { + _func = (output == nullptr) ? &output_stage : &output_stage; + } + else + { + _func = (output == nullptr) ? &output_stage : &output_stage; + } + break; } - else + case DataType::QS16: { - _func = (output == nullptr) ? &output_stage : &output_stage; + if(bias != nullptr && bias->info()->data_type() == DataType::QS8) + { + _func = (output == nullptr) ? &output_stage : &output_stage; + } + else if(bias == nullptr) + { + _func = (output == nullptr) ? &output_stage : &output_stage; + } + else + { + ARM_COMPUTE_ERROR("Not implemented"); + } + break; } - break; - } - case DataType::QS16: - { - if(bias != nullptr && bias->info()->data_type() == DataType::QS8) + case DataType::QS32: { - _func = (output == nullptr) ? &output_stage : &output_stage; + _func = (output == nullptr) ? &output_stage : &output_stage; + break; } - else if(bias == nullptr) + case DataType::S32: { - _func = (output == nullptr) ? &output_stage : &output_stage; + _func = (bias == nullptr) ? &output_stage : &output_stage; + break; } - else +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: { - ARM_COMPUTE_ERROR("Not implemented"); + _func = (output == nullptr) ? &output_stage : &output_stage; + break; } - break; - } - case DataType::QS32: - { - _func = (output == nullptr) ? &output_stage : &output_stage; - break; - } - case DataType::S32: - { - _func = (bias == nullptr) ? &output_stage : &output_stage; - break; - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - { - _func = (output == nullptr) ? &output_stage : &output_stage; - break; - } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - case DataType::F32: - { - _func = (output == nullptr) ? &output_stage : &output_stage; - break; + case DataType::F32: + { + _func = (output == nullptr) ? &output_stage : &output_stage; + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs."); + } } - default: + } + else + { + switch(input->info()->data_type()) { - ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs."); + case DataType::F32: + { + _func = (output == nullptr) ? &output_stage_nhwc : &output_stage_nhwc; + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs."); + } } } } -- cgit v1.2.1