From 0cbb927ac309e332ac6e6f1ab9170f041f0138ab Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Thu, 1 Mar 2018 16:56:48 +0000 Subject: COMPMID-804: Add NHWC data format support for NEON batch normalisation Change-Id: I04892e7be3f5aa58cd95917a4f90a6b4ffcf6efc Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122897 Reviewed-by: Giorgio Arena Tested-by: Jenkins Reviewed-by: Anthony Barbier --- .../kernels/NEBatchNormalizationLayerKernel.cpp | 110 ++++++++++++++++++--- 1 file changed, 98 insertions(+), 12 deletions(-) (limited to 'src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp') diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp index d1bdfac2da..6be50fdb0d 100644 --- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp @@ -58,6 +58,7 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT if(nullptr != output) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); } @@ -77,7 +78,7 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, gamma); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma); } - ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0)); + ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)) != mean->dimension(0)); return Status{}; } @@ -209,9 +210,9 @@ void NEBatchNormalizationLayerKernel::batch_normalization_qs16(const Window &win } template -void NEBatchNormalizationLayerKernel::batch_normalization_fp16(const Window &window) +void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw(const Window &window) { - static_assert(!fused_activation, "Activation is not supported for QS8"); + static_assert(!fused_activation, "Activation is not supported for FP16"); ARM_COMPUTE_UNUSED(window); #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -263,8 +264,43 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp16(const Window &win #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ } +template +void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc(const Window &window) +{ + static_assert(!fused_activation, "Activation is not supported for FP16"); + + ARM_COMPUTE_UNUSED(window); +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + Iterator input(_input, window); + Iterator output(_output, window); + + const auto input_mean = reinterpret_cast(_mean->ptr_to_element(Coordinates(0, 0))); + const auto input_var = reinterpret_cast(_var->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; + const auto input_beta = (_beta != nullptr) ? reinterpret_cast(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; + + const float16x8_t epsilon_vec = vdupq_n_f16(_epsilon); + execute_window_loop(window, [&](const Coordinates & id) + { + // Conctruct vectors + const float16x8_t mean_vec = vld1q_f16(input_mean + id.x()); + const float16x8_t var_vec = vld1q_f16(input_var + id.x()); + const float16x8_t gamma_vec = (input_gamma != nullptr) ? vld1q_f16(input_gamma + id.x()) : vdupq_n_f16(1.0); + const float16x8_t beta_vec = (input_beta != nullptr) ? vld1q_f16(input_beta + id.x()) : vdupq_n_f16(0.0); + // Calculate denominator + const float16x8_t denominator = vinvsqrtq_f16(vaddq_f16(var_vec, epsilon_vec)); + + // Calculate x bar and store results + const float16x8_t numerator = vsubq_f16(vld1q_f16(reinterpret_cast(input.ptr())), mean_vec); + const float16x8_t x_bar = vmulq_f16(numerator, denominator); + vst1q_f16(reinterpret_cast(output.ptr()), vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec))); + }, + input, output); +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +} + template -void NEBatchNormalizationLayerKernel::batch_normalization_fp32(const Window &window) +void NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw(const Window &window) { Iterator input(_input, window); Iterator output(_output, window); @@ -324,8 +360,50 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp32(const Window &win input, output); } +template +void NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc(const Window &window) +{ + Iterator input(_input, window); + Iterator output(_output, window); + + F activation_functor(_act_info); + + const auto input_mean = reinterpret_cast(_mean->ptr_to_element(Coordinates(0, 0))); + const auto input_var = reinterpret_cast(_var->ptr_to_element(Coordinates(0, 0))); + const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; + const auto input_beta = (_beta != nullptr) ? reinterpret_cast(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr; + + const float32x4_t epsilon_vec = vdupq_n_f32(_epsilon); + execute_window_loop(window, [&](const Coordinates & id) + { + // Conctruct vectors + const float32x4_t mean_vec = vld1q_f32(input_mean + id.x()); + const float32x4_t var_vec = vld1q_f32(input_var + id.x()); + const float32x4_t gamma_vec = (input_gamma != nullptr) ? vld1q_f32(input_gamma + id.x()) : vdupq_n_f32(1.0); + const float32x4_t beta_vec = (input_beta != nullptr) ? vld1q_f32(input_beta + id.x()) : vdupq_n_f32(0.0); + // Calculate denominator + const float32x4_t denominator = vinvsqrtq_f32(vaddq_f32(var_vec, epsilon_vec)); + + // Calculate x bar + const float32x4_t numerator = vsubq_f32(vld1q_f32(reinterpret_cast(input.ptr())), mean_vec); + const float32x4_t x_bar = vmulq_f32(numerator, denominator); + float32x4_t res = vmlaq_f32(beta_vec, x_bar, gamma_vec); + + // Perform fused activation + if(fused_activation) + { + activation_functor(res); + } + + // Store results + vst1q_f32(reinterpret_cast(output.ptr()), res); + }, + input, output); +} + void NEBatchNormalizationLayerKernel::configure_non_fused() { + const bool is_nhwc = _input->info()->data_layout() == DataLayout::NHWC; switch(_input->info()->data_type()) { case DataType::QS8: @@ -335,10 +413,11 @@ void NEBatchNormalizationLayerKernel::configure_non_fused() _func = &NEBatchNormalizationLayerKernel::batch_normalization_qs16; break; case DataType::F16: - _func = &NEBatchNormalizationLayerKernel::batch_normalization_fp16; + _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc : &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw; break; case DataType::F32: - _func = &NEBatchNormalizationLayerKernel::batch_normalization_fp32>; + _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc> : + &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw>; break; default: ARM_COMPUTE_ERROR("Element size not supported"); @@ -348,18 +427,25 @@ void NEBatchNormalizationLayerKernel::configure_non_fused() void NEBatchNormalizationLayerKernel::configure_fused() { - // Fused Batched Normalization with activation functions : FP32 - static std::map bn_fused_map_f32 = + // NCHW Fused Batched Normalization with activation functions : FP32 + static std::map bn_fused_map_f32_nchw = + { + { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw> }, + { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw> }, + { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw> } + }; + // NHWC Fused Batched Normalization with activation functions : FP32 + static std::map bn_fused_map_f32_nhwc = { - { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32> }, - { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32> }, - { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32> } + { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc> }, + { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc> }, + { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc> } }; switch(_input->info()->data_type()) { case DataType::F32: - _func = bn_fused_map_f32[_act_info.activation()]; + _func = (_input->info()->data_layout() == DataLayout::NHWC) ? bn_fused_map_f32_nhwc[_act_info.activation()] : bn_fused_map_f32_nchw[_act_info.activation()]; break; default: ARM_COMPUTE_ERROR("Element size not supported"); -- cgit v1.2.1