diff options
Diffstat (limited to 'src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp')
-rw-r--r-- | src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp | 97 |
1 files changed, 50 insertions, 47 deletions
diff --git a/src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp b/src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp index 0522d6e277..11f6294a35 100644 --- a/src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp +++ b/src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp @@ -23,6 +23,7 @@ */ #include "src/cpu/kernels/meanstddevnorm/generic/neon/impl.h" + #include "src/core/NEON/wrapper/wrapper.h" namespace arm_compute @@ -45,60 +46,62 @@ void mean_stddev_normalization(ITensor *input, ITensor *output, float epsilon, c Iterator input_itr(input, win); Iterator output_itr(output, win); - execute_window_loop(win, [&](const Coordinates &) - { - int x = window_start_x; - auto in_ptr = reinterpret_cast<const ScalarType *>(input_itr.ptr()); - auto out_ptr = reinterpret_cast<ScalarType *>(output_itr.ptr()); + execute_window_loop( + win, + [&](const Coordinates &) + { + int x = window_start_x; + auto in_ptr = reinterpret_cast<const ScalarType *>(input_itr.ptr()); + auto out_ptr = reinterpret_cast<ScalarType *>(output_itr.ptr()); - auto sum_vec = wrapper::vdup_n(static_cast<ScalarType>(0.f), ExactTagType{}); - auto sum_sq_vec = wrapper::vdup_n(static_cast<ScalarType>(0.f), ExactTagType{}); + auto sum_vec = wrapper::vdup_n(static_cast<ScalarType>(0.f), ExactTagType{}); + auto sum_sq_vec = wrapper::vdup_n(static_cast<ScalarType>(0.f), ExactTagType{}); - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - auto data = wrapper::vloadq(in_ptr + x); - sum_vec = wrapper::vadd(sum_vec, data); - sum_sq_vec = wrapper::vadd(sum_sq_vec, wrapper::vmul(data, data)); - } + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + auto data = wrapper::vloadq(in_ptr + x); + sum_vec = wrapper::vadd(sum_vec, data); + sum_sq_vec = wrapper::vadd(sum_sq_vec, wrapper::vmul(data, data)); + } - auto sum_carry_res = wrapper::vpadd(wrapper::vgethigh(sum_vec), wrapper::vgetlow(sum_vec)); - auto sum_sq_carry_res = wrapper::vpadd(wrapper::vgethigh(sum_sq_vec), wrapper::vgetlow(sum_sq_vec)); - for(int i = 0; i < size / 4; ++i) - { - sum_carry_res = wrapper::vpadd(sum_carry_res, sum_carry_res); - sum_sq_carry_res = wrapper::vpadd(sum_sq_carry_res, sum_sq_carry_res); - } + auto sum_carry_res = wrapper::vpadd(wrapper::vgethigh(sum_vec), wrapper::vgetlow(sum_vec)); + auto sum_sq_carry_res = wrapper::vpadd(wrapper::vgethigh(sum_sq_vec), wrapper::vgetlow(sum_sq_vec)); + for (int i = 0; i < size / 4; ++i) + { + sum_carry_res = wrapper::vpadd(sum_carry_res, sum_carry_res); + sum_sq_carry_res = wrapper::vpadd(sum_sq_carry_res, sum_sq_carry_res); + } - auto sum = wrapper::vgetlane(sum_carry_res, 0); - auto sum_sq = wrapper::vgetlane(sum_sq_carry_res, 0); + auto sum = wrapper::vgetlane(sum_carry_res, 0); + auto sum_sq = wrapper::vgetlane(sum_sq_carry_res, 0); - // Compute left-over elements - for(; x < window_end_x; ++x) - { - ScalarType data = *(in_ptr + x); - sum += data; - sum_sq += data * data; - } + // Compute left-over elements + for (; x < window_end_x; ++x) + { + ScalarType data = *(in_ptr + x); + sum += data; + sum_sq += data * data; + } - ScalarType mean = sum / input->info()->dimension(0); - ScalarType var = (sum_sq / input->info()->dimension(0)) - (mean * mean); - ScalarType stddev_inv = 1.f / sqrt(var + epsilon); + ScalarType mean = sum / input->info()->dimension(0); + ScalarType var = (sum_sq / input->info()->dimension(0)) - (mean * mean); + ScalarType stddev_inv = 1.f / sqrt(var + epsilon); - auto mean_vec = wrapper::vdup_n(mean, ExactTagType{}); - auto stddev_inv_vec = wrapper::vdup_n(stddev_inv, ExactTagType{}); - for(x = window_start_x; x <= (window_end_x - window_step_x); x += window_step_x) - { - auto data = wrapper::vloadq(in_ptr + x); - auto res = wrapper::vmul(wrapper::vsub(data, mean_vec), stddev_inv_vec); - // Store results - wrapper::vstore(out_ptr + x, res); - } - for(; x < window_end_x; ++x) - { - *(out_ptr + x) = (*(in_ptr + x) - mean) * stddev_inv; - } - }, - input_itr, output_itr); + auto mean_vec = wrapper::vdup_n(mean, ExactTagType{}); + auto stddev_inv_vec = wrapper::vdup_n(stddev_inv, ExactTagType{}); + for (x = window_start_x; x <= (window_end_x - window_step_x); x += window_step_x) + { + auto data = wrapper::vloadq(in_ptr + x); + auto res = wrapper::vmul(wrapper::vsub(data, mean_vec), stddev_inv_vec); + // Store results + wrapper::vstore(out_ptr + x, res); + } + for (; x < window_end_x; ++x) + { + *(out_ptr + x) = (*(in_ptr + x) - mean) * stddev_inv; + } + }, + input_itr, output_itr); } template void mean_stddev_normalization<float, 4>(ITensor *input, ITensor *output, float epsilon, const Window &window); } // namespace cpu |