aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/select/generic/neon/impl.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/select/generic/neon/impl.h')
-rw-r--r--src/cpu/kernels/select/generic/neon/impl.h111
1 files changed, 66 insertions, 45 deletions
diff --git a/src/cpu/kernels/select/generic/neon/impl.h b/src/cpu/kernels/select/generic/neon/impl.h
index 6a6d9969f8..7ce640b6ff 100644
--- a/src/cpu/kernels/select/generic/neon/impl.h
+++ b/src/cpu/kernels/select/generic/neon/impl.h
@@ -25,6 +25,7 @@
#define ACL_SRC_CPU_KERNELS_SELECT_GENERIC_NEON_IMPL_H
#include "arm_compute/core/TensorInfo.h"
+
#include "src/core/NEON/NEAsymm.h"
#include "src/cpu/kernels/select/generic/neon/impl.h"
@@ -37,8 +38,16 @@ namespace arm_compute
namespace cpu
{
template <typename ScalarType, typename VectorType>
-void select_op(const ITensor *cond, const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
- const int window_step_x, const int window_start_x, const int window_end_x, const int limit, VectorType (*condition_conversion)(const uint8_t *))
+void select_op(const ITensor *cond,
+ const ITensor *in1,
+ const ITensor *in2,
+ ITensor *out,
+ const Window &window,
+ const int window_step_x,
+ const int window_start_x,
+ const int window_end_x,
+ const int limit,
+ VectorType (*condition_conversion)(const uint8_t *))
{
Window win = window;
win.set(Window::DimX, Window::Dimension(0, 1, 1));
@@ -48,30 +57,32 @@ void select_op(const ITensor *cond, const ITensor *in1, const ITensor *in2, ITen
Iterator input2(in2, win);
Iterator output(out, win);
- execute_window_loop(win, [&](const Coordinates &)
- {
- auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr());
- const auto condition_ptr = reinterpret_cast<const uint8_t *>(condition.ptr());
- const auto input1_ptr = reinterpret_cast<const ScalarType *>(input1.ptr());
- const auto input2_ptr = reinterpret_cast<const ScalarType *>(input2.ptr());
-
- int x = window_start_x;
- for(; x <= limit; x += window_step_x)
+ execute_window_loop(
+ win,
+ [&](const Coordinates &)
{
- const auto c = (*condition_conversion)(condition_ptr + x);
- const auto a = wrapper::vloadq(input1_ptr + x);
- const auto b = wrapper::vloadq(input2_ptr + x);
- wrapper::vstore(output_ptr + x, wrapper::vbsl(c, a, b));
- }
- for(; x < window_end_x; ++x)
- {
- const auto c = *(condition_ptr + x);
- const auto a = *(input1_ptr + x);
- const auto b = *(input2_ptr + x);
- *(output_ptr + x) = static_cast<bool>(c) ? a : b;
- }
- },
- condition, input1, input2, output);
+ auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr());
+ const auto condition_ptr = reinterpret_cast<const uint8_t *>(condition.ptr());
+ const auto input1_ptr = reinterpret_cast<const ScalarType *>(input1.ptr());
+ const auto input2_ptr = reinterpret_cast<const ScalarType *>(input2.ptr());
+
+ int x = window_start_x;
+ for (; x <= limit; x += window_step_x)
+ {
+ const auto c = (*condition_conversion)(condition_ptr + x);
+ const auto a = wrapper::vloadq(input1_ptr + x);
+ const auto b = wrapper::vloadq(input2_ptr + x);
+ wrapper::vstore(output_ptr + x, wrapper::vbsl(c, a, b));
+ }
+ for (; x < window_end_x; ++x)
+ {
+ const auto c = *(condition_ptr + x);
+ const auto a = *(input1_ptr + x);
+ const auto b = *(input2_ptr + x);
+ *(output_ptr + x) = static_cast<bool>(c) ? a : b;
+ }
+ },
+ condition, input1, input2, output);
}
template <typename ScalarType, typename VectorType>
@@ -81,11 +92,14 @@ void select_op_8(const ITensor *cond, const ITensor *in1, const ITensor *in2, IT
const auto window_start_x = static_cast<int>(window.x().start());
const auto window_end_x = static_cast<int>(window.x().end());
- select_op<ScalarType, VectorType>(cond, in1, in2, out, window, window_step_x, window_start_x, window_end_x, window_end_x - window_step_x, [](const uint8_t *condition_ptr) -> VectorType
- {
- static const auto zero = wrapper::vdup_n(static_cast<uint8_t>(0), arm_compute::wrapper::traits::vector_128_tag());
- return wrapper::vcgt(wrapper::vloadq(condition_ptr), zero);
- });
+ select_op<ScalarType, VectorType>(
+ cond, in1, in2, out, window, window_step_x, window_start_x, window_end_x, window_end_x - window_step_x,
+ [](const uint8_t *condition_ptr) -> VectorType
+ {
+ static const auto zero =
+ wrapper::vdup_n(static_cast<uint8_t>(0), arm_compute::wrapper::traits::vector_128_tag());
+ return wrapper::vcgt(wrapper::vloadq(condition_ptr), zero);
+ });
}
template <typename ScalarType, typename VectorType>
@@ -95,11 +109,14 @@ void select_op_16(const ITensor *cond, const ITensor *in1, const ITensor *in2, I
const auto window_start_x = static_cast<int>(window.x().start());
const auto window_end_x = static_cast<int>(window.x().end());
- select_op<ScalarType, VectorType>(cond, in1, in2, out, window, window_step_x, window_start_x, window_end_x, window_end_x - window_step_x, [](const uint8_t *condition_ptr) -> VectorType
- {
- static const auto zero = wrapper::vdup_n(static_cast<uint16_t>(0), arm_compute::wrapper::traits::vector_128_tag());
- return wrapper::vcgt(wrapper::vmovl(wrapper::vload(condition_ptr)), zero);
- });
+ select_op<ScalarType, VectorType>(
+ cond, in1, in2, out, window, window_step_x, window_start_x, window_end_x, window_end_x - window_step_x,
+ [](const uint8_t *condition_ptr) -> VectorType
+ {
+ static const auto zero =
+ wrapper::vdup_n(static_cast<uint16_t>(0), arm_compute::wrapper::traits::vector_128_tag());
+ return wrapper::vcgt(wrapper::vmovl(wrapper::vload(condition_ptr)), zero);
+ });
}
template <typename ScalarType, typename VectorType>
@@ -109,15 +126,19 @@ void select_op_32(const ITensor *cond, const ITensor *in1, const ITensor *in2, I
const auto window_start_x = static_cast<int>(window.x().start());
const auto window_end_x = static_cast<int>(window.x().end());
- select_op<ScalarType, VectorType>(cond, in1, in2, out, window, window_step_x, window_start_x, window_end_x, window_end_x - window_step_x, [](const uint8_t *condition_ptr) -> VectorType
- {
- static const auto zero = wrapper::vdup_n(static_cast<uint32_t>(0), arm_compute::wrapper::traits::vector_128_tag());
- return wrapper::vcgt(wrapper::vmovl(wrapper::vgetlow(wrapper::vmovl(wrapper::vload(condition_ptr)))), zero);
- });
+ select_op<ScalarType, VectorType>(
+ cond, in1, in2, out, window, window_step_x, window_start_x, window_end_x, window_end_x - window_step_x,
+ [](const uint8_t *condition_ptr) -> VectorType
+ {
+ static const auto zero =
+ wrapper::vdup_n(static_cast<uint32_t>(0), arm_compute::wrapper::traits::vector_128_tag());
+ return wrapper::vcgt(wrapper::vmovl(wrapper::vgetlow(wrapper::vmovl(wrapper::vload(condition_ptr)))), zero);
+ });
}
template <typename ScalarType>
-void select_op_not_same_rank(const ITensor *cond, const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
+void select_op_not_same_rank(
+ const ITensor *cond, const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
ARM_COMPUTE_UNUSED(window);
@@ -131,20 +152,20 @@ void select_op_not_same_rank(const ITensor *cond, const ITensor *in1, const ITen
int offset = 0;
const int step = 16 / in1->info()->element_size();
- for(int i = 0; i < outer_size; ++i)
+ for (int i = 0; i < outer_size; ++i)
{
int x = offset;
const auto input_ptr = static_cast<bool>(*(condition_ptr + i)) ? input1_ptr : input2_ptr;
- for(; x <= offset + inner_size - step; x += step)
+ for (; x <= offset + inner_size - step; x += step)
{
wrapper::vstore(output_ptr + x, wrapper::vloadq(input_ptr + x));
}
- if(x <= offset + inner_size - (step / 2))
+ if (x <= offset + inner_size - (step / 2))
{
wrapper::vstore(output_ptr + x, wrapper::vload(input_ptr + x));
x += step / 2;
}
- for(; x < offset + inner_size; ++x)
+ for (; x < offset + inner_size; ++x)
{
*(output_ptr + x) = *(input_ptr + x);
}