aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2022-06-28 19:46:42 +0100
committerMichalis Spyrou <michalis.spyrou@arm.com>2022-07-04 16:00:54 +0000
commite417ff1d9fde119a238582a3b1feb914edd95c38 (patch)
tree7fe8c97e277931bd7d40ce1c98d13851daba0939
parent84a0941cf5bdaffc6127d4cae2e949e6e9109e4a (diff)
downloadComputeLibrary-e417ff1d9fde119a238582a3b1feb914edd95c38.tar.gz
Fix build errors on armv8.6 SVE2 with NDK 23 and 24
Extensive use of templates resulted in a compiler crash on NDK 23 and 24. This rework solves the issue and also reduces the library size by 101Kb. Resolves: COMPMID-5384 Change-Id: I9c5c68c5e36f236b0891e44d25478743417fb16d Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7871 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/cpu/kernels/elementwise_binary/generic/sve/fp16.cpp4
-rw-r--r--src/cpu/kernels/elementwise_binary/generic/sve/fp32.cpp4
-rw-r--r--src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp297
-rw-r--r--src/cpu/kernels/elementwise_binary/generic/sve/impl.h8
-rw-r--r--src/cpu/kernels/elementwise_binary/generic/sve/integer.cpp10
-rw-r--r--src/cpu/kernels/elementwise_binary/generic/sve2/impl.h303
-rw-r--r--src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8.cpp4
-rw-r--r--src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp4
-rw-r--r--src/cpu/kernels/softmax/generic/sve/impl.cpp21
9 files changed, 311 insertions, 344 deletions
diff --git a/src/cpu/kernels/elementwise_binary/generic/sve/fp16.cpp b/src/cpu/kernels/elementwise_binary/generic/sve/fp16.cpp
index 8adacbfe67..85224351df 100644
--- a/src/cpu/kernels/elementwise_binary/generic/sve/fp16.cpp
+++ b/src/cpu/kernels/elementwise_binary/generic/sve/fp16.cpp
@@ -33,7 +33,7 @@ namespace cpu
template <ArithmeticOperation op>
void sve_fp16_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_arithmetic_op<op, float16_t>(in1, in2, out, window);
+ return elementwise_arithmetic_op<float16_t>(in1, in2, out, op, window);
}
template void sve_fp16_elementwise_binary<ArithmeticOperation::ADD>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
@@ -48,7 +48,7 @@ template void sve_fp16_elementwise_binary<ArithmeticOperation::PRELU>(const ITen
template <ComparisonOperation op>
void sve_fp16_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_comparison_op<op, float16_t>(in1, in2, out, window);
+ return elementwise_comparison_op<float16_t>(in1, in2, out, op, window);
}
template void sve_fp16_comparison_elementwise_binary<ComparisonOperation::Equal>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
diff --git a/src/cpu/kernels/elementwise_binary/generic/sve/fp32.cpp b/src/cpu/kernels/elementwise_binary/generic/sve/fp32.cpp
index 0f80813d15..2b479f76f1 100644
--- a/src/cpu/kernels/elementwise_binary/generic/sve/fp32.cpp
+++ b/src/cpu/kernels/elementwise_binary/generic/sve/fp32.cpp
@@ -31,7 +31,7 @@ namespace cpu
template <ArithmeticOperation op>
void sve_fp32_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_arithmetic_op<op, float32_t>(in1, in2, out, window);
+ return elementwise_arithmetic_op<float32_t>(in1, in2, out, op, window);
}
template void sve_fp32_elementwise_binary<ArithmeticOperation::ADD>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
@@ -46,7 +46,7 @@ template void sve_fp32_elementwise_binary<ArithmeticOperation::PRELU>(const ITen
template <ComparisonOperation op>
void sve_fp32_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_comparison_op<op, float>(in1, in2, out, window);
+ return elementwise_comparison_op<float>(in1, in2, out, op, window);
}
template void sve_fp32_comparison_elementwise_binary<ComparisonOperation::Equal>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
template void sve_fp32_comparison_elementwise_binary<ComparisonOperation::NotEqual>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
diff --git a/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp b/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp
index 2a8b155d14..c0515f2abc 100644
--- a/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp
+++ b/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp
@@ -32,81 +32,117 @@ namespace cpu
{
using namespace arm_compute::wrapper;
-template <typename InputScalarType, typename OutputScalarType, typename OperatorType>
-struct LoopArguments
+template <typename ScalarType>
+void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, ArithmeticOperation op, const Window &window)
{
- OperatorType op;
- const InputScalarType *input1_ptr;
- const InputScalarType *input2_ptr;
- OutputScalarType *output_ptr;
-};
-
-template <typename InputScalarType, typename OutputScalarType, typename OperatorType>
-struct BroadcastLoopArguments
-{
- OperatorType op;
- const InputScalarType *input1_ptr;
- InputScalarType broadcast_value;
- OutputScalarType *output_ptr;
- bool reorder;
-};
+ using VectorType = typename sve_vector<ScalarType>::type;
-template <typename InputScalarType, typename OutputScalarType>
-void arithmetic_op_loop(svbool_t pg, const LoopArguments<InputScalarType, OutputScalarType, ArithmeticOperation> &args)
-{
- const auto in1 = svld1(pg, args.input1_ptr);
- const auto in2 = svld1(pg, args.input2_ptr);
- const auto res = elementwise_arithmetic_op<typename sve_vector<InputScalarType>::type>(pg, in1, in2, args.op);
- svst1(pg, args.output_ptr, res);
-}
+ const auto all_true_pg = svptrue<ScalarType>();
-template <typename InputScalarType, typename OutputScalarType>
-void arithmetic_op_broadcast_loop(svbool_t pg, const BroadcastLoopArguments<InputScalarType, OutputScalarType, ArithmeticOperation> &args)
-{
- const auto non_broadcast_vector = svld1(pg, args.input1_ptr);
- const auto broadcast_vector = svdup_n(args.broadcast_value);
- const auto in1 = args.reorder ? broadcast_vector : non_broadcast_vector;
- const auto in2 = args.reorder ? non_broadcast_vector : broadcast_vector;
- const auto res = elementwise_arithmetic_op<typename sve_vector<InputScalarType>::type>(pg, in1, in2, args.op);
- svst1(pg, args.output_ptr, res);
-}
+ // Create input windows
+ Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
+ Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
-template <typename InputScalarType, typename OutputScalarType>
-void comparison_op_loop(svbool_t pg, const LoopArguments<InputScalarType, OutputScalarType, ComparisonOperation> &args)
-{
- const auto in1 = svld1(pg, args.input1_ptr);
- const auto in2 = svld1(pg, args.input2_ptr);
- const auto res = elementwise_comparison_op<typename sve_vector<InputScalarType>::type, typename sve_vector<OutputScalarType>::type>(pg, in1, in2, args.op);
- const svbool_t output_pg = narrow_to_byte_predicate<sizeof(InputScalarType)>(pg);
- svst1(output_pg, args.output_ptr, res);
-}
+ // Clear X Dimension on execution window as we handle manually
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
-template <typename InputScalarType, typename OutputScalarType>
-void comparison_op_broadcast_loop(svbool_t pg, const BroadcastLoopArguments<InputScalarType, OutputScalarType, ComparisonOperation> &args)
-{
- const auto non_broadcast_vector = svld1(pg, args.input1_ptr);
- const auto broadcast_vector = svdup_n(args.broadcast_value);
- const auto in1 = args.reorder ? broadcast_vector : non_broadcast_vector;
- const auto in2 = args.reorder ? non_broadcast_vector : broadcast_vector;
- const auto res = elementwise_comparison_op<typename sve_vector<InputScalarType>::type, typename sve_vector<OutputScalarType>::type>(pg, in1, in2, args.op);
- const svbool_t output_pg = narrow_to_byte_predicate<sizeof(InputScalarType)>(pg);
- svst1(output_pg, args.output_ptr, res);
-}
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+ const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
+
+ if(is_broadcast_across_x)
+ {
+ const bool is_broadcast_input_2 = input2_win.x().step() == 0;
+ Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
+ Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
+ const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
+ const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
+
+ // Clear X Dimension on execution window as we handle manually
+ non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator broadcast_input(broadcast_tensor, broadcast_win);
+ Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
+ Iterator output(out, win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr());
+ const auto non_broadcast_input_ptr = reinterpret_cast<const ScalarType *>(non_broadcast_input.ptr());
+ const ScalarType broadcast_value = *reinterpret_cast<const ScalarType *>(broadcast_input.ptr());
+ const auto broadcast_vector = svdup_n(broadcast_value);
+
+ int x = window_start_x;
-template <typename InputScalarType, typename OutputScalarType, typename OperatorType>
-using LoopFuncType = void (*)(svbool_t, const LoopArguments<InputScalarType, OutputScalarType, OperatorType> &);
+ svbool_t pg = svwhilelt<ScalarType>(x, window_end_x);
+ do
+ {
+ const auto non_broadcast_vector = svld1(pg, non_broadcast_input_ptr + x);
+ VectorType res{};
-template <typename InputScalarType, typename OutputScalarType, typename OperatorType>
-using BroadcastLoopFuncType = void (*)(svbool_t, const BroadcastLoopArguments<InputScalarType, OutputScalarType, OperatorType> &);
+ if(is_broadcast_input_2)
+ {
+ res = elementwise_arithmetic_op<typename sve_vector<ScalarType>::type>(pg, non_broadcast_vector, broadcast_vector, op);
+ }
+ else
+ {
+ res = elementwise_arithmetic_op<typename sve_vector<ScalarType>::type>(pg, broadcast_vector, non_broadcast_vector, op);
+ }
+ svst1(pg, output_ptr + x, res);
-template <typename InputVectorType, typename OutputVectorType, typename OperatorType,
- typename InputScalarType = typename sve_scalar<InputVectorType>::type,
- typename OutputScalarType = typename sve_scalar<OutputVectorType>::type>
-void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
- OperatorType op,
- LoopFuncType<InputScalarType, OutputScalarType, OperatorType> func,
- BroadcastLoopFuncType<InputScalarType, OutputScalarType, OperatorType> broadcast_func)
+ x += svcnt<ScalarType>();
+ pg = svwhilelt<ScalarType>(x, window_end_x);
+ }
+ while(svptest_any(all_true_pg, pg));
+ },
+ broadcast_input, non_broadcast_input, output);
+ }
+ else
+ {
+ // Clear X Dimension on execution window as we handle manually
+ input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
+ input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator input1(in1, input1_win);
+ Iterator input2(in2, input2_win);
+ Iterator output(out, win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr());
+ const auto input1_ptr = reinterpret_cast<const ScalarType *>(input1.ptr());
+ const auto input2_ptr = reinterpret_cast<const ScalarType *>(input2.ptr());
+
+ int x = window_start_x;
+
+ svbool_t pg = svwhilelt<ScalarType>(x, window_end_x);
+ do
+ {
+ const auto in1 = svld1(pg, input1_ptr + x);
+ const auto in2 = svld1(pg, input2_ptr + x);
+ const auto res = elementwise_arithmetic_op<typename sve_vector<ScalarType>::type>(pg, in1, in2, op);
+ svst1(pg, output_ptr + x, res);
+
+ x += svcnt<ScalarType>();
+ pg = svwhilelt<ScalarType>(x, window_end_x);
+ }
+ while(svptest_any(all_true_pg, pg));
+ },
+ input1, input2, output);
+ }
+}
+template void elementwise_arithmetic_op<float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window);
+template void elementwise_arithmetic_op<float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window);
+template void elementwise_arithmetic_op<int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window);
+template void elementwise_arithmetic_op<int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window);
+
+template <typename InputScalarType, typename OutputScalarType>
+void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, ComparisonOperation op, const Window &window)
{
+ static_assert(sizeof(InputScalarType) >= sizeof(OutputScalarType), "input data type's width should be equal to or greater than output data type's width");
+
+ using OutputVectorType = typename sve_vector<OutputScalarType>::type;
const auto all_true_pg = svptrue<InputScalarType>();
// Create input windows
@@ -141,20 +177,26 @@ void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const
auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
const auto non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
const InputScalarType broadcast_value = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
+ const auto broadcast_vector = svdup_n(broadcast_value);
int x = window_start_x;
svbool_t pg = svwhilelt<InputScalarType>(x, window_end_x);
do
{
- broadcast_func(pg,
+ const auto non_broadcast_vector = svld1(pg, non_broadcast_input_ptr + x);
+ const svbool_t output_pg = narrow_to_byte_predicate<sizeof(InputScalarType)>(pg);
+ OutputVectorType res{};
+ if(is_broadcast_input_2)
{
- op,
- non_broadcast_input_ptr + x,
- broadcast_value,
- output_ptr + x,
- !is_broadcast_input_2
- });
+ res = elementwise_comparison_op<typename sve_vector<InputScalarType>::type, typename sve_vector<OutputScalarType>::type>(pg, non_broadcast_vector, broadcast_vector, op);
+ }
+ else
+ {
+ res = elementwise_comparison_op<typename sve_vector<InputScalarType>::type, typename sve_vector<OutputScalarType>::type>(pg, broadcast_vector, non_broadcast_vector, op);
+ }
+ svst1(output_pg, output_ptr + x, res);
+
x += svcnt<InputScalarType>();
pg = svwhilelt<InputScalarType>(x, window_end_x);
}
@@ -183,13 +225,12 @@ void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const
svbool_t pg = svwhilelt<InputScalarType>(x, window_end_x);
do
{
- func(pg,
- {
- op,
- input1_ptr + x,
- input2_ptr + x,
- output_ptr + x
- });
+ const auto in1 = svld1(pg, input1_ptr + x);
+ const auto in2 = svld1(pg, input2_ptr + x);
+ const auto res = elementwise_comparison_op<typename sve_vector<InputScalarType>::type, typename sve_vector<OutputScalarType>::type>(pg, in1, in2, op);
+ const svbool_t output_pg = narrow_to_byte_predicate<sizeof(InputScalarType)>(pg);
+ svst1(output_pg, output_ptr + x, res);
+
x += svcnt<InputScalarType>();
pg = svwhilelt<InputScalarType>(x, window_end_x);
}
@@ -199,97 +240,11 @@ void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const
}
}
-template <ArithmeticOperation op, typename ScalarType>
-void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
-{
- using VectorType = typename sve_vector<ScalarType>::type;
-
- elementwise_op<VectorType, VectorType, ArithmeticOperation>(in1, in2, out, window, op,
- &arithmetic_op_loop<ScalarType, ScalarType>,
- &arithmetic_op_broadcast_loop<ScalarType, ScalarType>);
-}
-template void elementwise_arithmetic_op<ArithmeticOperation::ADD, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::SUB, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::DIV, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::MIN, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::MAX, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::SQUARED_DIFF, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::POWER, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::PRELU, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-
-template void elementwise_arithmetic_op<ArithmeticOperation::ADD, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::SUB, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::DIV, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::MIN, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::MAX, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::SQUARED_DIFF, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::POWER, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::PRELU, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-
-template void elementwise_arithmetic_op<ArithmeticOperation::ADD, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::SUB, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::DIV, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::MIN, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::MAX, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::SQUARED_DIFF, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::POWER, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::PRELU, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-
-template void elementwise_arithmetic_op<ArithmeticOperation::ADD, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::SUB, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::DIV, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::MIN, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::MAX, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::SQUARED_DIFF, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::POWER, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_arithmetic_op<ArithmeticOperation::PRELU, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-
-template <ComparisonOperation op, typename InputScalarType, typename OutputScalarType>
-void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
-{
- static_assert(sizeof(InputScalarType) >= sizeof(OutputScalarType), "input data type's width should be equal to or greater than output data type's width");
- using InputVectorType = typename sve_vector<InputScalarType>::type;
- using OutputVectorType = typename sve_vector<OutputScalarType>::type;
-
- elementwise_op<InputVectorType, OutputVectorType, ComparisonOperation>(in1, in2, out, window, op,
- &comparison_op_loop<InputScalarType, OutputScalarType>,
- &comparison_op_broadcast_loop<InputScalarType, OutputScalarType>);
-}
-
-template void elementwise_comparison_op<ComparisonOperation::Equal, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::NotEqual, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::Greater, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::GreaterEqual, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::Less, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::LessEqual, float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-
-template void elementwise_comparison_op<ComparisonOperation::Equal, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::NotEqual, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::Greater, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::GreaterEqual, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::Less, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::LessEqual, float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-
-template void elementwise_comparison_op<ComparisonOperation::Equal, uint8_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::NotEqual, uint8_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::Greater, uint8_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::GreaterEqual, uint8_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::Less, uint8_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::LessEqual, uint8_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-
-template void elementwise_comparison_op<ComparisonOperation::Equal, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::NotEqual, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::Greater, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::GreaterEqual, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::Less, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::LessEqual, int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-
-template void elementwise_comparison_op<ComparisonOperation::Equal, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::NotEqual, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::Greater, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::GreaterEqual, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::Less, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
-template void elementwise_comparison_op<ComparisonOperation::LessEqual, int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
+template void elementwise_comparison_op<float32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
+template void elementwise_comparison_op<float16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
+template void elementwise_comparison_op<uint8_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
+template void elementwise_comparison_op<int16_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
+template void elementwise_comparison_op<int32_t>(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window);
template <>
svint32_t elementwise_pow<svint32_t>(svbool_t &pg, const svint32_t &a, const svint32_t &b)
diff --git a/src/cpu/kernels/elementwise_binary/generic/sve/impl.h b/src/cpu/kernels/elementwise_binary/generic/sve/impl.h
index 606090d417..860c50a1e0 100644
--- a/src/cpu/kernels/elementwise_binary/generic/sve/impl.h
+++ b/src/cpu/kernels/elementwise_binary/generic/sve/impl.h
@@ -153,11 +153,11 @@ OutputVectorType elementwise_comparison_op(svbool_t &pg, const InputVectorType &
return ret;
}
-template <ArithmeticOperation op, typename ScalarType>
-void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
+template <typename ScalarType>
+void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, ArithmeticOperation op, const Window &window);
-template <ComparisonOperation op, typename ScalarType, typename OutputScalarType = uint8_t>
-void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
+template <typename ScalarType, typename OutputScalarType = uint8_t>
+void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, ComparisonOperation op, const Window &window);
} // namespace cpu
} // namespace arm_compute
#endif /* SRC_CORE_SVE_KERNELS_ELEMENTWISE_LIST_H */
diff --git a/src/cpu/kernels/elementwise_binary/generic/sve/integer.cpp b/src/cpu/kernels/elementwise_binary/generic/sve/integer.cpp
index 8f7e27184b..c313fc6e04 100644
--- a/src/cpu/kernels/elementwise_binary/generic/sve/integer.cpp
+++ b/src/cpu/kernels/elementwise_binary/generic/sve/integer.cpp
@@ -31,7 +31,7 @@ namespace cpu
template <ArithmeticOperation op>
void sve_s32_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_arithmetic_op<op, int32_t>(in1, in2, out, window);
+ return elementwise_arithmetic_op<int32_t>(in1, in2, out, op, window);
}
template void sve_s32_elementwise_binary<ArithmeticOperation::ADD>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
template void sve_s32_elementwise_binary<ArithmeticOperation::SUB>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
@@ -45,7 +45,7 @@ template void sve_s32_elementwise_binary<ArithmeticOperation::PRELU>(const ITens
template <ArithmeticOperation op>
void sve_s16_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_arithmetic_op<op, int16_t>(in1, in2, out, window);
+ return elementwise_arithmetic_op<int16_t>(in1, in2, out, op, window);
}
template void sve_s16_elementwise_binary<ArithmeticOperation::ADD>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
template void sve_s16_elementwise_binary<ArithmeticOperation::SUB>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
@@ -59,7 +59,7 @@ template void sve_s16_elementwise_binary<ArithmeticOperation::PRELU>(const ITens
template <ComparisonOperation op>
void sve_u8_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_comparison_op<op, uint8_t>(in1, in2, out, window);
+ return elementwise_comparison_op<uint8_t>(in1, in2, out, op, window);
}
template void sve_u8_comparison_elementwise_binary<ComparisonOperation::Equal>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
template void sve_u8_comparison_elementwise_binary<ComparisonOperation::NotEqual>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
@@ -71,7 +71,7 @@ template void sve_u8_comparison_elementwise_binary<ComparisonOperation::LessEqua
template <ComparisonOperation op>
void sve_s16_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_comparison_op<op, int16_t>(in1, in2, out, window);
+ return elementwise_comparison_op<int16_t>(in1, in2, out, op, window);
}
template void sve_s16_comparison_elementwise_binary<ComparisonOperation::Equal>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
template void sve_s16_comparison_elementwise_binary<ComparisonOperation::NotEqual>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
@@ -83,7 +83,7 @@ template void sve_s16_comparison_elementwise_binary<ComparisonOperation::LessEqu
template <ComparisonOperation op>
void sve_s32_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_comparison_op<op, int32_t>(in1, in2, out, window);
+ return elementwise_comparison_op<int32_t>(in1, in2, out, op, window);
}
template void sve_s32_comparison_elementwise_binary<ComparisonOperation::Equal>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
template void sve_s32_comparison_elementwise_binary<ComparisonOperation::NotEqual>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
diff --git a/src/cpu/kernels/elementwise_binary/generic/sve2/impl.h b/src/cpu/kernels/elementwise_binary/generic/sve2/impl.h
index f34d05eb37..41e0ac77db 100644
--- a/src/cpu/kernels/elementwise_binary/generic/sve2/impl.h
+++ b/src/cpu/kernels/elementwise_binary/generic/sve2/impl.h
@@ -31,37 +31,6 @@ namespace cpu
{
using namespace arm_compute::wrapper;
-template <typename InputScalarType, typename OutputScalarType, typename OperatorType>
-struct QuantizedLoopArguments
-{
- OperatorType op;
- const InputScalarType *input1_ptr;
- const InputScalarType *input2_ptr;
- OutputScalarType *output_ptr;
-
- const svint32_t &in1_offset;
- const svint32_t &in2_offset;
- const svint32_t &out_offset;
- const svfloat32_t &in1_scale;
- const svfloat32_t &in2_scale;
- const svfloat32_t &out_scale;
-};
-
-template <typename InputScalarType, typename OutputScalarType, typename OperatorType>
-struct BroadcastQuantizedLoopArguments
-{
- OperatorType op;
- const InputScalarType *input1_ptr;
- float broadcast_value;
- OutputScalarType *output_ptr;
- bool reorder;
-
- const svint32_t &in1_offset;
- const svint32_t &out_offset;
- const svfloat32_t &in1_scale;
- const svfloat32_t &out_scale;
-};
-
inline svfloat32x4_t load_quantized(const int8_t *ptr, svbool_t pg, const svint32_t &offset, const svfloat32_t &scale)
{
auto x = svld1(pg, ptr);
@@ -131,98 +100,143 @@ inline void store_quantized(int8_t *ptr, svbool_t pg, svfloat32x4_t data, const
svst1(pg, ptr, narrowed);
}
-template <typename InputScalarType, typename OutputScalarType>
-inline void arithmetic_op_quantized_loop(svbool_t pg, const QuantizedLoopArguments<InputScalarType, OutputScalarType, ArithmeticOperation> &args)
+template <typename ScalarType>
+void elementwise_arithmetic_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *out, ArithmeticOperation op, const Window &window)
{
- const auto in1 = load_quantized(args.input1_ptr, pg, args.in1_offset, args.in1_scale);
- const auto in2 = load_quantized(args.input2_ptr, pg, args.in2_offset, args.in2_scale);
+ const auto all_true_pg = wrapper::svptrue<ScalarType>();
- const auto result = svcreate4(
- elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 0), svget4(in2, 0), args.op),
- elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 1), svget4(in2, 1), args.op),
- elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 2), svget4(in2, 2), args.op),
- elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 3), svget4(in2, 3), args.op));
+ // Create input windows
+ Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
+ Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
- store_quantized(args.output_ptr, pg, result, args.out_offset, args.out_scale);
-}
+ // Clear X Dimension on execution window as we handle manually
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
-template <typename InputScalarType, typename OutputScalarType>
-inline void arithmetic_op_broadcast_quantized_loop(svbool_t pg, const BroadcastQuantizedLoopArguments<InputScalarType, OutputScalarType, ArithmeticOperation> &args)
-{
- const auto in1 = load_quantized(args.input1_ptr, pg, args.in1_offset, args.in1_scale);
- const auto in2 = svcreate4(
- svdup_n(args.broadcast_value), svdup_n(args.broadcast_value), svdup_n(args.broadcast_value), svdup_n(args.broadcast_value));
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+ const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
- const auto &af = args.reorder ? in2 : in1;
- const auto &bf = args.reorder ? in1 : in2;
+ const auto output_voffset = svdup_n(out->info()->quantization_info().uniform().offset);
+ const auto output_vscale = svdup_n(1.f / out->info()->quantization_info().uniform().scale);
- const auto result = svcreate4(
- elementwise_arithmetic_op<svfloat32_t>(pg, svget4(af, 0), svget4(bf, 0), args.op),
- elementwise_arithmetic_op<svfloat32_t>(pg, svget4(af, 1), svget4(bf, 1), args.op),
- elementwise_arithmetic_op<svfloat32_t>(pg, svget4(af, 2), svget4(bf, 2), args.op),
- elementwise_arithmetic_op<svfloat32_t>(pg, svget4(af, 3), svget4(bf, 3), args.op));
+ if(is_broadcast_across_x)
+ {
+ const bool is_broadcast_input_2 = input2_win.x().step() == 0;
+ Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
+ Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
+ const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
+ const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
- store_quantized(args.output_ptr, pg, result, args.out_offset, args.out_scale);
-}
+ const auto non_broadcast_qinfo = is_broadcast_input_2 ? in1->info()->quantization_info() : in2->info()->quantization_info();
+ const auto broadcast_qinfo = is_broadcast_input_2 ? in2->info()->quantization_info() : in1->info()->quantization_info();
-template <typename InputScalarType, typename OutputScalarType>
-inline void comparison_op_quantized_loop(svbool_t pg, const QuantizedLoopArguments<InputScalarType, OutputScalarType, ComparisonOperation> &args)
-{
- const auto in1 = load_quantized(args.input1_ptr, pg, args.in1_offset, args.in1_scale);
- const auto in2 = load_quantized(args.input2_ptr, pg, args.in2_offset, args.in2_scale);
+ const auto non_broadcast_voffset = svdup_n(non_broadcast_qinfo.uniform().offset);
+ const auto non_broadcast_vscale = svdup_n(non_broadcast_qinfo.uniform().scale);
- using OutputVectorType = typename wrapper::traits::sve_vector<OutputScalarType>::type;
+ // Clear X Dimension on execution window as we handle manually
+ non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
- const auto result = svcreate4(
- elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 0), svget4(in2, 0), args.op),
- elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 1), svget4(in2, 1), args.op),
- elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 2), svget4(in2, 2), args.op),
- elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 3), svget4(in2, 3), args.op));
+ Iterator broadcast_input(broadcast_tensor, broadcast_win);
+ Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
+ Iterator output(out, win);
- const auto zipped_bottom = svzip1(svget4(result, 0), svget4(result, 1));
- const auto zipped_top = svzip1(svget4(result, 2), svget4(result, 3));
- const auto zipped = svzip1(zipped_bottom, zipped_top);
- svst1(pg, args.output_ptr, zipped);
-}
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr());
+ const auto non_broadcast_input_ptr = reinterpret_cast<const ScalarType *>(non_broadcast_input.ptr());
+ const ScalarType broadcast_value = *reinterpret_cast<const ScalarType *>(broadcast_input.ptr());
+ const float broadcast_value_f = Qasymm8QuantizationHelper<ScalarType>::dequantize(broadcast_value, broadcast_qinfo);
+ const auto in2 = svcreate4(svdup_n(broadcast_value_f), svdup_n(broadcast_value_f), svdup_n(broadcast_value_f), svdup_n(broadcast_value_f));
-template <typename InputScalarType, typename OutputScalarType>
-inline void comparison_op_broadcast_quantized_loop(svbool_t pg, const BroadcastQuantizedLoopArguments<InputScalarType, OutputScalarType, ComparisonOperation> &args)
-{
- const auto in1 = load_quantized(args.input1_ptr, pg, args.in1_offset, args.in1_scale);
- const auto in2 = svcreate4(
- svdup_n(args.broadcast_value), svdup_n(args.broadcast_value), svdup_n(args.broadcast_value), svdup_n(args.broadcast_value));
+ int x = window_start_x;
+
+ svbool_t pg = wrapper::svwhilelt<ScalarType>(x, window_end_x);
+ do
+ {
+ const auto in1 = load_quantized(non_broadcast_input_ptr + x, pg, non_broadcast_voffset, non_broadcast_vscale);
- const auto &af = args.reorder ? in2 : in1;
- const auto &bf = args.reorder ? in1 : in2;
+ svfloat32x4_t result{};
- using OutputVectorType = typename wrapper::traits::sve_vector<OutputScalarType>::type;
+ if(!is_broadcast_input_2)
+ {
+ result = svcreate4(
+ elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in2, 0), svget4(in1, 0), op),
+ elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in2, 1), svget4(in1, 1), op),
+ elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in2, 2), svget4(in1, 2), op),
+ elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in2, 3), svget4(in1, 3), op));
+ }
+ else
+ {
+ result = svcreate4(
+ elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 0), svget4(in2, 0), op),
+ elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 1), svget4(in2, 1), op),
+ elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 2), svget4(in2, 2), op),
+ elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 3), svget4(in2, 3), op));
+ }
- const auto result = svcreate4(
- elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(af, 0), svget4(bf, 0), args.op),
- elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(af, 1), svget4(bf, 1), args.op),
- elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(af, 2), svget4(bf, 2), args.op),
- elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(af, 3), svget4(bf, 3), args.op));
+ store_quantized(output_ptr + x, pg, result, output_voffset, output_vscale);
- const auto zipped_bottom = svzip1(svget4(result, 0), svget4(result, 1));
- const auto zipped_top = svzip1(svget4(result, 2), svget4(result, 3));
- const auto zipped = svzip1(zipped_bottom, zipped_top);
- svst1(pg, args.output_ptr, zipped);
-}
+ x += wrapper::svcnt<ScalarType>();
+ pg = wrapper::svwhilelt<ScalarType>(x, window_end_x);
+ }
+ while(svptest_any(all_true_pg, pg));
+ },
+ broadcast_input, non_broadcast_input, output);
+ }
+ else
+ {
+ // Clear X Dimension on execution window as we handle manually
+ input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
+ input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator input1(in1, input1_win);
+ Iterator input2(in2, input2_win);
+ Iterator output(out, win);
+
+ const auto in1_voffset = svdup_n(in1->info()->quantization_info().uniform().offset);
+ const auto in1_vscale = svdup_n(in1->info()->quantization_info().uniform().scale);
-template <typename InputScalarType, typename OutputScalarType, typename OperatorType>
-using LoopQuantizedFuncType = void (*)(svbool_t, const QuantizedLoopArguments<InputScalarType, OutputScalarType, OperatorType> &);
+ const auto in2_voffset = svdup_n(in2->info()->quantization_info().uniform().offset);
+ const auto in2_vscale = svdup_n(in2->info()->quantization_info().uniform().scale);
-template <typename InputScalarType, typename OutputScalarType, typename OperatorType>
-using BroadcastQuantizedLoopFuncType = void (*)(svbool_t, const BroadcastQuantizedLoopArguments<InputScalarType, OutputScalarType, OperatorType> &);
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr());
+ const auto input1_ptr = reinterpret_cast<const ScalarType *>(input1.ptr());
+ const auto input2_ptr = reinterpret_cast<const ScalarType *>(input2.ptr());
-template <typename InputVectorType, typename OutputVectorType, typename OperatorType,
- typename InputScalarType = typename wrapper::sve_scalar<InputVectorType>::type,
- typename OutputScalarType = typename wrapper::sve_scalar<OutputVectorType>::type>
-void elementwise_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
- OperatorType op,
- LoopQuantizedFuncType<InputScalarType, OutputScalarType, OperatorType> func,
- BroadcastQuantizedLoopFuncType<InputScalarType, OutputScalarType, OperatorType> broadcast_func)
+ int x = window_start_x;
+
+ svbool_t pg = wrapper::svwhilelt<ScalarType>(x, window_end_x);
+ do
+ {
+ const auto in1 = load_quantized(input1_ptr + x, pg, in1_voffset, in1_vscale);
+ const auto in2 = load_quantized(input2_ptr + x, pg, in2_voffset, in2_vscale);
+
+ const auto result = svcreate4(
+ elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 0), svget4(in2, 0), op),
+ elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 1), svget4(in2, 1), op),
+ elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 2), svget4(in2, 2), op),
+ elementwise_arithmetic_op<svfloat32_t>(pg, svget4(in1, 3), svget4(in2, 3), op));
+
+ store_quantized(output_ptr + x, pg, result, output_voffset, output_vscale);
+
+ x += wrapper::svcnt<ScalarType>();
+ pg = wrapper::svwhilelt<ScalarType>(x, window_end_x);
+ }
+ while(svptest_any(all_true_pg, pg));
+ },
+ input1, input2, output);
+ }
+}
+
+template <typename InputScalarType, typename OutputScalarType = uint8_t>
+void elementwise_comparison_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *out, ComparisonOperation op, const Window &window)
{
+ static_assert(sizeof(InputScalarType) >= sizeof(OutputScalarType), "input data type's width should be equal to or greater than output data type's width");
+
+ using OutputVectorType = typename wrapper::traits::sve_vector<OutputScalarType>::type;
const auto all_true_pg = wrapper::svptrue<InputScalarType>();
// Create input windows
@@ -237,9 +251,6 @@ void elementwise_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *o
const auto window_end_x = static_cast<int>(window.x().end());
const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
- const auto output_voffset = svdup_n(out->info()->quantization_info().uniform().offset);
- const auto output_vscale = svdup_n(1.f / out->info()->quantization_info().uniform().scale);
-
if(is_broadcast_across_x)
{
const bool is_broadcast_input_2 = input2_win.x().step() == 0;
@@ -266,23 +277,40 @@ void elementwise_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *o
auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
const auto non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
const InputScalarType broadcast_value = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
+ const float broadcast_value_f = Qasymm8QuantizationHelper<InputScalarType>::dequantize(broadcast_value, broadcast_qinfo);
+ const auto in2 = svcreate4(svdup_n(broadcast_value_f), svdup_n(broadcast_value_f), svdup_n(broadcast_value_f), svdup_n(broadcast_value_f));
int x = window_start_x;
svbool_t pg = wrapper::svwhilelt<InputScalarType>(x, window_end_x);
do
{
- const auto args = BroadcastQuantizedLoopArguments<InputScalarType, OutputScalarType, OperatorType>
+ const auto in1 = load_quantized(non_broadcast_input_ptr + x, pg, non_broadcast_voffset, non_broadcast_vscale);
+
+ svuint8x4_t result{};
+
+ if(!is_broadcast_input_2)
{
- op,
- non_broadcast_input_ptr + x,
- Qasymm8QuantizationHelper<InputScalarType>::dequantize(broadcast_value, broadcast_qinfo),
- output_ptr + x,
- !is_broadcast_input_2,
- non_broadcast_voffset, output_voffset,
- non_broadcast_vscale, output_vscale
- };
- broadcast_func(pg, args);
+ result = svcreate4(
+ elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in2, 0), svget4(in1, 0), op),
+ elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in2, 1), svget4(in1, 1), op),
+ elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in2, 2), svget4(in1, 2), op),
+ elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in2, 3), svget4(in1, 3), op));
+ }
+ else
+ {
+ result = svcreate4(
+ elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 0), svget4(in2, 0), op),
+ elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 1), svget4(in2, 1), op),
+ elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 2), svget4(in2, 2), op),
+ elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 3), svget4(in2, 3), op));
+ }
+
+ const auto zipped_bottom = svzip1(svget4(result, 0), svget4(result, 1));
+ const auto zipped_top = svzip1(svget4(result, 2), svget4(result, 3));
+ const auto zipped = svzip1(zipped_bottom, zipped_top);
+ svst1(pg, output_ptr + x, zipped);
+
x += wrapper::svcnt<InputScalarType>();
pg = wrapper::svwhilelt<InputScalarType>(x, window_end_x);
}
@@ -317,16 +345,19 @@ void elementwise_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *o
svbool_t pg = wrapper::svwhilelt<InputScalarType>(x, window_end_x);
do
{
- const auto args = QuantizedLoopArguments<InputScalarType, OutputScalarType, OperatorType>
- {
- op,
- input1_ptr + x,
- input2_ptr + x,
- output_ptr + x,
- in1_voffset, in2_voffset, output_voffset,
- in1_vscale, in2_vscale, output_vscale
- };
- func(pg, args);
+ const auto in1 = load_quantized(input1_ptr + x, pg, in1_voffset, in1_vscale);
+ const auto in2 = load_quantized(input2_ptr + x, pg, in2_voffset, in2_vscale);
+ const auto result = svcreate4(
+ elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 0), svget4(in2, 0), op),
+ elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 1), svget4(in2, 1), op),
+ elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 2), svget4(in2, 2), op),
+ elementwise_comparison_op<svfloat32_t, OutputVectorType>(pg, svget4(in1, 3), svget4(in2, 3), op));
+
+ const auto zipped_bottom = svzip1(svget4(result, 0), svget4(result, 1));
+ const auto zipped_top = svzip1(svget4(result, 2), svget4(result, 3));
+ const auto zipped = svzip1(zipped_bottom, zipped_top);
+ svst1(pg, output_ptr + x, zipped);
+
x += wrapper::svcnt<InputScalarType>();
pg = wrapper::svwhilelt<InputScalarType>(x, window_end_x);
}
@@ -335,26 +366,6 @@ void elementwise_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *o
input1, input2, output);
}
}
-
-template <ArithmeticOperation op, typename ScalarType>
-void elementwise_arithmetic_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
-{
- using VectorType = typename wrapper::traits::sve_vector<ScalarType>::type;
- elementwise_quantized_op<VectorType, VectorType, ArithmeticOperation>(in1, in2, out, window, op,
- &arithmetic_op_quantized_loop<ScalarType, ScalarType>,
- &arithmetic_op_broadcast_quantized_loop<ScalarType, ScalarType>);
-}
-
-template <ComparisonOperation op, typename InputScalarType, typename OutputScalarType = uint8_t>
-void elementwise_comparison_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
-{
- static_assert(sizeof(InputScalarType) >= sizeof(OutputScalarType), "input data type's width should be equal to or greater than output data type's width");
- using InputVectorType = typename wrapper::traits::sve_vector<InputScalarType>::type;
- using OutputVectorType = typename wrapper::traits::sve_vector<OutputScalarType>::type;
- elementwise_quantized_op<InputVectorType, OutputVectorType, ComparisonOperation>(in1, in2, out, window, op,
- &comparison_op_quantized_loop<InputScalarType, OutputScalarType>,
- &comparison_op_broadcast_quantized_loop<InputScalarType, OutputScalarType>);
-}
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8.cpp b/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8.cpp
index 479f053685..7435bb4f29 100644
--- a/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8.cpp
+++ b/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8.cpp
@@ -31,7 +31,7 @@ namespace cpu
template <ArithmeticOperation op>
void sve2_qasymm8_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_arithmetic_quantized_op<op, uint8_t>(in1, in2, out, window);
+ return elementwise_arithmetic_quantized_op<uint8_t>(in1, in2, out, op, window);
}
template void sve2_qasymm8_elementwise_binary<ArithmeticOperation::ADD>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
@@ -46,7 +46,7 @@ template void sve2_qasymm8_elementwise_binary<ArithmeticOperation::PRELU>(const
template <ComparisonOperation op>
void sve2_qasymm8_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_comparison_quantized_op<op, uint8_t>(in1, in2, out, window);
+ return elementwise_comparison_quantized_op<uint8_t>(in1, in2, out, op, window);
}
template void sve2_qasymm8_comparison_elementwise_binary<ComparisonOperation::Equal>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
diff --git a/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp b/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp
index ce6250be35..1027a1eed0 100644
--- a/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp
+++ b/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp
@@ -31,7 +31,7 @@ namespace cpu
template <ArithmeticOperation op>
void sve2_qasymm8_signed_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_arithmetic_quantized_op<op, int8_t>(in1, in2, out, window);
+ return elementwise_arithmetic_quantized_op<int8_t>(in1, in2, out, op, window);
}
template void sve2_qasymm8_signed_elementwise_binary<ArithmeticOperation::ADD>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
@@ -46,7 +46,7 @@ template void sve2_qasymm8_signed_elementwise_binary<ArithmeticOperation::PRELU>
template <ComparisonOperation op>
void sve2_qasymm8_signed_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- return elementwise_comparison_quantized_op<op, int8_t>(in1, in2, out, window);
+ return elementwise_comparison_quantized_op<int8_t>(in1, in2, out, op, window);
}
template void sve2_qasymm8_signed_comparison_elementwise_binary<ComparisonOperation::Equal>(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window);
diff --git a/src/cpu/kernels/softmax/generic/sve/impl.cpp b/src/cpu/kernels/softmax/generic/sve/impl.cpp
index f1442224e8..2340a31cbd 100644
--- a/src/cpu/kernels/softmax/generic/sve/impl.cpp
+++ b/src/cpu/kernels/softmax/generic/sve/impl.cpp
@@ -94,8 +94,9 @@ void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *co
/* Compute exponentials and sum */
{
/* Get max value */
- const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr());
- const auto vec_max = wrapper::svdup_n(max_val);
+ const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr());
+ const auto vec_max = wrapper::svdup_n(max_val);
+ const auto vec_beta = wrapper::svdup_n(static_cast<ScalarType>(beta));
/* Init sum to zero */
auto vec_sum = wrapper::svdup_n(static_cast<ScalarType>(0));
@@ -106,19 +107,19 @@ void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *co
do
{
auto vec_elements = svld1(pg, in_ptr + x);
- vec_elements = svsub_z(pg, vec_elements, vec_max);
- if(is_log)
- {
- vec_elements = svmul_z(pg, vec_elements, wrapper::svdup_n(static_cast<ScalarType>(beta)));
- vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements));
- }
- else
+ vec_elements = svmul_z(pg, svsub_z(pg, vec_elements, vec_max), vec_beta);
+ if(!is_log)
{
- vec_elements = wrapper::svexp_z(pg, svmul_z(pg, vec_elements, wrapper::svdup_n(static_cast<ScalarType>(beta))));
+ vec_elements = wrapper::svexp_z(pg, vec_elements);
vec_sum = svadd_m(pg, vec_sum, vec_elements);
}
svst1(pg, tmp_ptr + x, vec_elements);
+ if(is_log)
+ {
+ vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements));
+ }
+
x += wrapper::svcnt<ScalarType>();
pg = wrapper::svwhilelt<ScalarType>(x, input_width);
}