aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp')
-rw-r--r--src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp250
1 files changed, 139 insertions, 111 deletions
diff --git a/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp b/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp
index c0515f2abc..fa48407e9b 100644
--- a/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp
+++ b/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp
@@ -23,7 +23,9 @@
*/
#include "src/cpu/kernels/elementwise_binary/generic/sve/impl.h"
+
#include "src/core/NEON/SVEMath.h"
+
#include <arm_sve.h>
namespace arm_compute
@@ -33,7 +35,8 @@ namespace cpu
using namespace arm_compute::wrapper;
template <typename ScalarType>
-void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, ArithmeticOperation op, const Window &window)
+void elementwise_arithmetic_op(
+ const ITensor *in1, const ITensor *in2, ITensor *out, ArithmeticOperation op, const Window &window)
{
using VectorType = typename sve_vector<ScalarType>::type;
@@ -51,7 +54,7 @@ void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *
const auto window_end_x = static_cast<int>(window.x().end());
const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
- if(is_broadcast_across_x)
+ if (is_broadcast_across_x)
{
const bool is_broadcast_input_2 = input2_win.x().step() == 0;
Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
@@ -66,37 +69,40 @@ void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *
Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
Iterator output(out, win);
- execute_window_loop(win, [&](const Coordinates &)
- {
- auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr());
- const auto non_broadcast_input_ptr = reinterpret_cast<const ScalarType *>(non_broadcast_input.ptr());
- const ScalarType broadcast_value = *reinterpret_cast<const ScalarType *>(broadcast_input.ptr());
- const auto broadcast_vector = svdup_n(broadcast_value);
-
- int x = window_start_x;
-
- svbool_t pg = svwhilelt<ScalarType>(x, window_end_x);
- do
+ execute_window_loop(
+ win,
+ [&](const Coordinates &)
{
- const auto non_broadcast_vector = svld1(pg, non_broadcast_input_ptr + x);
- VectorType res{};
+ auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr());
+ const auto non_broadcast_input_ptr = reinterpret_cast<const ScalarType *>(non_broadcast_input.ptr());
+ const ScalarType broadcast_value = *reinterpret_cast<const ScalarType *>(broadcast_input.ptr());
+ const auto broadcast_vector = svdup_n(broadcast_value);
- if(is_broadcast_input_2)
- {
- res = elementwise_arithmetic_op<typename sve_vector<ScalarType>::type>(pg, non_broadcast_vector, broadcast_vector, op);
- }
- else
+ int x = window_start_x;
+
+ svbool_t pg = svwhilelt<ScalarType>(x, window_end_x);
+ do
{
- res = elementwise_arithmetic_op<typename sve_vector<ScalarType>::type>(pg, broadcast_vector, non_broadcast_vector, op);
- }
- svst1(pg, output_ptr + x, res);
-
- x += svcnt<ScalarType>();
- pg = svwhilelt<ScalarType>(x, window_end_x);
- }
- while(svptest_any(all_true_pg, pg));
- },
- broadcast_input, non_broadcast_input, output);
+ const auto non_broadcast_vector = svld1(pg, non_broadcast_input_ptr + x);
+ VectorType res{};
+
+ if (is_broadcast_input_2)
+ {
+ res = elementwise_arithmetic_op<typename sve_vector<ScalarType>::type>(pg, non_broadcast_vector,
+ broadcast_vector, op);
+ }
+ else
+ {
+ res = elementwise_arithmetic_op<typename sve_vector<ScalarType>::type>(
+ pg, broadcast_vector, non_broadcast_vector, op);
+ }
+ svst1(pg, output_ptr + x, res);
+
+ x += svcnt<ScalarType>();
+ pg = svwhilelt<ScalarType>(x, window_end_x);
+ } while (svptest_any(all_true_pg, pg));
+ },
+ broadcast_input, non_broadcast_input, output);
}
else
{
@@ -108,39 +114,46 @@ void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *
Iterator input2(in2, input2_win);
Iterator output(out, win);
- execute_window_loop(win, [&](const Coordinates &)
- {
- auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr());
- const auto input1_ptr = reinterpret_cast<const ScalarType *>(input1.ptr());
- const auto input2_ptr = reinterpret_cast<const ScalarType *>(input2.ptr());
+ execute_window_loop(
+ win,
+ [&](const Coordinates &)
+ {
+ auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr());
+ const auto input1_ptr = reinterpret_cast<const ScalarType *>(input1.ptr());
+ const auto input2_ptr = reinterpret_cast<const ScalarType *>(input2.ptr());
- int x = window_start_x;
+ int x = window_start_x;
- svbool_t pg = svwhilelt<ScalarType>(x, window_end_x);
- do
- {
- const auto in1 = svld1(pg, input1_ptr + x);
- const auto in2 = svld1(pg, input2_ptr + x);
- const auto res = elementwise_arithmetic_op<typename sve_vector<ScalarType>::type>(pg, in1, in2, op);
- svst1(pg, output_ptr + x, res);
-
- x += svcnt<ScalarType>();
- pg = svwhilelt<ScalarType>(x, window_end_x);
- }
- while(svptest_any(all_true_pg, pg));
- },
- input1, input2, output);
+ svbool_t pg = svwhilelt<ScalarType>(x, window_end_x);
+ do
+ {
+ const auto in1 = svld1(pg, input1_ptr + x);
+ const auto in2 = svld1(pg, input2_ptr + x);
+ const auto res = elementwise_arithmetic_op<typename sve_vector<ScalarType>::type>(pg, in1, in2, op);
+ svst1(pg, output_ptr + x, res);
+
+ x += svcnt<ScalarType>();
+ pg = svwhilelt<ScalarType>(x, window_end_x);
+ } while (svptest_any(all_true_pg, pg));
+ },
+ input1, input2, output);
}
}
-template void elementwise_arithmetic_op<float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window);
-template void elementwise_arithmetic_op<float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window);
-template void elementwise_arithmetic_op<int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window);
-template void elementwise_arithmetic_op<int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window);
+template void elementwise_arithmetic_op<float32_t>(
+ const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window);
+template void elementwise_arithmetic_op<float16_t>(
+ const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window);
+template void elementwise_arithmetic_op<int16_t>(
+ const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window);
+template void elementwise_arithmetic_op<int32_t>(
+ const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window);
template <typename InputScalarType, typename OutputScalarType>
-void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, ComparisonOperation op, const Window &window)
+void elementwise_comparison_op(
+ const ITensor *in1, const ITensor *in2, ITensor *out, ComparisonOperation op, const Window &window)
{
- static_assert(sizeof(InputScalarType) >= sizeof(OutputScalarType), "input data type's width should be equal to or greater than output data type's width");
+ static_assert(sizeof(InputScalarType) >= sizeof(OutputScalarType),
+ "input data type's width should be equal to or greater than output data type's width");
using OutputVectorType = typename sve_vector<OutputScalarType>::type;
const auto all_true_pg = svptrue<InputScalarType>();
@@ -157,7 +170,7 @@ void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *
const auto window_end_x = static_cast<int>(window.x().end());
const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
- if(is_broadcast_across_x)
+ if (is_broadcast_across_x)
{
const bool is_broadcast_input_2 = input2_win.x().step() == 0;
Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
@@ -172,37 +185,44 @@ void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *
Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
Iterator output(out, win);
- execute_window_loop(win, [&](const Coordinates &)
- {
- auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
- const auto non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
- const InputScalarType broadcast_value = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
- const auto broadcast_vector = svdup_n(broadcast_value);
+ execute_window_loop(
+ win,
+ [&](const Coordinates &)
+ {
+ auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
+ const auto non_broadcast_input_ptr =
+ reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
+ const InputScalarType broadcast_value =
+ *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
+ const auto broadcast_vector = svdup_n(broadcast_value);
- int x = window_start_x;
+ int x = window_start_x;
- svbool_t pg = svwhilelt<InputScalarType>(x, window_end_x);
- do
- {
- const auto non_broadcast_vector = svld1(pg, non_broadcast_input_ptr + x);
- const svbool_t output_pg = narrow_to_byte_predicate<sizeof(InputScalarType)>(pg);
- OutputVectorType res{};
- if(is_broadcast_input_2)
- {
- res = elementwise_comparison_op<typename sve_vector<InputScalarType>::type, typename sve_vector<OutputScalarType>::type>(pg, non_broadcast_vector, broadcast_vector, op);
- }
- else
+ svbool_t pg = svwhilelt<InputScalarType>(x, window_end_x);
+ do
{
- res = elementwise_comparison_op<typename sve_vector<InputScalarType>::type, typename sve_vector<OutputScalarType>::type>(pg, broadcast_vector, non_broadcast_vector, op);
- }
- svst1(output_pg, output_ptr + x, res);
-
- x += svcnt<InputScalarType>();
- pg = svwhilelt<InputScalarType>(x, window_end_x);
- }
- while(svptest_any(all_true_pg, pg));
- },
- broadcast_input, non_broadcast_input, output);
+ const auto non_broadcast_vector = svld1(pg, non_broadcast_input_ptr + x);
+ const svbool_t output_pg = narrow_to_byte_predicate<sizeof(InputScalarType)>(pg);
+ OutputVectorType res{};
+ if (is_broadcast_input_2)
+ {
+ res = elementwise_comparison_op<typename sve_vector<InputScalarType>::type,
+ typename sve_vector<OutputScalarType>::type>(
+ pg, non_broadcast_vector, broadcast_vector, op);
+ }
+ else
+ {
+ res = elementwise_comparison_op<typename sve_vector<InputScalarType>::type,
+ typename sve_vector<OutputScalarType>::type>(
+ pg, broadcast_vector, non_broadcast_vector, op);
+ }
+ svst1(output_pg, output_ptr + x, res);
+
+ x += svcnt<InputScalarType>();
+ pg = svwhilelt<InputScalarType>(x, window_end_x);
+ } while (svptest_any(all_true_pg, pg));
+ },
+ broadcast_input, non_broadcast_input, output);
}
else
{
@@ -214,37 +234,45 @@ void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *
Iterator input2(in2, input2_win);
Iterator output(out, win);
- execute_window_loop(win, [&](const Coordinates &)
- {
- auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
- const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr());
- const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr());
+ execute_window_loop(
+ win,
+ [&](const Coordinates &)
+ {
+ auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
+ const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr());
+ const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr());
- int x = window_start_x;
+ int x = window_start_x;
- svbool_t pg = svwhilelt<InputScalarType>(x, window_end_x);
- do
- {
- const auto in1 = svld1(pg, input1_ptr + x);
- const auto in2 = svld1(pg, input2_ptr + x);
- const auto res = elementwise_comparison_op<typename sve_vector<InputScalarType>::type, typename sve_vector<OutputScalarType>::type>(pg, in1, in2, op);
- const svbool_t output_pg = narrow_to_byte_predicate<sizeof(InputScalarType)>(pg);
- svst1(output_pg, output_ptr + x, res);
-
- x += svcnt<InputScalarType>();
- pg = svwhilelt<InputScalarType>(x, window_end_x);
- }
- while(svptest_any(all_true_pg, pg));
- },
- input1, input2, output);
+ svbool_t pg = svwhilelt<InputScalarType>(x, window_end_x);
+ do
+ {
+ const auto in1 = svld1(pg, input1_ptr + x);
+ const auto in2 = svld1(pg, input2_ptr + x);
+ const auto res =
+ elementwise_comparison_op<typename sve_vector<InputScalarType>::type,
+ typename sve_vector<OutputScalarType>::type>(pg, in1, in2, op);
+ const svbool_t output_pg = narrow_to_byte_predicate<sizeof(InputScalarType)>(pg);
+ svst1(output_pg, output_ptr + x, res);
+
+ x += svcnt<InputScalarType>();
+ pg = svwhilelt<InputScalarType>(x, window_end_x);
+ } while (svptest_any(all_true_pg, pg));
+ },
+ input1, input2, output);
}
}
-template void elementwise_comparison_op<float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
-template void elementwise_comparison_op<float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
-template void elementwise_comparison_op<uint8_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
-template void elementwise_comparison_op<int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
-template void elementwise_comparison_op<int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
+template void elementwise_comparison_op<float32_t>(
+ const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
+template void elementwise_comparison_op<float16_t>(
+ const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
+template void elementwise_comparison_op<uint8_t>(
+ const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
+template void elementwise_comparison_op<int16_t>(
+ const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
+template void elementwise_comparison_op<int32_t>(
+ const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
template <>
svint32_t elementwise_pow<svint32_t>(svbool_t &pg, const svint32_t &a, const svint32_t &b)