diff options
Diffstat (limited to 'src/core')
-rw-r--r-- | src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp | 177 |
1 files changed, 135 insertions, 42 deletions
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp index 08d8f8ce56..edda2cd9da 100644 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp +++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp @@ -44,6 +44,7 @@ namespace Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input); + ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::QS32, DataType::S32, DataType::F32); @@ -68,6 +69,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias); } + ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL))); ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1); } else @@ -79,6 +81,8 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con if((output != nullptr) && (output->total_size() != 0)) { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + if(is_data_type_fixed_point(input->data_type())) { ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS8 && output->data_type() != DataType::QS8, "Wrong data type for output"); @@ -101,6 +105,8 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *bias, ITensorInfo *output) { + ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); + bool window_changed = false; unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(input->data_type()); @@ -138,8 +144,16 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen } else { - AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1)); - window_changed = update_window_and_padding(win, input_access, bias_access); + if(input->data_layout() == DataLayout::NCHW) + { + AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1)); + window_changed = update_window_and_padding(win, input_access, bias_access); + } + else + { + AccessWindowHorizontal bias_access(bias, 0, num_elems_processed_per_iteration); + window_changed = update_window_and_padding(win, input_access, bias_access); + } } input_access.set_valid_region(win, ValidRegion(Coordinates(), input->tensor_shape())); @@ -253,6 +267,7 @@ template <typename T1, typename T2, bool in_place, bool has_bias> void output_stage(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift) { + ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN); ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier); ARM_COMPUTE_UNUSED(result_shift); ARM_COMPUTE_UNUSED(result_offset_after_shift); @@ -303,6 +318,66 @@ void output_stage(ITensor *input, const ITensor *bias, const Window &window, ITe } } +template <typename T1, typename T2, bool in_place, bool has_bias> +void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, + int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift) +{ + ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier); + ARM_COMPUTE_UNUSED(result_shift); + ARM_COMPUTE_UNUSED(result_offset_after_shift); + + Window window_bias = window; + window_bias.set(Window::DimY, Window::Dimension(0, 0, 0)); + window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0)); + window_bias.set(3, Window::Dimension(0, 0, 0)); + + Iterator in(input, window); + Iterator bi(bias, window_bias); + + if(in_place) // In place accumulate + { + execute_window_loop(window, [&](const Coordinates & id) + { + // Get bias and pointer to input + const auto in_ptr = reinterpret_cast<T1 *>(in.ptr()); + const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr()); + + // Accumulate bias + if(has_bias) + { + internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr))); + } + else + { + internal_vst1q(in_ptr, internal_vld1q(in_ptr)); + } + }, + in, bi); + } + else // Out of place accumulate + { + Iterator out(output, window); + execute_window_loop(window, [&](const Coordinates & id) + { + // Get bias and pointer to input + const auto in_ptr = reinterpret_cast<T1 *>(in.ptr()); + const auto out_ptr = reinterpret_cast<T2 *>(out.ptr()); + const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr()); + + // Accumulate bias + if(has_bias) + { + internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr))); + } + else + { + internal_vst1q(out_ptr, internal_vld1q(in_ptr)); + } + }, + in, bi); + } +} + // QASYMM8 specializations template <> void output_stage<int32_t, uint8_t, false, true>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, @@ -415,61 +490,79 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const INEKernel::configure(win_config.second); // Set appropriate function - switch(input->info()->data_type()) + if(input->info()->data_layout() == DataLayout::NCHW) { - case DataType::QS8: + switch(input->info()->data_type()) { - if(bias == nullptr) + case DataType::QS8: { - _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, false> : &output_stage<qint8_t, qint8_t, false, false>; + if(bias == nullptr) + { + _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, false> : &output_stage<qint8_t, qint8_t, false, false>; + } + else + { + _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, true> : &output_stage<qint8_t, qint8_t, false, true>; + } + break; } - else + case DataType::QS16: { - _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, true> : &output_stage<qint8_t, qint8_t, false, true>; + if(bias != nullptr && bias->info()->data_type() == DataType::QS8) + { + _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, true> : &output_stage<qint16_t, qint8_t, false, true>; + } + else if(bias == nullptr) + { + _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, false> : &output_stage<qint16_t, qint8_t, false, false>; + } + else + { + ARM_COMPUTE_ERROR("Not implemented"); + } + break; } - break; - } - case DataType::QS16: - { - if(bias != nullptr && bias->info()->data_type() == DataType::QS8) + case DataType::QS32: { - _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, true> : &output_stage<qint16_t, qint8_t, false, true>; + _func = (output == nullptr) ? &output_stage<qint32_t, qint16_t, true, true> : &output_stage<qint32_t, qint16_t, false, true>; + break; } - else if(bias == nullptr) + case DataType::S32: { - _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, false> : &output_stage<qint16_t, qint8_t, false, false>; + _func = (bias == nullptr) ? &output_stage<int32_t, uint8_t, false, false> : &output_stage<int32_t, uint8_t, false, true>; + break; } - else +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: { - ARM_COMPUTE_ERROR("Not implemented"); + _func = (output == nullptr) ? &output_stage<float16_t, float16_t, true, true> : &output_stage<float16_t, float16_t, false, true>; + break; } - break; - } - case DataType::QS32: - { - _func = (output == nullptr) ? &output_stage<qint32_t, qint16_t, true, true> : &output_stage<qint32_t, qint16_t, false, true>; - break; - } - case DataType::S32: - { - _func = (bias == nullptr) ? &output_stage<int32_t, uint8_t, false, false> : &output_stage<int32_t, uint8_t, false, true>; - break; - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - { - _func = (output == nullptr) ? &output_stage<float16_t, float16_t, true, true> : &output_stage<float16_t, float16_t, false, true>; - break; - } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - case DataType::F32: - { - _func = (output == nullptr) ? &output_stage<float, float, true, true> : &output_stage<float, float, false, true>; - break; + case DataType::F32: + { + _func = (output == nullptr) ? &output_stage<float, float, true, true> : &output_stage<float, float, false, true>; + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs."); + } } - default: + } + else + { + switch(input->info()->data_type()) { - ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs."); + case DataType::F32: + { + _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>; + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs."); + } } } } |