aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSang-Hoon Park <sang-hoon.park@arm.com>2020-07-16 14:26:16 +0100
committerSang-Hoon Park <sang-hoon.park@arm.com>2020-07-28 08:17:55 +0000
commit3351f2a454a11e15934fa8bfac635785783cf8e1 (patch)
tree991c4f863af9bca765f25e3c1a91bb7fc1b2a75b
parentad7515d231acb075a9585e52f257373b1a1b5d1f (diff)
downloadComputeLibrary-3351f2a454a11e15934fa8bfac635785783cf8e1.tar.gz
COMPMID-3575: Mixed preicision in NEInstanceNormalizationLayerKernel
In order to fix the issue caused by the limited precision of FP16. mixed precision (float accumulator) is introduced to NEInstanceNormalizationLayerKernel. Since the reference kernel is doing the mixed precision, currently mixed preicision computation is default when it is called from NEInstanceNormalizationLayer. - Make NEInstanceNormalizationLayerKernel use kernel descriptor to enable mixed precision computation - NEInstanceNormalizationLayer is modified to use the descriptor Change-Id: I7766622d715df054e303f9b441380b15b51f02b2 Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3589 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h26
-rw-r--r--arm_compute/core/NEON/wrapper/intrinsics/cvt.h16
-rw-r--r--src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp112
-rw-r--r--src/runtime/NEON/functions/NEInstanceNormalizationLayer.cpp14
-rw-r--r--tests/validation/NEON/InstanceNormalizationLayer.cpp8
5 files changed, 121 insertions, 55 deletions
diff --git a/arm_compute/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h b/arm_compute/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h
index 43f6ec3f13..a5bd453ac7 100644
--- a/arm_compute/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -29,6 +29,7 @@
namespace arm_compute
{
class ITensor;
+struct InstanceNormalizationLayerKernelInfo;
/** Interface for performing an instance normalization */
class NEInstanceNormalizationLayerKernel : public INEKernel
@@ -52,26 +53,22 @@ public:
~NEInstanceNormalizationLayerKernel() = default;
/** Set the input and output tensors.
*
- * @param[in, out] input Source tensor. Data types supported: F16/F32. Data layout supported: NCHW
- * In case of @p output tensor = nullptr this tensor will store the result of the normalization.
- * @param[out] output Destination tensor. Data types and data layouts supported: same as @p input.
- * @param[in] gamma (Optional) The scale scalar value applied to the normalized tensor. Defaults to 1.0
- * @param[in] beta (Optional) The offset scalar value applied to the normalized tensor. Defaults to 0.0
- * @param[in] epsilon (Optional) Lower bound value for the normalization. Defaults to 1e-12
+ * @param[in, out] input Source tensor. Data types supported: F16/F32. Data layout supported: NCHW
+ * In case of @p output tensor = nullptr this tensor will store the result of the normalization.
+ * @param[out] output Destination tensor. Data types and data layouts supported: same as @p input.
+ * @param[in] info Kernel meta-data descriptor
*/
- void configure(ITensor *input, ITensor *output, float gamma = 1.0f, float beta = 0.0f, float epsilon = 1e-12f);
+ void configure(ITensor *input, ITensor *output, const InstanceNormalizationLayerKernelInfo &info);
/** Static function to check if given info will lead to a valid configuration of @ref NEInstanceNormalizationLayer.
*
- * @param[in] input Source tensor info. Data types supported: F16/F32. Data layout supported: NCHW
- * @param[in] output Destination tensor info. Data types and data layouts supported: same as @p input.
- * @param[in] gamma (Optional) The scale scalar value applied to the normalized tensor. Defaults to 1.0
- * @param[in] beta (Optional) The offset scalar value applied to the normalized tensor. Defaults to 0.0
- * @param[in] epsilon (Optional) Lower bound value for the normalization. Defaults to 1e-12
+ * @param[in] input Source tensor info. Data types supported: F16/F32. Data layout supported: NCHW
+ * @param[in] output Destination tensor info. Data types and data layouts supported: same as @p input.
+ * @param[in] info Kernel meta-data descriptor
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *output, float gamma = 1.0f, float beta = 0.0f, float epsilon = 1e-12f);
+ static Status validate(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info);
// Inherited methods overridden:
void run(const Window &window, const ThreadInfo &info) override;
@@ -93,6 +90,7 @@ private:
float _gamma;
float _beta;
float _epsilon;
+ bool _use_mixed_precision{ true };
};
} // namespace arm_compute
#endif /*ARM_COMPUTE_NEINSTANCENORMALIZATIONLAYERKERNEL_H */
diff --git a/arm_compute/core/NEON/wrapper/intrinsics/cvt.h b/arm_compute/core/NEON/wrapper/intrinsics/cvt.h
index de1261bdd0..6e79a92bc2 100644
--- a/arm_compute/core/NEON/wrapper/intrinsics/cvt.h
+++ b/arm_compute/core/NEON/wrapper/intrinsics/cvt.h
@@ -40,8 +40,24 @@ namespace wrapper
VCVT_TO_F32_IMPL(float32x4_t, uint32x4_t, vcvtq, f32, u32)
VCVT_TO_F32_IMPL(float32x4_t, int32x4_t, vcvtq, f32, s32)
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+VCVT_TO_F32_IMPL(float32x4_t, float16x4_t, vcvt, f32, f16)
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#undef VCVT_TO_F32_IMPL
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+#define VCVT_TO_F16_IMPL(ptype, vtype, prefix, postfix1, postfix2) \
+ template <typename T> \
+ inline typename std::enable_if<std::is_same<T, float16_t>::value, float16x4_t>::type \
+ vcvt(const vtype &a) \
+ { \
+ return prefix##_##postfix1##_##postfix2(a); \
+ }
+
+VCVT_TO_F16_IMPL(float16x4_t, float32x4_t, vcvt, f16, f32)
+#undef VCVT_TO_F16_IMPL
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
template <typename T>
inline typename std::enable_if<std::is_same<T, uint8_t>::value, uint32x4_t>::type
vcvt(const float32x4_t &a)
diff --git a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp
index 3f3817902f..f650d97c45 100644
--- a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/NEON/NEMath.h"
#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/TensorInfo.h"
@@ -40,7 +41,43 @@ namespace arm_compute
{
namespace
{
-template <typename T>
+template <typename InputType, typename AccType = InputType>
+void vector_float_sum(AccType &result, AccType &result_square, const InputType &inputs)
+{
+ result = wrapper::vadd(result, inputs);
+ result_square = wrapper::vadd(result_square, wrapper::vmul(inputs, inputs));
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+inline void vector_float_sum(float32x4_t &result, float32x4_t &result_square, const float16x8_t &inputs)
+{
+ vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgetlow(inputs)));
+ vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgethigh(inputs)));
+}
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+template <typename InputType, typename AccType = InputType>
+InputType vector_float_norm(const InputType &inputs, const AccType &vec_mean, const AccType &vec_multip, const AccType &vec_beta)
+{
+ return wrapper::vadd(wrapper::vmul(wrapper::vsub(inputs, vec_mean), vec_multip), vec_beta);
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+inline float16x8_t vector_float_norm(const float16x8_t &inputs, const float32x4_t &vec_mean, const float32x4_t &vec_multip, const float32x4_t &vec_beta)
+{
+ const auto input_low = wrapper::vcvt<float>(wrapper::vgetlow(inputs));
+ const auto input_high = wrapper::vcvt<float>(wrapper::vgethigh(inputs));
+ const auto result_low = wrapper::vcvt<float16_t>(vector_float_norm(input_low, vec_mean, vec_multip, vec_beta));
+ const auto result_high = wrapper::vcvt<float16_t>(vector_float_norm(input_high, vec_mean, vec_multip, vec_beta));
+ float16x8_t result = wrapper::vcombine(result_low, result_high);
+
+ return result;
+}
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+template <typename T, typename AccType = T>
void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window)
{
/** NEON vector tag type. */
@@ -65,39 +102,37 @@ void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, f
Iterator input_plane_it(input, win_plane);
Iterator output_plane_it(output, win_plane);
- auto sum_h_w = static_cast<T>(0.f);
- auto sum_squares_h_w = static_cast<T>(0.f);
+ auto sum_h_w = static_cast<AccType>(0.f);
+ auto sum_squares_h_w = static_cast<AccType>(0.f);
execute_window_loop(win_plane, [&](const Coordinates &)
{
const auto input_ptr = reinterpret_cast<const T *>(input_plane_it.ptr());
- auto vec_sum_h_w = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
- auto vec_sum_squares_h_w = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+ auto vec_sum_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
+ auto vec_sum_squares_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
// Compute S elements per iteration
int x = window.x().start();
for(; x <= (window.x().end() - window_step_x); x += window_step_x)
{
- auto vec_input_val = wrapper::vloadq(input_ptr + x);
- vec_sum_h_w = wrapper::vadd(vec_sum_h_w, vec_input_val);
- vec_sum_squares_h_w = wrapper::vadd(vec_sum_squares_h_w, wrapper::vmul(vec_input_val, vec_input_val));
+ auto vec_input_val = wrapper::vloadq(input_ptr + x);
+ vector_float_sum(vec_sum_h_w, vec_sum_squares_h_w, vec_input_val);
}
auto vec2_sum_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_h_w), wrapper::vgetlow(vec_sum_h_w));
auto vec2_sum_squares_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_squares_h_w), wrapper::vgetlow(vec_sum_squares_h_w));
- for(int i = 0; i < window_step_x / 4; ++i)
- {
- vec2_sum_h_w = wrapper::vpadd(vec2_sum_h_w, vec2_sum_h_w);
- vec2_sum_squares_h_w = wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w);
- }
+
+ vec2_sum_h_w = wrapper::vpadd(vec2_sum_h_w, vec2_sum_h_w);
+ vec2_sum_squares_h_w = wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w);
+
sum_h_w += wrapper::vgetlane(vec2_sum_h_w, 0);
sum_squares_h_w += wrapper::vgetlane(vec2_sum_squares_h_w, 0);
// Compute left-over elements
for(; x < window.x().end(); ++x)
{
- const auto value = *(input_ptr + x);
+ const auto value = static_cast<AccType>(*(input_ptr + x));
sum_h_w += value;
sum_squares_h_w += value * value;
}
@@ -108,9 +143,9 @@ void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, f
const auto var_h_w = sum_squares_h_w / elements_plane - mean_h_w * mean_h_w;
const auto multip_h_w = gamma / std::sqrt(var_h_w + epsilon);
- const auto vec_mean_h_w = wrapper::vdup_n(static_cast<T>(mean_h_w), ExactTagType{});
- const auto vec_multip_h_w = wrapper::vdup_n(static_cast<T>(multip_h_w), ExactTagType{});
- const auto vec_beta = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{});
+ const auto vec_mean_h_w = wrapper::vdup_n(static_cast<AccType>(mean_h_w), ExactTagType{});
+ const auto vec_multip_h_w = wrapper::vdup_n(static_cast<AccType>(multip_h_w), ExactTagType{});
+ const auto vec_beta = wrapper::vdup_n(static_cast<AccType>(beta), ExactTagType{});
execute_window_loop(win_plane, [&](const Coordinates &)
{
@@ -118,19 +153,20 @@ void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, f
auto output_ptr = reinterpret_cast<T *>(output_plane_it.ptr());
// Compute S elements per iteration
- int x = window.x().start();
- auto vec_val = wrapper::vdup_n(static_cast<T>(0.0f), ExactTagType{});
+ int x = window.x().start();
+ //auto vec_val = wrapper::vdup_n(static_cast<T>(0.0f), ExactTagType{});
for(; x <= (window.x().end() - window_step_x); x += window_step_x)
{
- vec_val = wrapper::vloadq(input_ptr + x);
- vec_val = wrapper::vadd(wrapper::vmul(wrapper::vsub(vec_val, vec_mean_h_w), vec_multip_h_w), vec_beta);
- wrapper::vstore(output_ptr + x, vec_val);
+ const auto vec_val = wrapper::vloadq(input_ptr + x);
+ const auto normalized_vec = vector_float_norm(vec_val, vec_mean_h_w, vec_multip_h_w, vec_beta);
+ wrapper::vstore(output_ptr + x, normalized_vec);
}
// Compute left-over elements
for(; x < window.x().end(); ++x)
{
- *(output_ptr + x) = ((*(input_ptr + x)) - mean_h_w) * multip_h_w + beta;
+ const auto val = static_cast<AccType>(*(input_ptr + x));
+ *(output_ptr + x) = static_cast<T>((val - mean_h_w) * multip_h_w + beta);
}
},
input_plane_it, output_plane_it);
@@ -179,17 +215,18 @@ NEInstanceNormalizationLayerKernel::NEInstanceNormalizationLayerKernel()
{
}
-void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *output, float gamma, float beta, float epsilon)
+void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *output, const InstanceNormalizationLayerKernelInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input);
- _input = input;
- _output = output == nullptr ? input : output;
- _gamma = gamma;
- _beta = beta;
- _epsilon = epsilon;
+ _input = input;
+ _output = output == nullptr ? input : output;
+ _gamma = info.gamma;
+ _beta = info.beta;
+ _epsilon = info.epsilon;
+ _use_mixed_precision = info.use_mixed_precision;
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), gamma, beta, epsilon));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), _gamma, _beta, _epsilon));
if(_input->info()->data_type() == DataType::F32)
{
@@ -198,7 +235,14 @@ void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *outp
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
else if(_input->info()->data_type() == DataType::F16)
{
- _func = &instance_normalization_nchw<float16_t>;
+ if(_use_mixed_precision)
+ {
+ _func = &instance_normalization_nchw<float16_t, float>;
+ }
+ else
+ {
+ _func = &instance_normalization_nchw<float16_t>;
+ }
}
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
else
@@ -213,9 +257,9 @@ void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *outp
INEKernel::configure(std::get<1>(win_config));
}
-Status NEInstanceNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
+Status NEInstanceNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, gamma, beta, epsilon));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, info.gamma, info.beta, info.epsilon));
ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), (output == nullptr ? input->clone().get() : output->clone().get()))));
return Status{};
}
diff --git a/src/runtime/NEON/functions/NEInstanceNormalizationLayer.cpp b/src/runtime/NEON/functions/NEInstanceNormalizationLayer.cpp
index 7c803babb3..57d01ff2d6 100644
--- a/src/runtime/NEON/functions/NEInstanceNormalizationLayer.cpp
+++ b/src/runtime/NEON/functions/NEInstanceNormalizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,6 +24,7 @@
#include "arm_compute/runtime/NEON/functions/NEInstanceNormalizationLayer.h"
#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
namespace arm_compute
@@ -35,7 +36,8 @@ NEInstanceNormalizationLayer::NEInstanceNormalizationLayer(std::shared_ptr<IMemo
void NEInstanceNormalizationLayer::configure(ITensor *input, ITensor *output, float gamma, float beta, float epsilon)
{
- const DataLayout data_layout = input->info()->data_layout();
+ const DataLayout data_layout = input->info()->data_layout();
+ const auto kernel_descriptor = InstanceNormalizationLayerKernelInfo{ gamma, beta, epsilon, true };
// Configure Kernels
_is_nchw = data_layout == DataLayout::NCHW;
@@ -49,7 +51,7 @@ void NEInstanceNormalizationLayer::configure(ITensor *input, ITensor *output, fl
_permute_input.configure(input, &_permuted_input, PermutationVector(1U, 2U, 0U));
_permuted_input.info()->set_data_layout(DataLayout::NCHW);
- _normalization_kernel.configure(&_permuted_input, &_permuted_output, gamma, beta, epsilon);
+ _normalization_kernel.configure(&_permuted_input, &_permuted_output, kernel_descriptor);
_permuted_output.info()->set_data_layout(DataLayout::NCHW);
_permute_output.configure(&_permuted_output, output != nullptr ? output : input, PermutationVector(2U, 0U, 1U));
@@ -58,13 +60,15 @@ void NEInstanceNormalizationLayer::configure(ITensor *input, ITensor *output, fl
}
else
{
- _normalization_kernel.configure(input, output, gamma, beta, epsilon);
+ _normalization_kernel.configure(input, output, kernel_descriptor);
}
}
Status NEInstanceNormalizationLayer::validate(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
{
- return NEInstanceNormalizationLayerKernel::validate(&input->clone()->set_data_layout(DataLayout::NCHW), &output->clone()->set_data_layout(DataLayout::NCHW), gamma, beta, epsilon);
+ return NEInstanceNormalizationLayerKernel::validate(&input->clone()->set_data_layout(DataLayout::NCHW),
+ &output->clone()->set_data_layout(DataLayout::NCHW),
+ InstanceNormalizationLayerKernelInfo{ gamma, beta, epsilon, true });
}
void NEInstanceNormalizationLayer::run()
diff --git a/tests/validation/NEON/InstanceNormalizationLayer.cpp b/tests/validation/NEON/InstanceNormalizationLayer.cpp
index 9d13a2b784..1073a7f6f7 100644
--- a/tests/validation/NEON/InstanceNormalizationLayer.cpp
+++ b/tests/validation/NEON/InstanceNormalizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,7 +45,11 @@ namespace
/** Tolerance for float operations */
AbsoluteTolerance<float> tolerance_f32(0.0015f);
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-AbsoluteTolerance<float> tolerance_f16(0.5f);
+// This precision is chosen based on the precision float16_t can provide
+// for the decimal numbers between 16 and 32 and decided based on multiple
+// times of execution of tests. Although, with randomly generated numbers
+// there is no gaurantee that this tolerance will be always large enough.
+AbsoluteTolerance<half> tolerance_f16(static_cast<half>(0.015625f));
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} // namespace