aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h26
-rw-r--r--arm_compute/core/NEON/wrapper/intrinsics/cvt.h16
-rw-r--r--src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp112
-rw-r--r--src/runtime/NEON/functions/NEInstanceNormalizationLayer.cpp14
-rw-r--r--tests/validation/NEON/InstanceNormalizationLayer.cpp8
5 files changed, 121 insertions, 55 deletions
diff --git a/arm_compute/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h b/arm_compute/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h
index 43f6ec3f13..a5bd453ac7 100644
--- a/arm_compute/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -29,6 +29,7 @@
namespace arm_compute
{
class ITensor;
+struct InstanceNormalizationLayerKernelInfo;
/** Interface for performing an instance normalization */
class NEInstanceNormalizationLayerKernel : public INEKernel
@@ -52,26 +53,22 @@ public:
~NEInstanceNormalizationLayerKernel() = default;
/** Set the input and output tensors.
*
- * @param[in, out] input Source tensor. Data types supported: F16/F32. Data layout supported: NCHW
- * In case of @p output tensor = nullptr this tensor will store the result of the normalization.
- * @param[out] output Destination tensor. Data types and data layouts supported: same as @p input.
- * @param[in] gamma (Optional) The scale scalar value applied to the normalized tensor. Defaults to 1.0
- * @param[in] beta (Optional) The offset scalar value applied to the normalized tensor. Defaults to 0.0
- * @param[in] epsilon (Optional) Lower bound value for the normalization. Defaults to 1e-12
+ * @param[in, out] input Source tensor. Data types supported: F16/F32. Data layout supported: NCHW
+ * In case of @p output tensor = nullptr this tensor will store the result of the normalization.
+ * @param[out] output Destination tensor. Data types and data layouts supported: same as @p input.
+ * @param[in] info Kernel meta-data descriptor
*/
- void configure(ITensor *input, ITensor *output, float gamma = 1.0f, float beta = 0.0f, float epsilon = 1e-12f);
+ void configure(ITensor *input, ITensor *output, const InstanceNormalizationLayerKernelInfo &info);
/** Static function to check if given info will lead to a valid configuration of @ref NEInstanceNormalizationLayer.
*
- * @param[in] input Source tensor info. Data types supported: F16/F32. Data layout supported: NCHW
- * @param[in] output Destination tensor info. Data types and data layouts supported: same as @p input.
- * @param[in] gamma (Optional) The scale scalar value applied to the normalized tensor. Defaults to 1.0
- * @param[in] beta (Optional) The offset scalar value applied to the normalized tensor. Defaults to 0.0
- * @param[in] epsilon (Optional) Lower bound value for the normalization. Defaults to 1e-12
+ * @param[in] input Source tensor info. Data types supported: F16/F32. Data layout supported: NCHW
+ * @param[in] output Destination tensor info. Data types and data layouts supported: same as @p input.
+ * @param[in] info Kernel meta-data descriptor
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *output, float gamma = 1.0f, float beta = 0.0f, float epsilon = 1e-12f);
+ static Status validate(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info);
// Inherited methods overridden:
void run(const Window &window, const ThreadInfo &info) override;
@@ -93,6 +90,7 @@ private:
float _gamma;
float _beta;
float _epsilon;
+ bool _use_mixed_precision{ true };
};
} // namespace arm_compute
#endif /*ARM_COMPUTE_NEINSTANCENORMALIZATIONLAYERKERNEL_H */
diff --git a/arm_compute/core/NEON/wrapper/intrinsics/cvt.h b/arm_compute/core/NEON/wrapper/intrinsics/cvt.h
index de1261bdd0..6e79a92bc2 100644
--- a/arm_compute/core/NEON/wrapper/intrinsics/cvt.h
+++ b/arm_compute/core/NEON/wrapper/intrinsics/cvt.h
@@ -40,8 +40,24 @@ namespace wrapper
VCVT_TO_F32_IMPL(float32x4_t, uint32x4_t, vcvtq, f32, u32)
VCVT_TO_F32_IMPL(float32x4_t, int32x4_t, vcvtq, f32, s32)
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+VCVT_TO_F32_IMPL(float32x4_t, float16x4_t, vcvt, f32, f16)
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#undef VCVT_TO_F32_IMPL
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+#define VCVT_TO_F16_IMPL(ptype, vtype, prefix, postfix1, postfix2) \
+ template <typename T> \
+ inline typename std::enable_if<std::is_same<T, float16_t>::value, float16x4_t>::type \
+ vcvt(const vtype &a) \
+ { \
+ return prefix##_##postfix1##_##postfix2(a); \
+ }
+
+VCVT_TO_F16_IMPL(float16x4_t, float32x4_t, vcvt, f16, f32)
+#undef VCVT_TO_F16_IMPL
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
template <typename T>
inline typename std::enable_if<std::is_same<T, uint8_t>::value, uint32x4_t>::type
vcvt(const float32x4_t &a)
diff --git a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp
index 3f3817902f..f650d97c45 100644
--- a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/NEON/NEMath.h"
#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/TensorInfo.h"
@@ -40,7 +41,43 @@ namespace arm_compute
{
namespace
{
-template <typename T>
+template <typename InputType, typename AccType = InputType>
+void vector_float_sum(AccType &result, AccType &result_square, const InputType &inputs)
+{
+ result = wrapper::vadd(result, inputs);
+ result_square = wrapper::vadd(result_square, wrapper::vmul(inputs, inputs));
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+inline void vector_float_sum(float32x4_t &result, float32x4_t &result_square, const float16x8_t &inputs)
+{
+ vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgetlow(inputs)));
+ vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgethigh(inputs)));
+}
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+template <typename InputType, typename AccType = InputType>
+InputType vector_float_norm(const InputType &inputs, const AccType &vec_mean, const AccType &vec_multip, const AccType &vec_beta)
+{
+ return wrapper::vadd(wrapper::vmul(wrapper::vsub(inputs, vec_mean), vec_multip), vec_beta);
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+inline float16x8_t vector_float_norm(const float16x8_t &inputs, const float32x4_t &vec_mean, const float32x4_t &vec_multip, const float32x4_t &vec_beta)
+{
+ const auto input_low = wrapper::vcvt<float>(wrapper::vgetlow(inputs));
+ const auto input_high = wrapper::vcvt<float>(wrapper::vgethigh(inputs));
+ const auto result_low = wrapper::vcvt<float16_t>(vector_float_norm(input_low, vec_mean, vec_multip, vec_beta));
+ const auto result_high = wrapper::vcvt<float16_t>(vector_float_norm(input_high, vec_mean, vec_multip, vec_beta));
+ float16x8_t result = wrapper::vcombine(result_low, result_high);
+
+ return result;
+}
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+template <typename T, typename AccType = T>
void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window)
{
/** NEON vector tag type. */
@@ -65,39 +102,37 @@ void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, f
Iterator input_plane_it(input, win_plane);
Iterator output_plane_it(output, win_plane);
- auto sum_h_w = static_cast<T>(0.f);
- auto sum_squares_h_w = static_cast<T>(0.f);
+ auto sum_h_w = static_cast<AccType>(0.f);
+ auto sum_squares_h_w = static_cast<AccType>(0.f);
execute_window_loop(win_plane, [&](const Coordinates &)
{
const auto input_ptr = reinterpret_cast<const T *>(input_plane_it.ptr());
- auto vec_sum_h_w = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
- auto vec_sum_squares_h_w = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+ auto vec_sum_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
+ auto vec_sum_squares_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
// Compute S elements per iteration
int x = window.x().start();
for(; x <= (window.x().end() - window_step_x); x += window_step_x)
{
- auto vec_input_val = wrapper::vloadq(input_ptr + x);
- vec_sum_h_w = wrapper::vadd(vec_sum_h_w, vec_input_val);
- vec_sum_squares_h_w = wrapper::vadd(vec_sum_squares_h_w, wrapper::vmul(vec_input_val, vec_input_val));
+ auto vec_input_val = wrapper::vloadq(input_ptr + x);
+ vector_float_sum(vec_sum_h_w, vec_sum_squares_h_w, vec_input_val);
}
auto vec2_sum_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_h_w), wrapper::vgetlow(vec_sum_h_w));
auto vec2_sum_squares_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_squares_h_w), wrapper::vgetlow(vec_sum_squares_h_w));
- for(int i = 0; i < window_step_x / 4; ++i)
- {
- vec2_sum_h_w = wrapper::vpadd(vec2_sum_h_w, vec2_sum_h_w);
- vec2_sum_squares_h_w = wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w);
- }
+
+ vec2_sum_h_w = wrapper::vpadd(vec2_sum_h_w, vec2_sum_h_w);
+ vec2_sum_squares_h_w = wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w);
+
sum_h_w += wrapper::vgetlane(vec2_sum_h_w, 0);
sum_squares_h_w += wrapper::vgetlane(vec2_sum_squares_h_w, 0);
// Compute left-over elements
for(; x < window.x().end(); ++x)
{
- const auto value = *(input_ptr + x);
+ const auto value = static_cast<AccType>(*(input_ptr + x));
sum_h_w += value;
sum_squares_h_w += value * value;
}
@@ -108,9 +143,9 @@ void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, f
const auto var_h_w = sum_squares_h_w / elements_plane - mean_h_w * mean_h_w;
const auto multip_h_w = gamma / std::sqrt(var_h_w + epsilon);
- const auto vec_mean_h_w = wrapper::vdup_n(static_cast<T>(mean_h_w), ExactTagType{});
- const auto vec_multip_h_w = wrapper::vdup_n(static_cast<T>(multip_h_w), ExactTagType{});
- const auto vec_beta = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{});
+ const auto vec_mean_h_w = wrapper::vdup_n(static_cast<AccType>(mean_h_w), ExactTagType{});
+ const auto vec_multip_h_w = wrapper::vdup_n(static_cast<AccType>(multip_h_w), ExactTagType{});
+ const auto vec_beta = wrapper::vdup_n(static_cast<AccType>(beta), ExactTagType{});
execute_window_loop(win_plane, [&](const Coordinates &)
{
@@ -118,19 +153,20 @@ void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, f
auto output_ptr = reinterpret_cast<T *>(output_plane_it.ptr());
// Compute S elements per iteration
- int x = window.x().start();
- auto vec_val = wrapper::vdup_n(static_cast<T>(0.0f), ExactTagType{});
+ int x = window.x().start();
+ //auto vec_val = wrapper::vdup_n(static_cast<T>(0.0f), ExactTagType{});
for(; x <= (window.x().end() - window_step_x); x += window_step_x)
{
- vec_val = wrapper::vloadq(input_ptr + x);
- vec_val = wrapper::vadd(wrapper::vmul(wrapper::vsub(vec_val, vec_mean_h_w), vec_multip_h_w), vec_beta);
- wrapper::vstore(output_ptr + x, vec_val);
+ const auto vec_val = wrapper::vloadq(input_ptr + x);
+ const auto normalized_vec = vector_float_norm(vec_val, vec_mean_h_w, vec_multip_h_w, vec_beta);
+ wrapper::vstore(output_ptr + x, normalized_vec);
}
// Compute left-over elements
for(; x < window.x().end(); ++x)
{
- *(output_ptr + x) = ((*(input_ptr + x)) - mean_h_w) * multip_h_w + beta;
+ const auto val = static_cast<AccType>(*(input_ptr + x));
+ *(output_ptr + x) = static_cast<T>((val - mean_h_w) * multip_h_w + beta);
}
},
input_plane_it, output_plane_it);
@@ -179,17 +215,18 @@ NEInstanceNormalizationLayerKernel::NEInstanceNormalizationLayerKernel()
{
}
-void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *output, float gamma, float beta, float epsilon)
+void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *output, const InstanceNormalizationLayerKernelInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input);
- _input = input;
- _output = output == nullptr ? input : output;
- _gamma = gamma;
- _beta = beta;
- _epsilon = epsilon;
+ _input = input;
+ _output = output == nullptr ? input : output;
+ _gamma = info.gamma;
+ _beta = info.beta;
+ _epsilon = info.epsilon;
+ _use_mixed_precision = info.use_mixed_precision;
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), gamma, beta, epsilon));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), _gamma, _beta, _epsilon));
if(_input->info()->data_type() == DataType::F32)
{
@@ -198,7 +235,14 @@ void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *outp
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
else if(_input->info()->data_type() == DataType::F16)
{
- _func = &instance_normalization_nchw<float16_t>;
+ if(_use_mixed_precision)
+ {
+ _func = &instance_normalization_nchw<float16_t, float>;
+ }
+ else
+ {
+ _func = &instance_normalization_nchw<float16_t>;
+ }
}
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
else
@@ -213,9 +257,9 @@ void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *outp
INEKernel::configure(std::get<1>(win_config));
}
-Status NEInstanceNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
+Status NEInstanceNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, gamma, beta, epsilon));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, info.gamma, info.beta, info.epsilon));
ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), (output == nullptr ? input->clone().get() : output->clone().get()))));
return Status{};
}
diff --git a/src/runtime/NEON/functions/NEInstanceNormalizationLayer.cpp b/src/runtime/NEON/functions/NEInstanceNormalizationLayer.cpp
index 7c803babb3..57d01ff2d6 100644
--- a/src/runtime/NEON/functions/NEInstanceNormalizationLayer.cpp
+++ b/src/runtime/NEON/functions/NEInstanceNormalizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,6 +24,7 @@
#include "arm_compute/runtime/NEON/functions/NEInstanceNormalizationLayer.h"
#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
namespace arm_compute
@@ -35,7 +36,8 @@ NEInstanceNormalizationLayer::NEInstanceNormalizationLayer(std::shared_ptr<IMemo
void NEInstanceNormalizationLayer::configure(ITensor *input, ITensor *output, float gamma, float beta, float epsilon)
{
- const DataLayout data_layout = input->info()->data_layout();
+ const DataLayout data_layout = input->info()->data_layout();
+ const auto kernel_descriptor = InstanceNormalizationLayerKernelInfo{ gamma, beta, epsilon, true };
// Configure Kernels
_is_nchw = data_layout == DataLayout::NCHW;
@@ -49,7 +51,7 @@ void NEInstanceNormalizationLayer::configure(ITensor *input, ITensor *output, fl
_permute_input.configure(input, &_permuted_input, PermutationVector(1U, 2U, 0U));
_permuted_input.info()->set_data_layout(DataLayout::NCHW);
- _normalization_kernel.configure(&_permuted_input, &_permuted_output, gamma, beta, epsilon);
+ _normalization_kernel.configure(&_permuted_input, &_permuted_output, kernel_descriptor);
_permuted_output.info()->set_data_layout(DataLayout::NCHW);
_permute_output.configure(&_permuted_output, output != nullptr ? output : input, PermutationVector(2U, 0U, 1U));
@@ -58,13 +60,15 @@ void NEInstanceNormalizationLayer::configure(ITensor *input, ITensor *output, fl
}
else
{
- _normalization_kernel.configure(input, output, gamma, beta, epsilon);
+ _normalization_kernel.configure(input, output, kernel_descriptor);
}
}
Status NEInstanceNormalizationLayer::validate(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
{
- return NEInstanceNormalizationLayerKernel::validate(&input->clone()->set_data_layout(DataLayout::NCHW), &output->clone()->set_data_layout(DataLayout::NCHW), gamma, beta, epsilon);
+ return NEInstanceNormalizationLayerKernel::validate(&input->clone()->set_data_layout(DataLayout::NCHW),
+ &output->clone()->set_data_layout(DataLayout::NCHW),
+ InstanceNormalizationLayerKernelInfo{ gamma, beta, epsilon, true });
}
void NEInstanceNormalizationLayer::run()
diff --git a/tests/validation/NEON/InstanceNormalizationLayer.cpp b/tests/validation/NEON/InstanceNormalizationLayer.cpp
index 9d13a2b784..1073a7f6f7 100644
--- a/tests/validation/NEON/InstanceNormalizationLayer.cpp
+++ b/tests/validation/NEON/InstanceNormalizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,7 +45,11 @@ namespace
/** Tolerance for float operations */
AbsoluteTolerance<float> tolerance_f32(0.0015f);
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-AbsoluteTolerance<float> tolerance_f16(0.5f);
+// This precision is chosen based on the precision float16_t can provide
+// for the decimal numbers between 16 and 32 and decided based on multiple
+// times of execution of tests. Although, with randomly generated numbers
+// there is no gaurantee that this tolerance will be always large enough.
+AbsoluteTolerance<half> tolerance_f16(static_cast<half>(0.015625f));
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} // namespace