diff options
Diffstat (limited to 'src/core')
-rw-r--r-- | src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp | 62 |
1 files changed, 50 insertions, 12 deletions
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp index ac1fc393c4..683d48b030 100644 --- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp @@ -45,13 +45,11 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT { ARM_COMPUTE_UNUSED(epsilon); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, - DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); if(act_info.enabled()) { ActivationLayerInfo::ActivationFunction act = act_info.activation(); - ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON(act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::RELU && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU); @@ -102,16 +100,16 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen } } //namespace -template <bool fused_activation> +template <bool fused_activation, typename F> void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw(const Window &window) { - static_assert(!fused_activation, "Activation is not supported for FP16"); - ARM_COMPUTE_UNUSED(window); #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC Iterator input(_input, window); Iterator output(_output, window); + F activation_functor(_act_info); + // Hold information about the current feature map we are iterating. // Only compute denominator and NEON vectors once per feature map. int slice = -1; @@ -151,22 +149,30 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw(const Window // Calculate x bar and store results const float16x8_t numerator = vsubq_f16(vld1q_f16(reinterpret_cast<const float16_t *>(input.ptr())), mean_vec); const float16x8_t x_bar = vmulq_f16(numerator, denominator); - vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec))); + float16x8_t res = vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec)); + + // Perform fused activation + if(fused_activation) + { + activation_functor(res); + } + + vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), res); }, input, output); #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ } -template <bool fused_activation> +template <bool fused_activation, typename F> void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc(const Window &window) { - static_assert(!fused_activation, "Activation is not supported for FP16"); - ARM_COMPUTE_UNUSED(window); #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC Iterator input(_input, window); Iterator output(_output, window); + F activation_functor(_act_info); + const auto input_mean = reinterpret_cast<const float16_t *>(_mean->ptr_to_element(Coordinates(0, 0))); const auto input_var = reinterpret_cast<const float16_t *>(_var->ptr_to_element(Coordinates(0, 0))); const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const float16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; @@ -186,7 +192,15 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc(const Window // Calculate x bar and store results const float16x8_t numerator = vsubq_f16(vld1q_f16(reinterpret_cast<const float16_t *>(input.ptr())), mean_vec); const float16x8_t x_bar = vmulq_f16(numerator, denominator); - vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec))); + float16x8_t res = vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec)); + + // Perform fused activation + if(fused_activation) + { + activation_functor(res); + } + + vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), res); }, input, output); #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ @@ -299,9 +313,12 @@ void NEBatchNormalizationLayerKernel::configure_non_fused() const bool is_nhwc = _input->info()->data_layout() == DataLayout::NHWC; switch(_input->info()->data_type()) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<false> : &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<false>; + _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<false, ::detail::dummy<float16_t, 8>> : + &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<false, ::detail::dummy<float16_t, 8>>; break; +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F32: _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc<false, ::detail::dummy<float, 4>> : &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw<false, ::detail::dummy<float, 4>>; @@ -328,9 +345,30 @@ void NEBatchNormalizationLayerKernel::configure_fused() { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc<true, ::detail::brelu<float, 4>> }, { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc<true, ::detail::lubrelu<float, 4>> } }; +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + // NCHW Fused Batched Normalization with activation functions : FP16 + static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f16_nchw = + { + { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<true, ::detail::relu<float16_t, 8>> }, + { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<true, ::detail::brelu<float16_t, 8>> }, + { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<true, ::detail::lubrelu<float16_t, 8>> } + }; + // NHWC Fused Batched Normalization with activation functions : FP16 + static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f16_nhwc = + { + { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<true, ::detail::relu<float16_t, 8>> }, + { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<true, ::detail::brelu<float16_t, 8>> }, + { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<true, ::detail::lubrelu<float16_t, 8>> } + }; +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC switch(_input->info()->data_type()) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + _func = (_input->info()->data_layout() == DataLayout::NHWC) ? bn_fused_map_f16_nhwc[_act_info.activation()] : bn_fused_map_f16_nchw[_act_info.activation()]; + break; +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F32: _func = (_input->info()->data_layout() == DataLayout::NHWC) ? bn_fused_map_f32_nhwc[_act_info.activation()] : bn_fused_map_f32_nchw[_act_info.activation()]; break; |