aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-11-20 18:38:29 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2018-11-22 13:22:24 +0000
commit2897e61e8fe04aaf95540f4525c3dd3f7f46ebfa (patch)
treea887648225c485ddf06363170d6111290b6111eb /src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp
parent303f0dbebf631b3db00d9d64e71018abbbe9d4fe (diff)
downloadComputeLibrary-2897e61e8fe04aaf95540f4525c3dd3f7f46ebfa.tar.gz
COMPMID-1645 NEL2Normalization for FP32/FP16 & NHWC
Change-Id: I29e35024e29781a6b943b568abec9c73649215e6
Diffstat (limited to 'src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp139
1 files changed, 127 insertions, 12 deletions
diff --git a/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp b/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp
index ed037832af..cda041de66 100644
--- a/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp
@@ -32,15 +32,20 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include <arm_neon.h>
#include <cmath>
-using namespace arm_compute;
-
+namespace arm_compute
+{
namespace
{
+template <typename T, int S>
void l2_normalize_X(const ITensor *in, const ITensor *sum, ITensor *out, float epsilon, const Window &window)
{
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
Window window_sum(window);
window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
@@ -53,30 +58,97 @@ void l2_normalize_X(const ITensor *in, const ITensor *sum, ITensor *out, float e
Iterator sum_it(sum, sum_slice);
Iterator output_it(out, in_slice);
- const float sum_value = *reinterpret_cast<const float *>(sum_it.ptr());
- const float32x4_t vec_normalize_value = vdupq_n_f32(1.f / std::sqrt(std::max(sum_value, epsilon)));
+ const auto sum_value = *reinterpret_cast<const T *>(sum_it.ptr());
+ const auto vec_normalize_value = wrapper::vdup_n(static_cast<T>(1.f / std::sqrt(std::max(sum_value, static_cast<T>(epsilon)))), ExactTagType{});
execute_window_loop(in_slice, [&](const Coordinates & id)
{
- const auto in_ptr = reinterpret_cast<const float *>(input_it.ptr());
- const auto out_ptr = reinterpret_cast<float *>(output_it.ptr());
+ const auto in_ptr = reinterpret_cast<const T *>(input_it.ptr());
+ const auto out_ptr = reinterpret_cast<T *>(output_it.ptr());
- vst1q_f32(out_ptr, vmulq_f32(vld1q_f32(in_ptr), vec_normalize_value));
+ wrapper::vstore(out_ptr, wrapper::vmul(wrapper::vloadq(in_ptr), vec_normalize_value));
},
input_it, output_it);
}
while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
}
+template <typename T, int S>
+void l2_normalize_Y(const ITensor *in, const ITensor *sum, ITensor *out, float epsilon, const Window &window)
+{
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+ Window window_sum(window);
+ window_sum.set(Window::DimY, Window::Dimension(0, 0, 0));
+
+ Window in_slice = window.first_slice_window_2D();
+ Window sum_slice = window_sum.first_slice_window_2D();
+
+ do
+ {
+ Iterator input_it(in, in_slice);
+ Iterator sum_it(sum, sum_slice);
+ Iterator output_it(out, in_slice);
+
+ auto eps = wrapper::vdup_n(static_cast<T>(epsilon), ExactTagType{});
+
+ execute_window_loop(in_slice, [&](const Coordinates & id)
+ {
+ const auto in_ptr = reinterpret_cast<const T *>(input_it.ptr());
+ const auto sum_ptr = reinterpret_cast<const T *>(sum_it.ptr());
+ const auto out_ptr = reinterpret_cast<T *>(output_it.ptr());
+
+ const auto vec_normalize_value = wrapper::vinvsqrt(wrapper::vmax(wrapper::vloadq(sum_ptr), eps));
+ wrapper::vstore(out_ptr, wrapper::vmul(wrapper::vloadq(in_ptr), vec_normalize_value));
+ },
+ input_it, sum_it, output_it);
+ }
+ while(window.slide_window_slice_2D(in_slice) && window.slide_window_slice_2D(sum_slice));
+}
+
+template <typename T, int S>
+void l2_normalize_Z(const ITensor *in, const ITensor *sum, ITensor *out, float epsilon, const Window &window)
+{
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+ Window window_sum(window);
+ window_sum.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Window in_slice = window.first_slice_window_3D();
+ Window sum_slice = window_sum.first_slice_window_3D();
+
+ do
+ {
+ Iterator input_it(in, in_slice);
+ Iterator sum_it(sum, sum_slice);
+ Iterator output_it(out, in_slice);
+
+ auto eps = wrapper::vdup_n(static_cast<T>(epsilon), ExactTagType{});
+
+ execute_window_loop(in_slice, [&](const Coordinates & id)
+ {
+ const auto in_ptr = reinterpret_cast<const T *>(input_it.ptr());
+ const auto sum_ptr = reinterpret_cast<const T *>(sum_it.ptr());
+ const auto out_ptr = reinterpret_cast<T *>(output_it.ptr());
+
+ const auto vec_normalize_value = wrapper::vinvsqrt(wrapper::vmax(wrapper::vloadq(sum_ptr), eps));
+ wrapper::vstore(out_ptr, wrapper::vmul(wrapper::vloadq(in_ptr), vec_normalize_value));
+ },
+ input_it, sum_it, output_it);
+ }
+ while(window.slide_window_slice_3D(in_slice) && window.slide_window_slice_3D(sum_slice));
+}
+
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, unsigned int axis, float epsilon)
{
ARM_COMPUTE_UNUSED(epsilon);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, sum, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() != DataLayout::NCHW);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 0, "Unsupported normalization axis, Supported axis is 0");
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 2, "Axis greater than 2 is not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Normalization axis greater than max number of dimensions");
// Reduce shape on axis
@@ -89,7 +161,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *sum, cons
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(input->tensor_shape(), output->tensor_shape());
- ARM_COMPUTE_RETURN_ERROR_ON(output->data_layout() != DataLayout::NCHW);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
}
return Status{};
@@ -158,9 +230,52 @@ void NEL2NormalizeLayerKernel::run(const Window &window, const ThreadInfo &info)
switch(_axis)
{
case 0:
- l2_normalize_X(_input, _sum, _output, _epsilon, window);
+ switch(_input->info()->data_type())
+ {
+ case DataType::F32:
+ l2_normalize_X<float, 4>(_input, _sum, _output, _epsilon, window);
+ break;
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ l2_normalize_X<float16_t, 8>(_input, _sum, _output, _epsilon, window);
+ break;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+ break;
+ case 1:
+ switch(_input->info()->data_type())
+ {
+ case DataType::F32:
+ l2_normalize_Y<float, 4>(_input, _sum, _output, _epsilon, window);
+ break;
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ l2_normalize_Y<float16_t, 8>(_input, _sum, _output, _epsilon, window);
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+ break;
+ case 2:
+ switch(_input->info()->data_type())
+ {
+ case DataType::F32:
+ l2_normalize_Z<float, 4>(_input, _sum, _output, _epsilon, window);
+ break;
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ l2_normalize_Z<float16_t, 8>(_input, _sum, _output, _epsilon, window);
+ break;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
break;
default:
ARM_COMPUTE_ERROR("Unsupported normalization axis");
}
}
+} // namespace arm_compute