diff options
6 files changed, 74 insertions, 14 deletions
diff --git a/arm_compute/core/NEON/kernels/NEBatchNormalizationLayerKernel.h b/arm_compute/core/NEON/kernels/NEBatchNormalizationLayerKernel.h index 2a540c151b..092b1d9514 100644 --- a/arm_compute/core/NEON/kernels/NEBatchNormalizationLayerKernel.h +++ b/arm_compute/core/NEON/kernels/NEBatchNormalizationLayerKernel.h @@ -103,7 +103,7 @@ private: * * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()). */ - template <bool fused_activation> + template <bool fused_activation, typename F> void batch_normalization_fp16_nchw(const Window &window); /** Template function to run batch normalization on fp16 on tensors with NHWC format * @@ -111,7 +111,7 @@ private: * * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()). */ - template <bool fused_activation> + template <bool fused_activation, typename F> void batch_normalization_fp16_nhwc(const Window &window); /** Template function to run batch normalization on fp32 * diff --git a/arm_compute/core/NEON/wrapper/intrinsics/dup_n.h b/arm_compute/core/NEON/wrapper/intrinsics/dup_n.h index 1c07b4f3ff..4d9a7952c0 100644 --- a/arm_compute/core/NEON/wrapper/intrinsics/dup_n.h +++ b/arm_compute/core/NEON/wrapper/intrinsics/dup_n.h @@ -45,6 +45,9 @@ VDUP_N_IMPL(int16_t, int16x4_t, vdup_n, s16, traits::vector_64_tag) VDUP_N_IMPL(uint32_t, uint32x2_t, vdup_n, u32, traits::vector_64_tag) VDUP_N_IMPL(int32_t, int32x2_t, vdup_n, s32, traits::vector_64_tag) VDUP_N_IMPL(float, float32x2_t, vdup_n, f32, traits::vector_64_tag) +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +VDUP_N_IMPL(float16_t, float16x4_t, vdup_n, f16, traits::vector_64_tag) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC VDUP_N_IMPL(uint8_t, uint8x16_t, vdupq_n, u8, traits::vector_128_tag) VDUP_N_IMPL(int8_t, int8x16_t, vdupq_n, s8, traits::vector_128_tag) @@ -53,6 +56,9 @@ VDUP_N_IMPL(int16_t, int16x8_t, vdupq_n, s16, traits::vector_128_tag) VDUP_N_IMPL(uint32_t, uint32x4_t, vdupq_n, u32, traits::vector_128_tag) VDUP_N_IMPL(int32_t, int32x4_t, vdupq_n, s32, traits::vector_128_tag) VDUP_N_IMPL(float, float32x4_t, vdupq_n, f32, traits::vector_128_tag) +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +VDUP_N_IMPL(float16_t, float16x8_t, vdupq_n, f16, traits::vector_128_tag) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #undef VDUP_N_IMPL } // namespace wrapper diff --git a/arm_compute/core/NEON/wrapper/intrinsics/max.h b/arm_compute/core/NEON/wrapper/intrinsics/max.h index 1a8e95de87..05ed051c62 100644 --- a/arm_compute/core/NEON/wrapper/intrinsics/max.h +++ b/arm_compute/core/NEON/wrapper/intrinsics/max.h @@ -43,6 +43,9 @@ VMAX_IMPL(int16_t, int16x4_t, vmax, s16) VMAX_IMPL(uint32_t, uint32x2_t, vmax, u32) VMAX_IMPL(int32_t, int32x2_t, vmax, s32) VMAX_IMPL(float, float32x2_t, vmax, f32) +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +VMAX_IMPL(float16_t, float16x4_t, vmax, f16) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC VMAX_IMPL(uint8_t, uint8x16_t, vmaxq, u8) VMAX_IMPL(int8_t, int8x16_t, vmaxq, s8) @@ -51,6 +54,9 @@ VMAX_IMPL(int16_t, int16x8_t, vmaxq, s16) VMAX_IMPL(uint32_t, uint32x4_t, vmaxq, u32) VMAX_IMPL(int32_t, int32x4_t, vmaxq, s32) VMAX_IMPL(float, float32x4_t, vmaxq, f32) +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +VMAX_IMPL(float16_t, float16x8_t, vmaxq, f16) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #undef VMAX_IMPL } // namespace wrapper diff --git a/arm_compute/core/NEON/wrapper/intrinsics/min.h b/arm_compute/core/NEON/wrapper/intrinsics/min.h index ae79631190..5ea2068f24 100644 --- a/arm_compute/core/NEON/wrapper/intrinsics/min.h +++ b/arm_compute/core/NEON/wrapper/intrinsics/min.h @@ -43,6 +43,9 @@ VMIN_IMPL(int16_t, int16x4_t, vmin, s16) VMIN_IMPL(uint32_t, uint32x2_t, vmin, u32) VMIN_IMPL(int32_t, int32x2_t, vmin, s32) VMIN_IMPL(float, float32x2_t, vmin, f32) +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +VMIN_IMPL(float16_t, float16x4_t, vmin, f16) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC VMIN_IMPL(uint8_t, uint8x16_t, vminq, u8) VMIN_IMPL(int8_t, int8x16_t, vminq, s8) @@ -51,6 +54,9 @@ VMIN_IMPL(int16_t, int16x8_t, vminq, s16) VMIN_IMPL(uint32_t, uint32x4_t, vminq, u32) VMIN_IMPL(int32_t, int32x4_t, vminq, s32) VMIN_IMPL(float, float32x4_t, vminq, f32) +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +VMIN_IMPL(float16_t, float16x8_t, vminq, f16) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #undef VMIN_IMPL } // namespace wrapper diff --git a/arm_compute/core/NEON/wrapper/traits.h b/arm_compute/core/NEON/wrapper/traits.h index 495ddbb1af..5cd6086c0c 100644 --- a/arm_compute/core/NEON/wrapper/traits.h +++ b/arm_compute/core/NEON/wrapper/traits.h @@ -62,6 +62,10 @@ template <> struct neon_vector<uint64_t, 2>{ using type = uint64x2_t; using tag_ template <> struct neon_vector<int64_t, 2>{ using type = int64x2_t; using tag_type = vector_128_tag; }; template <> struct neon_vector<float_t, 2>{ using type = float32x2_t; using tag_type = vector_64_tag; }; template <> struct neon_vector<float_t, 4>{ using type = float32x4_t; using tag_type = vector_128_tag; }; +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> struct neon_vector<float16_t, 4>{ using type = float16x4_t; using tag_type = vector_64_tag; }; +template <> struct neon_vector<float16_t, 8>{ using type = float16x8_t; using tag_type = vector_128_tag; }; +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #endif /* DOXYGEN_SKIP_THIS */ /** Helper type template to get the type of a neon vector */ diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp index ac1fc393c4..683d48b030 100644 --- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp @@ -45,13 +45,11 @@ validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const IT { ARM_COMPUTE_UNUSED(epsilon); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, - DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); if(act_info.enabled()) { ActivationLayerInfo::ActivationFunction act = act_info.activation(); - ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON(act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::RELU && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && act != ActivationLayerInfo::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU); @@ -102,16 +100,16 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen } } //namespace -template <bool fused_activation> +template <bool fused_activation, typename F> void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw(const Window &window) { - static_assert(!fused_activation, "Activation is not supported for FP16"); - ARM_COMPUTE_UNUSED(window); #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC Iterator input(_input, window); Iterator output(_output, window); + F activation_functor(_act_info); + // Hold information about the current feature map we are iterating. // Only compute denominator and NEON vectors once per feature map. int slice = -1; @@ -151,22 +149,30 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw(const Window // Calculate x bar and store results const float16x8_t numerator = vsubq_f16(vld1q_f16(reinterpret_cast<const float16_t *>(input.ptr())), mean_vec); const float16x8_t x_bar = vmulq_f16(numerator, denominator); - vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec))); + float16x8_t res = vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec)); + + // Perform fused activation + if(fused_activation) + { + activation_functor(res); + } + + vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), res); }, input, output); #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ } -template <bool fused_activation> +template <bool fused_activation, typename F> void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc(const Window &window) { - static_assert(!fused_activation, "Activation is not supported for FP16"); - ARM_COMPUTE_UNUSED(window); #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC Iterator input(_input, window); Iterator output(_output, window); + F activation_functor(_act_info); + const auto input_mean = reinterpret_cast<const float16_t *>(_mean->ptr_to_element(Coordinates(0, 0))); const auto input_var = reinterpret_cast<const float16_t *>(_var->ptr_to_element(Coordinates(0, 0))); const auto input_gamma = (_gamma != nullptr) ? reinterpret_cast<const float16_t *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr; @@ -186,7 +192,15 @@ void NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc(const Window // Calculate x bar and store results const float16x8_t numerator = vsubq_f16(vld1q_f16(reinterpret_cast<const float16_t *>(input.ptr())), mean_vec); const float16x8_t x_bar = vmulq_f16(numerator, denominator); - vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec))); + float16x8_t res = vaddq_f16(beta_vec, vmulq_f16(x_bar, gamma_vec)); + + // Perform fused activation + if(fused_activation) + { + activation_functor(res); + } + + vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), res); }, input, output); #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ @@ -299,9 +313,12 @@ void NEBatchNormalizationLayerKernel::configure_non_fused() const bool is_nhwc = _input->info()->data_layout() == DataLayout::NHWC; switch(_input->info()->data_type()) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<false> : &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<false>; + _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<false, ::detail::dummy<float16_t, 8>> : + &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<false, ::detail::dummy<float16_t, 8>>; break; +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F32: _func = (is_nhwc) ? &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc<false, ::detail::dummy<float, 4>> : &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nchw<false, ::detail::dummy<float, 4>>; @@ -328,9 +345,30 @@ void NEBatchNormalizationLayerKernel::configure_fused() { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc<true, ::detail::brelu<float, 4>> }, { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp32_nhwc<true, ::detail::lubrelu<float, 4>> } }; +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + // NCHW Fused Batched Normalization with activation functions : FP16 + static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f16_nchw = + { + { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<true, ::detail::relu<float16_t, 8>> }, + { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<true, ::detail::brelu<float16_t, 8>> }, + { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nchw<true, ::detail::lubrelu<float16_t, 8>> } + }; + // NHWC Fused Batched Normalization with activation functions : FP16 + static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f16_nhwc = + { + { ActivationLayerInfo::ActivationFunction::RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<true, ::detail::relu<float16_t, 8>> }, + { ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<true, ::detail::brelu<float16_t, 8>> }, + { ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, &NEBatchNormalizationLayerKernel::batch_normalization_fp16_nhwc<true, ::detail::lubrelu<float16_t, 8>> } + }; +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC switch(_input->info()->data_type()) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + _func = (_input->info()->data_layout() == DataLayout::NHWC) ? bn_fused_map_f16_nhwc[_act_info.activation()] : bn_fused_map_f16_nchw[_act_info.activation()]; + break; +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F32: _func = (_input->info()->data_layout() == DataLayout::NHWC) ? bn_fused_map_f32_nhwc[_act_info.activation()] : bn_fused_map_f32_nchw[_act_info.activation()]; break; |