diff options
Diffstat (limited to 'src/cpu/kernels/elementwise_unary/generic/sve/impl.cpp')
-rw-r--r-- | src/cpu/kernels/elementwise_unary/generic/sve/impl.cpp | 44 |
1 files changed, 24 insertions, 20 deletions
diff --git a/src/cpu/kernels/elementwise_unary/generic/sve/impl.cpp b/src/cpu/kernels/elementwise_unary/generic/sve/impl.cpp index a948862906..5af534d9e7 100644 --- a/src/cpu/kernels/elementwise_unary/generic/sve/impl.cpp +++ b/src/cpu/kernels/elementwise_unary/generic/sve/impl.cpp @@ -24,6 +24,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/utils/misc/Traits.h" + #include "src/core/NEON/wrapper/intrinsics/intrinsics.h" namespace arm_compute @@ -31,9 +32,10 @@ namespace arm_compute namespace cpu { template <typename ScalarType, typename VectorType> -inline typename std::enable_if<utils::traits::is_floating_point<ScalarType>::value, VectorType>::type elementwise_op_sve_imp(svbool_t pg, ElementWiseUnary op, const VectorType &a) +inline typename std::enable_if<utils::traits::is_floating_point<ScalarType>::value, VectorType>::type +elementwise_op_sve_imp(svbool_t pg, ElementWiseUnary op, const VectorType &a) { - switch(op) + switch (op) { case ElementWiseUnary::RSQRT: return svinvsqrt(pg, a); @@ -55,9 +57,10 @@ inline typename std::enable_if<utils::traits::is_floating_point<ScalarType>::val } template <typename ScalarType, typename VectorType> -inline typename std::enable_if<std::is_integral<ScalarType>::value, VectorType>::type elementwise_op_sve_imp(svbool_t pg, ElementWiseUnary op, const VectorType &a) +inline typename std::enable_if<std::is_integral<ScalarType>::value, VectorType>::type +elementwise_op_sve_imp(svbool_t pg, ElementWiseUnary op, const VectorType &a) { - switch(op) + switch (op) { case ElementWiseUnary::NEG: return svneg_z(pg, a); @@ -81,23 +84,24 @@ void elementwise_sve_op(const ITensor *in, ITensor *out, const Window &window, E Iterator input(in, win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr()); - const auto input_ptr = reinterpret_cast<const ScalarType *>(input.ptr()); - int x = window_start_x; - - svbool_t pg = wrapper::svwhilelt<ScalarType>(x, window_end_x); - do + execute_window_loop( + win, + [&](const Coordinates &) { - const auto vin = svld1(pg, input_ptr + x); - svst1(pg, output_ptr + x, elementwise_op_sve_imp<ScalarType, decltype(vin)>(pg, op, vin)); - x += wrapper::svcnt<ScalarType>(); - pg = wrapper::svwhilelt<ScalarType>(x, window_end_x); - } - while(svptest_any(all_true_pg, pg)); - }, - input, output); + auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr()); + const auto input_ptr = reinterpret_cast<const ScalarType *>(input.ptr()); + int x = window_start_x; + + svbool_t pg = wrapper::svwhilelt<ScalarType>(x, window_end_x); + do + { + const auto vin = svld1(pg, input_ptr + x); + svst1(pg, output_ptr + x, elementwise_op_sve_imp<ScalarType, decltype(vin)>(pg, op, vin)); + x += wrapper::svcnt<ScalarType>(); + pg = wrapper::svwhilelt<ScalarType>(x, window_end_x); + } while (svptest_any(all_true_pg, pg)); + }, + input, output); } template void elementwise_sve_op<float16_t>(const ITensor *in, ITensor *out, const Window &window, ElementWiseUnary op); |