diff options
Diffstat (limited to 'src/cpu/kernels/l2normlayer/generic/neon/impl.h')
-rw-r--r-- | src/cpu/kernels/l2normlayer/generic/neon/impl.h | 96 |
1 files changed, 51 insertions, 45 deletions
diff --git a/src/cpu/kernels/l2normlayer/generic/neon/impl.h b/src/cpu/kernels/l2normlayer/generic/neon/impl.h index a06cdd33d3..6bd19299b7 100644 --- a/src/cpu/kernels/l2normlayer/generic/neon/impl.h +++ b/src/cpu/kernels/l2normlayer/generic/neon/impl.h @@ -26,8 +26,9 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/TensorInfo.h" -#include "src/core/NEON/wrapper/wrapper.h" + #include "src/core/common/Registrars.h" +#include "src/core/NEON/wrapper/wrapper.h" #include <cstddef> @@ -51,33 +52,36 @@ void l2_normalize_x(const ITensor *in, const ITensor *sum, ITensor *out, float e Iterator sum_it(sum, win_collapsed); Iterator output_it(out, win_collapsed); - execute_window_loop(win_collapsed, [&](const Coordinates &) - { - const auto in_ptr = reinterpret_cast<const T *>(input_it.ptr()); - const auto out_ptr = reinterpret_cast<T *>(output_it.ptr()); - - const T sum_value = *reinterpret_cast<const T *>(sum_it.ptr()); - const T norm_value = static_cast<T>(1.f) / std::sqrt(std::max(sum_value, static_cast<T>(epsilon))); - const auto vec_norm_value = wrapper::vdup_n(norm_value, ExactTagType{}); - - // Compute elements over vector steps - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + execute_window_loop( + win_collapsed, + [&](const Coordinates &) { - wrapper::vstore(out_ptr + x, wrapper::vmul(wrapper::vloadq(in_ptr + x), vec_norm_value)); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - out_ptr[x] = in_ptr[x] * norm_value; - } - }, - input_it, sum_it, output_it); + const auto in_ptr = reinterpret_cast<const T *>(input_it.ptr()); + const auto out_ptr = reinterpret_cast<T *>(output_it.ptr()); + + const T sum_value = *reinterpret_cast<const T *>(sum_it.ptr()); + const T norm_value = static_cast<T>(1.f) / std::sqrt(std::max(sum_value, static_cast<T>(epsilon))); + const auto vec_norm_value = wrapper::vdup_n(norm_value, ExactTagType{}); + + // Compute elements over vector steps + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + wrapper::vstore(out_ptr + x, wrapper::vmul(wrapper::vloadq(in_ptr + x), vec_norm_value)); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + out_ptr[x] = in_ptr[x] * norm_value; + } + }, + input_it, sum_it, output_it); } template <typename T, int S> -void l2_normalize_yz(const ITensor *in, const ITensor *sum, ITensor *out, float epsilon, const Window &window, size_t axis) +void l2_normalize_yz( + const ITensor *in, const ITensor *sum, ITensor *out, float epsilon, const Window &window, size_t axis) { using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; @@ -97,28 +101,30 @@ void l2_normalize_yz(const ITensor *in, const ITensor *sum, ITensor *out, float const auto vec_eps = wrapper::vdup_n(static_cast<T>(epsilon), ExactTagType{}); - execute_window_loop(win, [&](const Coordinates &) - { - const auto in_ptr = reinterpret_cast<const T *>(input_it.ptr()); - const auto sum_ptr = reinterpret_cast<const T *>(sum_it.ptr()); - const auto out_ptr = reinterpret_cast<T *>(output_it.ptr()); - - // Compute elements over vector steps - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vec_norm_value = wrapper::vinvsqrt(wrapper::vmax(wrapper::vloadq(sum_ptr + x), vec_eps)); - wrapper::vstore(out_ptr + x, wrapper::vmul(wrapper::vloadq(in_ptr + x), vec_norm_value)); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) + execute_window_loop( + win, + [&](const Coordinates &) { - const T norm_value = static_cast<T>(1.f) / std::sqrt(std::max(sum_ptr[x], static_cast<T>(epsilon))); - out_ptr[x] = in_ptr[x] * norm_value; - } - }, - input_it, sum_it, output_it); + const auto in_ptr = reinterpret_cast<const T *>(input_it.ptr()); + const auto sum_ptr = reinterpret_cast<const T *>(sum_it.ptr()); + const auto out_ptr = reinterpret_cast<T *>(output_it.ptr()); + + // Compute elements over vector steps + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vec_norm_value = wrapper::vinvsqrt(wrapper::vmax(wrapper::vloadq(sum_ptr + x), vec_eps)); + wrapper::vstore(out_ptr + x, wrapper::vmul(wrapper::vloadq(in_ptr + x), vec_norm_value)); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + const T norm_value = static_cast<T>(1.f) / std::sqrt(std::max(sum_ptr[x], static_cast<T>(epsilon))); + out_ptr[x] = in_ptr[x] * norm_value; + } + }, + input_it, sum_it, output_it); } } // namespace cpu } // namespace arm_compute |