diff options
author | Sang-Hoon Park <sang-hoon.park@arm.com> | 2020-07-16 14:26:16 +0100 |
---|---|---|
committer | Sang-Hoon Park <sang-hoon.park@arm.com> | 2020-07-28 08:17:55 +0000 |
commit | 3351f2a454a11e15934fa8bfac635785783cf8e1 (patch) | |
tree | 991c4f863af9bca765f25e3c1a91bb7fc1b2a75b /src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp | |
parent | ad7515d231acb075a9585e52f257373b1a1b5d1f (diff) | |
download | ComputeLibrary-3351f2a454a11e15934fa8bfac635785783cf8e1.tar.gz |
COMPMID-3575: Mixed preicision in NEInstanceNormalizationLayerKernel
In order to fix the issue caused by the limited precision of FP16.
mixed precision (float accumulator) is introduced to
NEInstanceNormalizationLayerKernel. Since the reference kernel
is doing the mixed precision, currently mixed preicision computation
is default when it is called from NEInstanceNormalizationLayer.
- Make NEInstanceNormalizationLayerKernel use kernel descriptor
to enable mixed precision computation
- NEInstanceNormalizationLayer is modified to use the descriptor
Change-Id: I7766622d715df054e303f9b441380b15b51f02b2
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3589
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp | 112 |
1 files changed, 78 insertions, 34 deletions
diff --git a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp index 3f3817902f..f650d97c45 100644 --- a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Arm Limited. + * Copyright (c) 2019-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -27,6 +27,7 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" +#include "arm_compute/core/KernelDescriptors.h" #include "arm_compute/core/NEON/NEMath.h" #include "arm_compute/core/NEON/wrapper/wrapper.h" #include "arm_compute/core/TensorInfo.h" @@ -40,7 +41,43 @@ namespace arm_compute { namespace { -template <typename T> +template <typename InputType, typename AccType = InputType> +void vector_float_sum(AccType &result, AccType &result_square, const InputType &inputs) +{ + result = wrapper::vadd(result, inputs); + result_square = wrapper::vadd(result_square, wrapper::vmul(inputs, inputs)); +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +inline void vector_float_sum(float32x4_t &result, float32x4_t &result_square, const float16x8_t &inputs) +{ + vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgetlow(inputs))); + vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgethigh(inputs))); +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +template <typename InputType, typename AccType = InputType> +InputType vector_float_norm(const InputType &inputs, const AccType &vec_mean, const AccType &vec_multip, const AccType &vec_beta) +{ + return wrapper::vadd(wrapper::vmul(wrapper::vsub(inputs, vec_mean), vec_multip), vec_beta); +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +inline float16x8_t vector_float_norm(const float16x8_t &inputs, const float32x4_t &vec_mean, const float32x4_t &vec_multip, const float32x4_t &vec_beta) +{ + const auto input_low = wrapper::vcvt<float>(wrapper::vgetlow(inputs)); + const auto input_high = wrapper::vcvt<float>(wrapper::vgethigh(inputs)); + const auto result_low = wrapper::vcvt<float16_t>(vector_float_norm(input_low, vec_mean, vec_multip, vec_beta)); + const auto result_high = wrapper::vcvt<float16_t>(vector_float_norm(input_high, vec_mean, vec_multip, vec_beta)); + float16x8_t result = wrapper::vcombine(result_low, result_high); + + return result; +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +template <typename T, typename AccType = T> void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window) { /** NEON vector tag type. */ @@ -65,39 +102,37 @@ void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, f Iterator input_plane_it(input, win_plane); Iterator output_plane_it(output, win_plane); - auto sum_h_w = static_cast<T>(0.f); - auto sum_squares_h_w = static_cast<T>(0.f); + auto sum_h_w = static_cast<AccType>(0.f); + auto sum_squares_h_w = static_cast<AccType>(0.f); execute_window_loop(win_plane, [&](const Coordinates &) { const auto input_ptr = reinterpret_cast<const T *>(input_plane_it.ptr()); - auto vec_sum_h_w = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); - auto vec_sum_squares_h_w = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); + auto vec_sum_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{}); + auto vec_sum_squares_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{}); // Compute S elements per iteration int x = window.x().start(); for(; x <= (window.x().end() - window_step_x); x += window_step_x) { - auto vec_input_val = wrapper::vloadq(input_ptr + x); - vec_sum_h_w = wrapper::vadd(vec_sum_h_w, vec_input_val); - vec_sum_squares_h_w = wrapper::vadd(vec_sum_squares_h_w, wrapper::vmul(vec_input_val, vec_input_val)); + auto vec_input_val = wrapper::vloadq(input_ptr + x); + vector_float_sum(vec_sum_h_w, vec_sum_squares_h_w, vec_input_val); } auto vec2_sum_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_h_w), wrapper::vgetlow(vec_sum_h_w)); auto vec2_sum_squares_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_squares_h_w), wrapper::vgetlow(vec_sum_squares_h_w)); - for(int i = 0; i < window_step_x / 4; ++i) - { - vec2_sum_h_w = wrapper::vpadd(vec2_sum_h_w, vec2_sum_h_w); - vec2_sum_squares_h_w = wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w); - } + + vec2_sum_h_w = wrapper::vpadd(vec2_sum_h_w, vec2_sum_h_w); + vec2_sum_squares_h_w = wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w); + sum_h_w += wrapper::vgetlane(vec2_sum_h_w, 0); sum_squares_h_w += wrapper::vgetlane(vec2_sum_squares_h_w, 0); // Compute left-over elements for(; x < window.x().end(); ++x) { - const auto value = *(input_ptr + x); + const auto value = static_cast<AccType>(*(input_ptr + x)); sum_h_w += value; sum_squares_h_w += value * value; } @@ -108,9 +143,9 @@ void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, f const auto var_h_w = sum_squares_h_w / elements_plane - mean_h_w * mean_h_w; const auto multip_h_w = gamma / std::sqrt(var_h_w + epsilon); - const auto vec_mean_h_w = wrapper::vdup_n(static_cast<T>(mean_h_w), ExactTagType{}); - const auto vec_multip_h_w = wrapper::vdup_n(static_cast<T>(multip_h_w), ExactTagType{}); - const auto vec_beta = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}); + const auto vec_mean_h_w = wrapper::vdup_n(static_cast<AccType>(mean_h_w), ExactTagType{}); + const auto vec_multip_h_w = wrapper::vdup_n(static_cast<AccType>(multip_h_w), ExactTagType{}); + const auto vec_beta = wrapper::vdup_n(static_cast<AccType>(beta), ExactTagType{}); execute_window_loop(win_plane, [&](const Coordinates &) { @@ -118,19 +153,20 @@ void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, f auto output_ptr = reinterpret_cast<T *>(output_plane_it.ptr()); // Compute S elements per iteration - int x = window.x().start(); - auto vec_val = wrapper::vdup_n(static_cast<T>(0.0f), ExactTagType{}); + int x = window.x().start(); + //auto vec_val = wrapper::vdup_n(static_cast<T>(0.0f), ExactTagType{}); for(; x <= (window.x().end() - window_step_x); x += window_step_x) { - vec_val = wrapper::vloadq(input_ptr + x); - vec_val = wrapper::vadd(wrapper::vmul(wrapper::vsub(vec_val, vec_mean_h_w), vec_multip_h_w), vec_beta); - wrapper::vstore(output_ptr + x, vec_val); + const auto vec_val = wrapper::vloadq(input_ptr + x); + const auto normalized_vec = vector_float_norm(vec_val, vec_mean_h_w, vec_multip_h_w, vec_beta); + wrapper::vstore(output_ptr + x, normalized_vec); } // Compute left-over elements for(; x < window.x().end(); ++x) { - *(output_ptr + x) = ((*(input_ptr + x)) - mean_h_w) * multip_h_w + beta; + const auto val = static_cast<AccType>(*(input_ptr + x)); + *(output_ptr + x) = static_cast<T>((val - mean_h_w) * multip_h_w + beta); } }, input_plane_it, output_plane_it); @@ -179,17 +215,18 @@ NEInstanceNormalizationLayerKernel::NEInstanceNormalizationLayerKernel() { } -void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *output, float gamma, float beta, float epsilon) +void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *output, const InstanceNormalizationLayerKernelInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(input); - _input = input; - _output = output == nullptr ? input : output; - _gamma = gamma; - _beta = beta; - _epsilon = epsilon; + _input = input; + _output = output == nullptr ? input : output; + _gamma = info.gamma; + _beta = info.beta; + _epsilon = info.epsilon; + _use_mixed_precision = info.use_mixed_precision; - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), gamma, beta, epsilon)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), _gamma, _beta, _epsilon)); if(_input->info()->data_type() == DataType::F32) { @@ -198,7 +235,14 @@ void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *outp #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC else if(_input->info()->data_type() == DataType::F16) { - _func = &instance_normalization_nchw<float16_t>; + if(_use_mixed_precision) + { + _func = &instance_normalization_nchw<float16_t, float>; + } + else + { + _func = &instance_normalization_nchw<float16_t>; + } } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC else @@ -213,9 +257,9 @@ void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *outp INEKernel::configure(std::get<1>(win_config)); } -Status NEInstanceNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon) +Status NEInstanceNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info) { - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, gamma, beta, epsilon)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, info.gamma, info.beta, info.epsilon)); ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), (output == nullptr ? input->clone().get() : output->clone().get())))); return Status{}; } |