aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2018-03-01 16:56:48 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:54 +0000
commit0cbb927ac309e332ac6e6f1ab9170f041f0138ab (patch)
tree102d50dec9f741f04b1126ae03e6e491dda2d3ba /src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
parent82b51482479951cf133c223eb81aae291cb4d590 (diff)
downloadComputeLibrary-0cbb927ac309e332ac6e6f1ab9170f041f0138ab.tar.gz
COMPMID-804: Add NHWC data format support for NEON batch normalisation
Change-Id: I04892e7be3f5aa58cd95917a4f90a6b4ffcf6efc Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122897 Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp110
1 files changed, 98 insertions, 12 deletions
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
index d1bdfac2da..6be50fdb0d 100644
--- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
@@ -58,6 +58,7 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT
if(nullptr != output)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
}
@@ -77,7 +78,7 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, gamma);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma);
}
- ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0));
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)) != mean->dimension(0));
return Status{};
}
@@ -209,9 +210,9 @@ void NEBatchNormalizationLayerKernel::batch_normalization_qs16(const Window &win
}
template <bool fused_activation>
-void NEBatchNormalizationLayerKernel::batch_normalization_fp16(const Window &window)
+void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw(const Window &window)
{
- static_assert(!fused_activation, "Activation is not supported for QS8");
+ static_assert(!fused_activation, "Activation is not supported for FP16");
ARM_COMPUTE_UNUSED(window);
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -263,8 +264,43 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp16(const Window &win
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
}
+template <bool fused_activation>
+void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc(const Window &window)
+{
+ static_assert(!fused_activation, "Activation is not supported for FP16");
+
+ ARM_COMPUTE_UNUSED(window);
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ Iterator input(_input, window);
+ Iterator output(_output, window);
+
+ const auto input_mean = reinterpret_cast<const float16_t *>(_mean->ptr_to_element(Coordinates(0, 0)));
+ const auto input_var = reinterpret_cast<const float16_t *>(_var->ptr_to_element(Coordinates(0, 0)));
+ const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const float16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
+ const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const float16_t *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
+
+ const float16x8_t epsilon_vec = vdupq_n_f16(_epsilon);
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ // Conctruct vectors
+ const float16x8_t mean_vec = vld1q_f16(input_mean + id.x());
+ const float16x8_t var_vec = vld1q_f16(input_var + id.x());
+ const float16x8_t gamma_vec = (input_gamma != nullptr) ? vld1q_f16(input_gamma + id.x()) : vdupq_n_f16(1.0);
+ const float16x8_t beta_vec = (input_beta != nullptr) ? vld1q_f16(input_beta + id.x()) : vdupq_n_f16(0.0);
+ // Calculate denominator
+ const float16x8_t denominator = vinvsqrtq_f16(vaddq_f16(var_vec, epsilon_vec));
+
+ // Calculate x bar and store results
+ const float16x8_t numerator = vsubq_f16(vld1q_f16(reinterpret_cast<const float16_t *>(input.ptr())), mean_vec);
+ const float16x8_t x_bar = vmulq_f16(numerator, denominator);
+ vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec)));
+ },
+ input, output);
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+}
+
template <bool fused_activation, typename F>
-void NEBatchNormalizationLayerKernel::batch_normalization_fp32(const Window &window)
+void NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw(const Window &window)
{
Iterator input(_input, window);
Iterator output(_output, window);
@@ -324,8 +360,50 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp32(const Window &win
input, output);
}
+template <bool fused_activation, typename F>
+void NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc(const Window &window)
+{
+ Iterator input(_input, window);
+ Iterator output(_output, window);
+
+ F activation_functor(_act_info);
+
+ const auto input_mean = reinterpret_cast<const float *>(_mean->ptr_to_element(Coordinates(0, 0)));
+ const auto input_var = reinterpret_cast<const float *>(_var->ptr_to_element(Coordinates(0, 0)));
+ const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const float *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
+ const auto input_beta = (_beta != nullptr) ? reinterpret_cast<const float *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
+
+ const float32x4_t epsilon_vec = vdupq_n_f32(_epsilon);
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ // Conctruct vectors
+ const float32x4_t mean_vec = vld1q_f32(input_mean + id.x());
+ const float32x4_t var_vec = vld1q_f32(input_var + id.x());
+ const float32x4_t gamma_vec = (input_gamma != nullptr) ? vld1q_f32(input_gamma + id.x()) : vdupq_n_f32(1.0);
+ const float32x4_t beta_vec = (input_beta != nullptr) ? vld1q_f32(input_beta + id.x()) : vdupq_n_f32(0.0);
+ // Calculate denominator
+ const float32x4_t denominator = vinvsqrtq_f32(vaddq_f32(var_vec, epsilon_vec));
+
+ // Calculate x bar
+ const float32x4_t numerator = vsubq_f32(vld1q_f32(reinterpret_cast<const float *>(input.ptr())), mean_vec);
+ const float32x4_t x_bar = vmulq_f32(numerator, denominator);
+ float32x4_t res = vmlaq_f32(beta_vec, x_bar, gamma_vec);
+
+ // Perform fused activation
+ if(fused_activation)
+ {
+ activation_functor(res);
+ }
+
+ // Store results
+ vst1q_f32(reinterpret_cast<float *>(output.ptr()), res);
+ },
+ input, output);
+}
+
void NEBatchNormalizationLayerKernel::configure_non_fused()
{
+ const bool is_nhwc = _input->info()->data_layout() == DataLayout::NHWC;
switch(_input->info()->data_type())
{
case DataType::QS8:
@@ -335,10 +413,11 @@ void NEBatchNormalizationLayerKernel::configure_non_fused()
_func = &NEBatchNormalizationLayerKernel::batch_normalization_qs16<false>;
break;
case DataType::F16:
- _func = &NEBatchNormalizationLayerKernel::batch_normalization_fp16<false>;
+ _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<false> : &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<false>;
break;
case DataType::F32:
- _func = &NEBatchNormalizationLayerKernel::batch_normalization_fp32<false, ::detail::dummy<float, 4>>;
+ _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc<false, ::detail::dummy<float, 4>> :
+ &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw<false, ::detail::dummy<float, 4>>;
break;
default:
ARM_COMPUTE_ERROR("Element size not supported");
@@ -348,18 +427,25 @@ void NEBatchNormalizationLayerKernel::configure_non_fused()
void NEBatchNormalizationLayerKernel::configure_fused()
{
- // Fused Batched Normalization with activation functions : FP32
- static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f32 =
+ // NCHW Fused Batched Normalization with activation functions : FP32
+ static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f32_nchw =
+ {
+ { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw<true, ::detail::relu<float, 4>> },
+ { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw<true, ::detail::brelu<float, 4>> },
+ { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw<true, ::detail::lubrelu<float, 4>> }
+ };
+ // NHWC Fused Batched Normalization with activation functions : FP32
+ static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f32_nhwc =
{
- { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32<true, ::detail::relu<float, 4>> },
- { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32<true, ::detail::brelu<float, 4>> },
- { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32<true, ::detail::lubrelu<float, 4>> }
+ { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc<true, ::detail::relu<float, 4>> },
+ { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc<true, ::detail::brelu<float, 4>> },
+ { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc<true, ::detail::lubrelu<float, 4>> }
};
switch(_input->info()->data_type())
{
case DataType::F32:
- _func = bn_fused_map_f32[_act_info.activation()];
+ _func = (_input->info()->data_layout() == DataLayout::NHWC) ? bn_fused_map_f32_nhwc[_act_info.activation()] : bn_fused_map_f32_nchw[_act_info.activation()];
break;
default:
ARM_COMPUTE_ERROR("Element size not supported");