aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels')
-rw-r--r--src/cpu/kernels/CpuActivationKernel.cpp18
-rw-r--r--src/cpu/kernels/CpuDequantizeKernel.cpp328
-rw-r--r--src/cpu/kernels/CpuDequantizeKernel.h16
-rw-r--r--src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp324
-rw-r--r--src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h41
-rw-r--r--src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp21
-rw-r--r--src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h26
-rw-r--r--src/cpu/kernels/CpuKernelSelectionTypes.h4
-rw-r--r--src/cpu/kernels/CpuQuantizeKernel.cpp344
-rw-r--r--src/cpu/kernels/CpuQuantizeKernel.h26
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.cpp60
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.h13
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp24
-rw-r--r--src/cpu/kernels/assembly/gemm_common.hpp6
-rw-r--r--src/cpu/kernels/dequantize/generic/neon/fp16.cpp37
-rw-r--r--src/cpu/kernels/dequantize/generic/neon/fp32.cpp35
-rw-r--r--src/cpu/kernels/dequantize/generic/neon/impl.h340
-rw-r--r--src/cpu/kernels/dequantize/generic/neon/list.h43
-rw-r--r--src/cpu/kernels/quantize/generic/neon/fp16.cpp45
-rw-r--r--src/cpu/kernels/quantize/generic/neon/fp32.cpp48
-rw-r--r--src/cpu/kernels/quantize/generic/neon/impl.h330
-rw-r--r--src/cpu/kernels/quantize/generic/neon/integer.cpp82
-rw-r--r--src/cpu/kernels/quantize/generic/neon/list.h66
-rw-r--r--src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp65
-rw-r--r--src/cpu/kernels/reduction_layer/generic/neon/fp32.cpp73
-rw-r--r--src/cpu/kernels/reduction_layer/generic/neon/impl.h1633
-rw-r--r--src/cpu/kernels/reduction_layer/generic/neon/integer.cpp62
-rw-r--r--src/cpu/kernels/reduction_layer/generic/neon/list.h66
-rw-r--r--src/cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp63
-rw-r--r--src/cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp63
-rw-r--r--src/cpu/kernels/softmax/generic/neon/fp16.cpp28
-rw-r--r--src/cpu/kernels/softmax/generic/neon/fp32.cpp28
-rw-r--r--src/cpu/kernels/softmax/generic/neon/qasymm8.cpp28
-rw-r--r--src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp28
-rw-r--r--src/cpu/kernels/softmax/generic/sme2/fp16.cpp781
-rw-r--r--src/cpu/kernels/softmax/generic/sme2/fp32.cpp585
-rw-r--r--src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp634
-rw-r--r--src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp655
-rw-r--r--src/cpu/kernels/softmax/list.h43
39 files changed, 6380 insertions, 732 deletions
diff --git a/src/cpu/kernels/CpuActivationKernel.cpp b/src/cpu/kernels/CpuActivationKernel.cpp
index 7cfa39b286..4253027231 100644
--- a/src/cpu/kernels/CpuActivationKernel.cpp
+++ b/src/cpu/kernels/CpuActivationKernel.cpp
@@ -43,6 +43,13 @@ namespace kernels
{
namespace
{
+
+bool is_fp16_lut_supported(ActivationLayerInfo::ActivationFunction func)
+{
+ return func == ActivationLayerInfo::ActivationFunction::LOGISTIC ||
+ func == ActivationLayerInfo::ActivationFunction::TANH;
+}
+
static const std::vector<CpuActivationKernel::ActivationKernel> available_kernels = {
#ifdef ARM_COMPUTE_ENABLE_SVE
{"sve2_q8_activation_lut",
@@ -85,10 +92,7 @@ static const std::vector<CpuActivationKernel::ActivationKernel> available_kernel
REGISTER_QSYMM16_SVE2(arm_compute::cpu::sve2_qsymm16_activation)},
{"sve_fp16_activation_lut",
[](const ActivationDataTypeISASelectorData &data)
- {
- return data.dt == DataType::F16 && data.isa.fp16 && data.isa.sve &&
- data.f == ActivationLayerInfo::ActivationFunction::LOGISTIC;
- },
+ { return data.dt == DataType::F16 && data.isa.fp16 && data.isa.sve && is_fp16_lut_supported(data.f); },
REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_activation_lut)},
{"sve_fp16_activation",
[](const ActivationDataTypeISASelectorData &data)
@@ -299,10 +303,10 @@ void CpuActivationKernel::configure(const ITensorInfo *src, ITensorInfo *dst, Ac
activation_info.setLookupTable256(tmp_lut);
}
- if (src->data_type() == DataType::F16 &&
- activation_info.activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC)
+ if (std::string(uk->name) == "sve_fp16_activation_lut")
{
- const LUTInfo info = {activation_info.activation(), src->data_type(), src->quantization_info()};
+ const LUTInfo info = {activation_info.activation(), activation_info.a(), activation_info.b(), src->data_type(),
+ src->quantization_info().uniform()};
activation_info.setLookupTable65536((lut_manager.get_lut_table(info)));
}
#endif // __aarch64__
diff --git a/src/cpu/kernels/CpuDequantizeKernel.cpp b/src/cpu/kernels/CpuDequantizeKernel.cpp
index d17128b5ac..5595ace998 100644
--- a/src/cpu/kernels/CpuDequantizeKernel.cpp
+++ b/src/cpu/kernels/CpuDequantizeKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -29,12 +29,14 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
+#include "src/core/common/Registrars.h"
#include "src/core/CPP/Validate.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
#include "src/core/NEON/NEAsymm.h"
#include "src/core/NEON/NESymm.h"
#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/cpu/kernels/dequantize/generic/neon/list.h"
#include <arm_neon.h>
@@ -62,301 +64,6 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst)
return Status{};
}
-
-template <typename T>
-inline void store_result(T *ptr, const float32x4x4_t &v)
-{
- ARM_COMPUTE_UNUSED(ptr, v);
-}
-
-template <>
-inline void store_result<float>(float *ptr, const float32x4x4_t &v)
-{
- wrapper::vstore(ptr, v.val[0]);
- wrapper::vstore(ptr + 4, v.val[1]);
- wrapper::vstore(ptr + 8, v.val[2]);
- wrapper::vstore(ptr + 12, v.val[3]);
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-template <>
-inline void store_result<float16_t>(float16_t *ptr, const float32x4x4_t &v)
-{
- wrapper::vstore(ptr, vcombine_f16(vcvt_f16_f32(v.val[0]), vcvt_f16_f32(v.val[1])));
- wrapper::vstore(ptr + 8, vcombine_f16(vcvt_f16_f32(v.val[2]), vcvt_f16_f32(v.val[3])));
-}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-template <typename T>
-inline void store_result(T *ptr, const float32x4x2_t &v)
-{
- ARM_COMPUTE_UNUSED(ptr, v);
-}
-
-template <>
-inline void store_result<float>(float *ptr, const float32x4x2_t &v)
-{
- wrapper::vstore(ptr, v.val[0]);
- wrapper::vstore(ptr + 4, v.val[1]);
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-template <>
-inline void store_result<float16_t>(float16_t *ptr, const float32x4x2_t &v)
-{
- wrapper::vstore(ptr, vcombine_f16(vcvt_f16_f32(v.val[0]), vcvt_f16_f32(v.val[1])));
-}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-template <typename TOut, typename TIn>
-void run_dequantization_qasymm8(const ITensor *input, ITensor *output, const Window &window)
-{
- const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform();
- const float scale = qinfo.scale;
- const int32_t offset = qinfo.offset;
-
- const int window_step_x = 16;
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- // Collapse window and reset first dimension to handle tail calculations manually
- Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
- win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- // Create iterators
- Iterator in(input, win_collapsed);
- Iterator out(output, win_collapsed);
-
- execute_window_loop(
- win_collapsed,
- [&](const Coordinates &)
- {
- const auto in_ptr = reinterpret_cast<const TIn *>(in.ptr());
- const auto out_ptr = reinterpret_cast<TOut *>(out.ptr());
-
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const auto vin = wrapper::vloadq(in_ptr + x);
- const auto vdeq = vdequantize(vin, scale, offset);
-
- store_result(reinterpret_cast<TOut *>(out_ptr + x), vdeq);
- }
-
- // Compute left-over elements
- for (; x < window_end_x; ++x)
- {
- auto val = *(in_ptr + x);
- *(out_ptr + x) = static_cast<TOut>(Qasymm8QuantizationHelper<TIn>::dequantize(val, qinfo));
- }
- },
- in, out);
-}
-
-template <typename T>
-void run_dequantization_qsymm8_per_channel_nchw(const ITensor *input, ITensor *output, const Window &window)
-{
- const auto scale = input->info()->quantization_info().scale();
-
- const int window_step_x = 16;
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- // Reset first dimension to handle tail calculations manually
- Window win(window);
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- // Create iterators
- Iterator in(input, win);
- Iterator out(output, win);
-
- execute_window_loop(
- win,
- [&](const Coordinates &id)
- {
- const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr());
- const auto out_ptr = reinterpret_cast<T *>(out.ptr());
-
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const auto vin = wrapper::vloadq(in_ptr + x);
- const auto vdeq = vdequantize(vin, scale[id.z()]);
-
- store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
- }
-
- // Compute left-over elements
- for (; x < window_end_x; ++x)
- {
- int8_t val = *(in_ptr + x);
- *(out_ptr + x) = static_cast<T>(dequantize(val, scale[id.z()]));
- }
- },
- in, out);
-}
-
-template <typename T>
-void run_dequantization_qsymm8_per_channel_nhwc(const ITensor *input, ITensor *output, const Window &window)
-{
- const auto scale = input->info()->quantization_info().scale();
-
- const int window_step_x = 16;
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- // Reset first dimension to handle tail calculations manually
- Window win(window);
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- // Create iterators
- Iterator in(input, win);
- Iterator out(output, win);
-
- execute_window_loop(
- win,
- [&](const Coordinates &)
- {
- const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr());
- const auto out_ptr = reinterpret_cast<T *>(out.ptr());
-
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float32x4x4_t vscale = {{scale[x + 0], scale[x + 1], scale[x + 2], scale[x + 3], scale[x + 4],
- scale[x + 5], scale[x + 6], scale[x + 7], scale[x + 8], scale[x + 9],
- scale[x + 10], scale[x + 11], scale[x + 12], scale[x + 13],
- scale[x + 14], scale[x + 15]}};
- const auto vin = wrapper::vloadq(in_ptr + x);
- const auto vdeq = vdequantize(vin, vscale);
-
- store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
- }
-
- // Compute left-over elements
- for (; x < window_end_x; ++x)
- {
- int8_t val = *(in_ptr + x);
- *(out_ptr + x) = static_cast<T>(dequantize(val, scale[x]));
- }
- },
- in, out);
-}
-
-template <typename T>
-void run_dequantization_qsymm8(const ITensor *input, ITensor *output, const Window &window)
-{
- const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform();
- const float scale = qinfo.scale;
-
- const int window_step_x = 16;
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- // Collapse window and reset first dimension to handle tail calculations manually
- Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
- win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- // Create iterators
- Iterator in(input, win_collapsed);
- Iterator out(output, win_collapsed);
-
- execute_window_loop(
- win_collapsed,
- [&](const Coordinates &)
- {
- const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr());
- const auto out_ptr = reinterpret_cast<T *>(out.ptr());
-
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const auto vin = wrapper::vloadq(in_ptr + x);
- const auto vdeq = vdequantize(vin, scale);
-
- store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
- }
-
- // Compute left-over elements
- for (; x < window_end_x; ++x)
- {
- int8_t val = *(in_ptr + x);
- *(out_ptr + x) = static_cast<T>(dequantize(val, scale));
- }
- },
- in, out);
-}
-
-template <typename T>
-void run_dequantization_qsymm16(const ITensor *input, ITensor *output, const Window &window)
-{
- const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform();
- const float scale = qinfo.scale;
-
- const int window_step_x = 8;
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- // Collapse window and reset first dimension to handle tail calculations manually
- Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
- win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- // Create iterators
- Iterator in(input, win_collapsed);
- Iterator out(output, win_collapsed);
-
- execute_window_loop(
- win_collapsed,
- [&](const Coordinates &)
- {
- const auto in_ptr = reinterpret_cast<const int16_t *>(in.ptr());
- const auto out_ptr = reinterpret_cast<T *>(out.ptr());
-
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const auto vin = wrapper::vloadq(in_ptr + x);
- const auto vdeq = vdequantize_int16(vin, scale);
-
- store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
- }
-
- // Compute left-over elements
- for (; x < window_end_x; ++x)
- {
- int16_t val = *(in_ptr + x);
- *(out_ptr + x) = static_cast<T>(dequantize_qsymm16(val, scale));
- }
- },
- in, out);
-}
-
-template <typename T>
-void run_dequantization_core(const ITensor *input, ITensor *output, const Window &window)
-{
- switch (input->info()->data_type())
- {
- case DataType::QASYMM8:
- run_dequantization_qasymm8<T, uint8_t>(input, output, window);
- break;
- case DataType::QASYMM8_SIGNED:
- run_dequantization_qasymm8<T, int8_t>(input, output, window);
- break;
- case DataType::QSYMM8_PER_CHANNEL:
- input->info()->data_layout() == DataLayout::NHWC
- ? run_dequantization_qsymm8_per_channel_nhwc<T>(input, output, window)
- : run_dequantization_qsymm8_per_channel_nchw<T>(input, output, window);
- break;
- case DataType::QSYMM8:
- run_dequantization_qsymm8<T>(input, output, window);
- break;
- case DataType::QSYMM16:
- run_dequantization_qsymm16<T>(input, output, window);
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported data type.");
- }
-}
} // namespace
void CpuDequantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
@@ -370,6 +77,20 @@ void CpuDequantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
auto_init_if_empty(*dst, src->tensor_shape(), 1, DataType::F32);
ICpuKernel::configure(win);
+
+ switch (dst->data_type())
+ {
+ case DataType::F32:
+ _func = REGISTER_FP32_NEON(fp32_run_dequantization_core);
+ break;
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ case DataType::F16:
+ _func = REGISTER_FP16_NEON(fp16_run_dequantization_core);
+ break;
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type.");
+ }
}
Status CpuDequantizeKernel::validate(const ITensorInfo *src, const ITensorInfo *dst)
@@ -386,20 +107,7 @@ void CpuDequantizeKernel::run_op(ITensorPack &tensors, const Window &window, con
const auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
auto dst = tensors.get_tensor(TensorType::ACL_DST);
-
- switch (dst->info()->data_type())
- {
- case DataType::F32:
- run_dequantization_core<float>(src, dst, window);
- break;
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- run_dequantization_core<float16_t>(src, dst, window);
- break;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- default:
- ARM_COMPUTE_ERROR("Unsupported data type.");
- }
+ _func(src, dst, window);
}
const char *CpuDequantizeKernel::name() const
{
diff --git a/src/cpu/kernels/CpuDequantizeKernel.h b/src/cpu/kernels/CpuDequantizeKernel.h
index 6ed58587c9..d8b6444f0a 100644
--- a/src/cpu/kernels/CpuDequantizeKernel.h
+++ b/src/cpu/kernels/CpuDequantizeKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_DEQUANTIZE_KERNEL_H
-#define ARM_COMPUTE_CPU_DEQUANTIZE_KERNEL_H
+#ifndef ACL_SRC_CPU_KERNELS_CPUDEQUANTIZEKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPUDEQUANTIZEKERNEL_H
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuKernel.h"
@@ -56,8 +56,16 @@ public:
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
+
+private:
+ /** Common signature for all the specialised @ref CpuDequantizeKernel functions
+ *
+ * @param[in] window Region on which to execute the kernel.
+ */
+ using DequantizeFunctionExecutorPtr = void (*)(const ITensor *input, ITensor *output, const Window &window);
+ DequantizeFunctionExecutorPtr _func{nullptr};
};
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_DEQUANTIZE_KERNEL_H */
+#endif // ACL_SRC_CPU_KERNELS_CPUDEQUANTIZEKERNEL_H
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp
index e290783021..2a76a5958d 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2022,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -51,17 +51,19 @@ Status validate_arguments(const ITensorInfo *mm_result,
int32_t a_offset,
int32_t b_offset)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32, DataType::F32);
- // If a_offset == 0, vector_sum_col can be a nullptr
- if (a_offset != 0)
+ // We run if the offset is nonzero or a sum col has been provided, we need
+ // the second option in case the QuantizationInfo is dynamic
+ if (a_offset != 0 || vector_sum_col != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_col, 1, DataType::S32);
ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_col->dimension(0) != mm_result->dimension(0));
}
- // If b_offset == 0, vector_sum_row can be a nullptr
- if (b_offset != 0)
+ // We run if the offset is nonzero or a sum row has been provided, we need
+ // the second option in case the QuantizationInfo is dynamic
+ if (b_offset != 0 || vector_sum_row != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_row, 1, DataType::S32);
@@ -86,7 +88,7 @@ Status validate_arguments(const ITensorInfo *mm_result,
ARM_COMPUTE_RETURN_ERROR_ON_MSG(vector_sum_row_shape[1] != output_shape[output_batch_idx],
"mm_result tensor must have the same number of batches of output tensor");
- if (a_offset != 0)
+ if (vector_sum_col != nullptr)
{
TensorShape vector_sum_col_shape = vector_sum_col->tensor_shape();
vector_sum_col_shape.collapse_from(1);
@@ -102,6 +104,275 @@ Status validate_arguments(const ITensorInfo *mm_result,
return Status{};
}
+void run_offset_contribution_float(const Window &window,
+ ITensor *mm_result,
+ const ITensor *vector_sum_col,
+ const ITensor *vector_sum_row,
+ int32_t a_offset,
+ int32_t b_offset,
+ int32_t k_offset,
+ float scale,
+ bool slide_vector_sum_col,
+ bool is_gemm3d)
+{
+ Window collapsed_window = window.collapse_if_possible(window, Window::DimZ);
+ collapsed_window.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ const int height_input = is_gemm3d ? mm_result->info()->dimension(1) : 0;
+ const int depth_input = is_gemm3d ? mm_result->info()->dimension(2) : 1;
+
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_step_x = 16;
+
+ // if vector_sum_col is nullptr then stride_y is 0, else get stride_y
+ const size_t sum_col_stride_y = (vector_sum_col != nullptr) ? (vector_sum_col->info()->strides_in_bytes().y()) : 0;
+ Iterator mm_result_it(mm_result, collapsed_window);
+
+ if ((a_offset != 0) && (b_offset != 0) && (vector_sum_col != nullptr) && (vector_sum_row != nullptr)) // true, true
+ {
+ // Set window for vector_sum_col
+ Window win_vector_sum_col(collapsed_window);
+ win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ // Set window for vector_sum_row
+ Window win_vector_sum_row(collapsed_window);
+ win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
+ Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
+
+ const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
+
+ // Offset in case vector_sum_col is batched
+ const int vector_sum_col_batch_offset =
+ slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
+
+ execute_window_loop(
+ collapsed_window,
+ [&](const Coordinates &id)
+ {
+ const int batch_id = id.z() / depth_input;
+ const size_t batch_offset_col = batch_id * (sum_col_stride_y);
+ auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_offset_col +
+ batch_id * vector_sum_col_batch_offset);
+ auto mm_result_ptr = reinterpret_cast<float *>(mm_result_it.ptr());
+
+ // Compute the leftover term due to b_offset.
+ int32_t b_offset_term_s32 =
+ *(reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
+ id.y() + (id.z() % depth_input) * height_input);
+ b_offset_term_s32 *= b_offset;
+
+ const int32x4_t b_offset_term_s32_vec = vdupq_n_s32(b_offset_term_s32);
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Compute the leftover term due to a_offset.
+ int32x4x4_t a_offset_term_s32 = {
+ {vld1q_s32(vector_sum_col_ptr + x + 0), vld1q_s32(vector_sum_col_ptr + x + 4),
+ vld1q_s32(vector_sum_col_ptr + x + 8), vld1q_s32(vector_sum_col_ptr + x + 12)}};
+
+ a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
+ a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
+ a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
+ a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
+
+ // Add a_offset_term_s32 and b_offset_term_s32
+ int32x4x4_t offset_term_s32 = {
+ {vdupq_n_s32(k_offset), vdupq_n_s32(k_offset), vdupq_n_s32(k_offset), vdupq_n_s32(k_offset)}};
+
+ offset_term_s32.val[0] =
+ vaddq_s32(offset_term_s32.val[0], vaddq_s32(a_offset_term_s32.val[0], b_offset_term_s32_vec));
+ offset_term_s32.val[1] =
+ vaddq_s32(offset_term_s32.val[1], vaddq_s32(a_offset_term_s32.val[1], b_offset_term_s32_vec));
+ offset_term_s32.val[2] =
+ vaddq_s32(offset_term_s32.val[2], vaddq_s32(a_offset_term_s32.val[2], b_offset_term_s32_vec));
+ offset_term_s32.val[3] =
+ vaddq_s32(offset_term_s32.val[3], vaddq_s32(a_offset_term_s32.val[3], b_offset_term_s32_vec));
+
+ float32x4x4_t in_f32 = {{vld1q_f32(mm_result_ptr + x + 0), vld1q_f32(mm_result_ptr + x + 4),
+ vld1q_f32(mm_result_ptr + x + 8), vld1q_f32(mm_result_ptr + x + 12)}};
+
+ // Convert and scale the S32 offsets to match the already scaled GEMM results
+ float32x4x4_t offset_terms_scaled = {{
+ vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[0]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[1]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[2]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[3]), scale),
+ }};
+
+ // Add the offset terms to the GEMM result
+ in_f32.val[0] = vaddq_f32(in_f32.val[0], offset_terms_scaled.val[0]);
+ in_f32.val[1] = vaddq_f32(in_f32.val[1], offset_terms_scaled.val[1]);
+ in_f32.val[2] = vaddq_f32(in_f32.val[2], offset_terms_scaled.val[2]);
+ in_f32.val[3] = vaddq_f32(in_f32.val[3], offset_terms_scaled.val[3]);
+
+ // Store the result with the offset contribution
+ vst1q_f32(mm_result_ptr + x + 0, in_f32.val[0]);
+ vst1q_f32(mm_result_ptr + x + 4, in_f32.val[1]);
+ vst1q_f32(mm_result_ptr + x + 8, in_f32.val[2]);
+ vst1q_f32(mm_result_ptr + x + 12, in_f32.val[3]);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Compute the leftover term due to a_offset.
+ int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
+
+ a_offset_term_s32 *= a_offset;
+
+ // Add the offset terms to GEMM's result
+ // Store the result with the offset contribution
+ mm_result_ptr[x] += (k_offset + a_offset_term_s32 + b_offset_term_s32) * scale;
+ }
+ },
+ vector_sum_col_it, vector_sum_row_it, mm_result_it);
+ }
+ else if ((a_offset == 0) && (b_offset != 0) && (vector_sum_row != nullptr)) // false, true
+ {
+ ARM_COMPUTE_ERROR_ON_NULLPTR(vector_sum_row);
+
+ // Set window for vector_sum_row
+ Window win_vector_sum_row(collapsed_window);
+ win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
+
+ const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
+
+ execute_window_loop(
+ collapsed_window,
+ [&](const Coordinates &id)
+ {
+ const int batch_id = id.z() / depth_input;
+ auto mm_result_ptr = reinterpret_cast<float *>(mm_result_it.ptr());
+
+ // Compute the leftover term due to b_offset.
+ int32_t row_sum =
+ *(reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
+ id.y() + (id.z() % depth_input) * height_input);
+ float scaled_b_offset_term_f32 = row_sum * b_offset * scale;
+
+ const float32x4_t b_offset_term_f32_vec = vdupq_n_f32(scaled_b_offset_term_f32);
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ float32x4x4_t in_f32 = {{vld1q_f32(mm_result_ptr + x + 0), vld1q_f32(mm_result_ptr + x + 4),
+ vld1q_f32(mm_result_ptr + x + 8), vld1q_f32(mm_result_ptr + x + 12)}};
+
+ // Add the offset terms to GEMM's result
+ in_f32.val[0] = vaddq_f32(in_f32.val[0], b_offset_term_f32_vec);
+ in_f32.val[1] = vaddq_f32(in_f32.val[1], b_offset_term_f32_vec);
+ in_f32.val[2] = vaddq_f32(in_f32.val[2], b_offset_term_f32_vec);
+ in_f32.val[3] = vaddq_f32(in_f32.val[3], b_offset_term_f32_vec);
+
+ // Store the result with the offset contribution
+ vst1q_f32(mm_result_ptr + x + 0, in_f32.val[0]);
+ vst1q_f32(mm_result_ptr + x + 4, in_f32.val[1]);
+ vst1q_f32(mm_result_ptr + x + 8, in_f32.val[2]);
+ vst1q_f32(mm_result_ptr + x + 12, in_f32.val[3]);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Add the offset terms to GEMM's result
+ // Store the result with the offset contribution
+ mm_result_ptr[x] += scaled_b_offset_term_f32;
+ }
+ },
+ vector_sum_row_it, mm_result_it);
+ }
+ else if ((a_offset != 0) && (b_offset == 0) && (vector_sum_col != nullptr)) // true, false
+ {
+ // Set window for vector_sum_col
+ Window win_vector_sum_col(collapsed_window);
+ win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
+
+ // Offset in case vector_sum_col is batched
+ const int vector_sum_col_batch_offset =
+ slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
+
+ execute_window_loop(
+ collapsed_window,
+ [&](const Coordinates &id)
+ {
+ const int batch_id = id.z() / depth_input;
+ const size_t batch_offset_col =
+ batch_id *
+ (sum_col_stride_y); // Value to offset vector_sum_col_ptr to allow for iteration of y values in tensor
+ auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_offset_col +
+ batch_id * vector_sum_col_batch_offset);
+ auto mm_result_ptr = reinterpret_cast<float *>(mm_result_it.ptr());
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Compute the leftover term due to a_offset.
+ int32x4x4_t a_offset_term_s32 = {
+ {vld1q_s32(vector_sum_col_ptr + x + 0), vld1q_s32(vector_sum_col_ptr + x + 4),
+ vld1q_s32(vector_sum_col_ptr + x + 8), vld1q_s32(vector_sum_col_ptr + x + 12)}};
+
+ a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
+ a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
+ a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
+ a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
+
+ float32x4x4_t a_offset_term_scaled = {{
+ vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[0]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[1]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[2]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[3]), scale),
+ }};
+
+ float32x4x4_t in_f32 = {{vld1q_f32(mm_result_ptr + x + 0), vld1q_f32(mm_result_ptr + x + 4),
+ vld1q_f32(mm_result_ptr + x + 8), vld1q_f32(mm_result_ptr + x + 12)}};
+
+ // Add the offset terms to GEMM's result
+ in_f32.val[0] = vaddq_f32(in_f32.val[0], a_offset_term_scaled.val[0]);
+ in_f32.val[1] = vaddq_f32(in_f32.val[1], a_offset_term_scaled.val[1]);
+ in_f32.val[2] = vaddq_f32(in_f32.val[2], a_offset_term_scaled.val[2]);
+ in_f32.val[3] = vaddq_f32(in_f32.val[3], a_offset_term_scaled.val[3]);
+
+ // Store the result with the offset contribution
+ vst1q_f32(mm_result_ptr + x + 0, in_f32.val[0]);
+ vst1q_f32(mm_result_ptr + x + 4, in_f32.val[1]);
+ vst1q_f32(mm_result_ptr + x + 8, in_f32.val[2]);
+ vst1q_f32(mm_result_ptr + x + 12, in_f32.val[3]);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Compute the leftover term due to a_offset.
+ const int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
+
+ // Add the offset terms to GEMM's result
+ // Store the result with the offset contribution
+ mm_result_ptr[x] += a_offset_term_s32 * a_offset * scale;
+ }
+ },
+ vector_sum_col_it, mm_result_it);
+ }
+ else // false, false
+ {
+ // No offset contribution from matrix A and matrix B
+ return;
+ }
+}
+
void run_offset_contribution(const Window &window,
ITensor *mm_result,
const ITensor *vector_sum_col,
@@ -361,7 +632,8 @@ void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result,
ITensorInfo *vector_sum_row,
int32_t k,
int32_t a_offset,
- int32_t b_offset)
+ int32_t b_offset,
+ float scale)
{
// Perform validate step
ARM_COMPUTE_UNUSED(vector_sum_row);
@@ -370,10 +642,11 @@ void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result,
_a_offset = a_offset;
_b_offset = b_offset;
- _k_offset = a_offset * b_offset * k;
+ _k = k;
- // If a_offset == 0, vector_sum_col can be a nullptr
- if (a_offset != 0)
+ _scale = scale;
+
+ if (vector_sum_col != nullptr)
{
// Check if vector_sum_col_shape should be slidden or not
// Don't slide vector_sum_col_shape along the y dimension if vector_sum_col_shape has just 1 dimension and vector_sum_row_shape more than 1
@@ -386,6 +659,21 @@ void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result,
ICpuKernel::configure(win);
}
+void CpuGemmLowpOffsetContributionKernel::set_a_offset(int32_t a_offset)
+{
+ _a_offset = a_offset;
+}
+
+void CpuGemmLowpOffsetContributionKernel::set_b_offset(int32_t b_offset)
+{
+ _b_offset = b_offset;
+}
+
+void CpuGemmLowpOffsetContributionKernel::set_scale(float scale)
+{
+ _scale = scale;
+}
+
Status CpuGemmLowpOffsetContributionKernel::validate(const ITensorInfo *mm_result,
const ITensorInfo *vector_sum_col,
const ITensorInfo *vector_sum_row,
@@ -410,8 +698,18 @@ void CpuGemmLowpOffsetContributionKernel::run_op(ITensorPack &tensors, const Win
const bool reinterpret_as_3d = vector_sum_row != nullptr && mm_result->info()->num_dimensions() > 1 &&
mm_result->info()->tensor_shape().y() != vector_sum_row->info()->tensor_shape().x();
- run_offset_contribution(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, _k_offset,
- _slide_vector_sum_col, reinterpret_as_3d);
+ // check to see what is the output type of result
+ auto k_offset = _a_offset * _b_offset * _k;
+ if (mm_result->info()->data_type() == DataType::F32)
+ {
+ run_offset_contribution_float(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, k_offset,
+ _scale, _slide_vector_sum_col, reinterpret_as_3d);
+ }
+ else
+ {
+ run_offset_contribution(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, k_offset,
+ _slide_vector_sum_col, reinterpret_as_3d);
+ }
}
const char *CpuGemmLowpOffsetContributionKernel::name() const
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h
index 08b2d47529..ecbfb0c282 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2022,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,12 +21,14 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_KERNEL_H
-#define ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_KERNEL_H
+#ifndef ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONKERNEL_H
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuKernel.h"
+#include <cstdint>
+
namespace arm_compute
{
namespace cpu
@@ -62,13 +64,16 @@ public:
* @param[in] k Number of matrix A columns or Matrix B rows
* @param[in] a_offset Offset to be added to each element of the matrix A.
* @param[in] b_offset Offset to be added to each element of the matrix B.
+ * @param[in] scale (Optional) multiplies the contribution to make it the same scale as the dst in the case where mm_result is float
+ * (and so has already been scaled). Default is 1.0
*/
void configure(ITensorInfo *mm_result,
ITensorInfo *vector_sum_col,
ITensorInfo *vector_sum_row,
int32_t k,
int32_t a_offset,
- int32_t b_offset);
+ int32_t b_offset,
+ float scale = 1.0f);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to CpuGemmLowpOffsetContributionKernel::configure()
@@ -81,6 +86,29 @@ public:
int32_t a_offset,
int32_t b_offset);
+ /** Set the a offset
+ * Warning: if a_offset is non-zero then vector_sum_col must be set in run_op.
+ * Run configure or validate again if you aren't sure
+ *
+ * @param[in] a_offset Offset to be added to each element of the matrix A.
+ */
+ void set_a_offset(int32_t a_offset);
+
+ /** Set the b offset
+ * Warning: if b_offset is non-zero then vector_sum_row must be set in run_op.
+ * Run configure or validate again if you aren't sure
+ *
+ * @param[in] b_offset Offset to be added to each element of the matrix B.
+ */
+ void set_b_offset(int32_t b_offset);
+
+ /** Set the dequantize scale
+ *
+ * @param[in] scale Multiplies the contribution to make it the same scale as the dst in the case where
+ * mm_result is float (and so has already been scaled).
+ */
+ void set_scale(float scale);
+
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
@@ -88,10 +116,11 @@ public:
private:
int32_t _a_offset{0};
int32_t _b_offset{0};
- int32_t _k_offset{0};
+ int32_t _k{0}; // Number of columns of A or rows of B, used in last offset term
+ float _scale{1.0};
bool _slide_vector_sum_col{true};
};
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_KERNEL_H */
+#endif // ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONKERNEL_H
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp
index d008842398..3c113f2828 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021, 2023 Arm Limited.
+ * Copyright (c) 2019-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -919,7 +919,7 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::configure(const ITensorInfo
_a_offset = a_offset;
_b_offset = b_offset;
- _k_offset = a_offset * b_offset * k;
+ _k = k;
_output_stage = output_stage;
// If a_offset == 0, vector_sum_col can be a nullptr
@@ -958,6 +958,16 @@ Status CpuGemmLowpOffsetContributionOutputStageKernel::validate(const ITensorInf
return Status{};
}
+void CpuGemmLowpOffsetContributionOutputStageKernel::set_a_offset(int32_t a_offset)
+{
+ _a_offset = a_offset;
+}
+
+void CpuGemmLowpOffsetContributionOutputStageKernel::set_b_offset(int32_t b_offset)
+{
+ _b_offset = b_offset;
+}
+
void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &tensors,
const Window &window,
const ThreadInfo &info)
@@ -993,10 +1003,11 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &te
// Check if symmetric per-channel execution
const bool is_symm = _output_stage.is_quantized_per_channel;
+ auto k_offset = _a_offset * _b_offset * _k;
if (is_symm)
{
run_offset_contribution_output_stage_symm(window, mm_result, vector_sum_col, vector_sum_row, bias, dst,
- _a_offset, _b_offset, _k_offset, _is_vector_sum_col_batched,
+ _a_offset, _b_offset, k_offset, _is_vector_sum_col_batched,
_output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point);
}
else
@@ -1004,13 +1015,13 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &te
if (is_signed)
{
run_offset_contribution_output_stage<int8_t>(
- window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset,
+ window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, k_offset,
_is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point);
}
else
{
run_offset_contribution_output_stage<uint8_t>(
- window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset,
+ window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, k_offset,
_is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point);
}
}
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h
index af477d4756..ff706ff3dc 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2022 Arm Limited.
+ * Copyright (c) 2019-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_OUTPUTSTAGE_KERNEL_H
-#define ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_OUTPUTSTAGE_KERNEL_H
+#ifndef ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H
#include "arm_compute/core/KernelDescriptors.h"
@@ -110,6 +110,22 @@ public:
int32_t b_offset,
GEMMLowpOutputStageInfo output_stage);
+ /** Set the a offset
+ * Warning: if a_offset is non-zero then vector_sum_col must be set in run_op.
+ * Run configure or validate again if you aren't sure
+ *
+ * @param[in] a_offset Offset to be added to each element of the matrix A.
+ */
+ void set_a_offset(int32_t a_offset);
+
+ /** Set the b offset
+ * Warning: if b_offset is non-zero then vector_sum_col must be set in run_op.
+ * Run configure or validate again if you aren't sure
+ *
+ * @param[in] b_offset Offset to be added to each element of the matrix B.
+ */
+ void set_b_offset(int32_t b_offset);
+
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
@@ -118,11 +134,11 @@ private:
/** Function to use for the particular tensors passed to configure() */
int32_t _a_offset{0};
int32_t _b_offset{0};
- int32_t _k_offset{0};
+ int32_t _k{0}; // Number of columns of A or rows of B, used in last offset term
bool _is_vector_sum_col_batched{true};
GEMMLowpOutputStageInfo _output_stage{GEMMLowpOutputStageInfo()};
};
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_OUTPUTSTAGE_KERNEL_H */
+#endif // ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H
diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h
index 45ebeec394..7c1e4772a6 100644
--- a/src/cpu/kernels/CpuKernelSelectionTypes.h
+++ b/src/cpu/kernels/CpuKernelSelectionTypes.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -104,6 +104,8 @@ struct SoftmaxKernelDataTypeISASelectorData
DataType dt;
cpuinfo::CpuIsaInfo isa;
bool is_log;
+ int axis;
+ unsigned long sme2_vector_length;
};
// Selector pointer types
diff --git a/src/cpu/kernels/CpuQuantizeKernel.cpp b/src/cpu/kernels/CpuQuantizeKernel.cpp
index d2ac6cf8ac..ed4675ae3d 100644
--- a/src/cpu/kernels/CpuQuantizeKernel.cpp
+++ b/src/cpu/kernels/CpuQuantizeKernel.cpp
@@ -29,12 +29,12 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
+#include "src/core/common/Registrars.h"
#include "src/core/CPP/Validate.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
-#include "src/core/NEON/NEAsymm.h"
-#include "src/core/NEON/NEMath.h"
#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/cpu/kernels/quantize/generic/neon/list.h"
#include <arm_neon.h>
#include <map>
@@ -47,7 +47,6 @@ namespace kernels
{
namespace
{
-constexpr auto window_step = 16;
Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst)
{
@@ -63,59 +62,6 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst)
return Status{};
}
-template <typename T>
-inline float32x4x4_t load_value(const T *input_ptr)
-{
- using Tx16_t = typename wrapper::traits::neon_vector<T, 16>::type;
- return arm_compute::convert_to_float32x4x4<Tx16_t>(wrapper::vloadq(input_ptr));
-}
-
-template <>
-inline float32x4x4_t load_value(const float *input_ptr)
-{
- return {wrapper::vloadq(input_ptr), wrapper::vloadq(input_ptr + 4), wrapper::vloadq(input_ptr + 8),
- wrapper::vloadq(input_ptr + 12)};
-}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-template <>
-inline float32x4x4_t load_value(const float16_t *input_ptr)
-{
- return {vcvt_f32_f16(wrapper::vload(input_ptr)), vcvt_f32_f16(wrapper::vload(input_ptr + 4)),
- vcvt_f32_f16(wrapper::vload(input_ptr + 8)), vcvt_f32_f16(wrapper::vload(input_ptr + 12))};
-}
-
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
-template <typename element_type>
-using vector_type = wrapper::traits::neon_vector_t<element_type, window_step>;
-
-template <typename quantized_type>
-vector_type<quantized_type> vquantize_qasymm8(const float32x4x4_t &qv, const UniformQuantizationInfo &qi);
-
-template <>
-vector_type<uint8_t> vquantize_qasymm8<uint8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi)
-{
- return vquantize(qv, qi);
-}
-
-template <>
-vector_type<int8_t> vquantize_qasymm8<int8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi)
-{
- return vquantize_signed(qv, qi);
-}
-
-template <typename TOut, typename = typename std::enable_if<std::is_signed<TOut>::value, bool>::type>
-inline int8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper)
-{
- return wrapper::vcombine(wrapper::vqmovn(lower), wrapper::vqmovn(upper));
-}
-
-template <typename TOut, typename = typename std::enable_if<std::is_unsigned<TOut>::value, bool>::type>
-inline uint8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper)
-{
- return wrapper::vcombine(wrapper::vqmovun(lower), wrapper::vqmovun(upper));
-}
-
} // namespace
void CpuQuantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
@@ -124,38 +70,36 @@ void CpuQuantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, dst));
static const std::map<std::string, QuantizeFunctionExecutorPtr> quant_map = {
- {"op_QASYMM8_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<uint8_t, uint8_t>},
- {"op_QASYMM8_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<uint8_t, int8_t>},
- {"op_QASYMM8_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<uint8_t>},
+ {"op_QASYMM8_QASYMM8", REGISTER_INTEGER_NEON(u8_u8_run_quantize_qasymm8)},
+ {"op_QASYMM8_QASYMM8_SIGNED", REGISTER_INTEGER_NEON(u8_i8_run_quantize_qasymm8)},
+ {"op_QASYMM8_QASYMM16", REGISTER_INTEGER_NEON(u8_run_quantize_qasymm16)},
- {"op_QASYMM8_SIGNED_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<int8_t, uint8_t>},
- {"op_QASYMM8_SIGNED_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<int8_t, int8_t>},
- {"op_QASYMM8_SIGNED_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<int8_t>},
+ {"op_QASYMM8_SIGNED_QASYMM8", REGISTER_INTEGER_NEON(i8_u8_run_quantize_qasymm8)},
+ {"op_QASYMM8_SIGNED_QASYMM8_SIGNED", REGISTER_INTEGER_NEON(i8_i8_run_quantize_qasymm8)},
+ {"op_QASYMM8_SIGNED_QASYMM16", REGISTER_INTEGER_NEON(i8_run_quantize_qasymm16)},
// Functions for offset only requantization
- {"op_OFFSET_ONLY_QASYMM8_QASYMM8", &CpuQuantizeKernel::run_requantize_offset_only<uint8_t, uint8_t>},
- {"op_OFFSET_ONLY_QASYMM8_QASYMM8_SIGNED", &CpuQuantizeKernel::run_requantize_offset_only<uint8_t, int8_t>},
- {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8", &CpuQuantizeKernel::run_requantize_offset_only<int8_t, uint8_t>},
- {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8_SIGNED",
- &CpuQuantizeKernel::run_requantize_offset_only<int8_t, int8_t>},
+ {"op_OFFSET_ONLY_QASYMM8_QASYMM8", REGISTER_INTEGER_NEON(u8_u8_run_requantize_offset_only)},
+ {"op_OFFSET_ONLY_QASYMM8_QASYMM8_SIGNED", REGISTER_INTEGER_NEON(u8_i8_run_requantize_offset_only)},
+ {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8", REGISTER_INTEGER_NEON(i8_u8_run_requantize_offset_only)},
+ {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8_SIGNED", REGISTER_INTEGER_NEON(i8_i8_run_requantize_offset_only)},
// Functions for offset uint8 to int8 and vice versa quantization (no scale changes)
{"op_OFFSET_ONLY_CONVERT_QASYMM8_SIGNED_QASYMM8",
- &CpuQuantizeKernel::run_requantize_offset_only_convert<int8_t, uint8_t>},
+ REGISTER_INTEGER_NEON(i8_u8_run_requantize_offset_only_convert)},
{"op_OFFSET_ONLY_CONVERT_QASYMM8_QASYMM8_SIGNED",
- &CpuQuantizeKernel::run_requantize_offset_only_convert<uint8_t, int8_t>},
-
- {"op_F32_QSYMM8", &CpuQuantizeKernel::run_quantize_qsymm8<float, int8_t>},
-
- {"op_F32_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<float, uint8_t>},
- {"op_F32_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<float, int8_t>},
- {"op_F32_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<float>},
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- {"op_F16_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<float16_t, uint8_t>},
- {"op_F16_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<float16_t, int8_t>},
- {"op_F16_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<float16_t>},
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC*/
+ REGISTER_INTEGER_NEON(u8_i8_run_requantize_offset_only_convert)},
+
+ {"op_F32_QSYMM8", REGISTER_FP32_NEON(fp32_i8_run_quantize_qsymm8)},
+ {"op_F32_QASYMM8", REGISTER_FP32_NEON(fp32_u8_run_quantize_qasymm8)},
+ {"op_F32_QASYMM8_SIGNED", REGISTER_FP32_NEON(fp32_i8_run_quantize_qasymm8)},
+ {"op_F32_QASYMM16", REGISTER_FP32_NEON(fp32_run_quantize_qasymm16)},
+
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ {"op_F16_QASYMM8", REGISTER_FP16_NEON(fp16_u8_run_quantize_qasymm8)},
+ {"op_F16_QASYMM8_SIGNED", REGISTER_FP16_NEON(fp16_i8_run_quantize_qasymm8)},
+ {"op_F16_QASYMM16", REGISTER_FP16_NEON(fp16_run_quantize_qasymm16)},
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
};
std::string function_to_call("op_");
@@ -203,242 +147,6 @@ Status CpuQuantizeKernel::validate(const ITensorInfo *src, const ITensorInfo *ds
return Status{};
}
-template <typename TIn, typename TOut>
-void CpuQuantizeKernel::run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window)
-{
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
- UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
- uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
-
- // Collapse window and reset first dimension to handle tail calculations manually
- Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
- win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator input(src, win_collapsed);
- Iterator output(dst, win_collapsed);
- execute_window_loop(
- win_collapsed,
- [&](const Coordinates &)
- {
- auto input_ptr = reinterpret_cast<const TIn *>(input.ptr());
- auto output_ptr = reinterpret_cast<TOut *>(output.ptr());
- int x = window_start_x;
- for (; x <= (window_end_x - window_step); x += window_step)
- {
- wrapper::vstore(&output_ptr[x], vquantize_qasymm8<TOut>(load_value(&input_ptr[x]), uqinfo));
- }
- // Compute left-over elements
- for (; x < window_end_x; ++x)
- {
- output_ptr[x] = quantize_qsymm8(input_ptr[x], dst->info()->quantization_info());
- }
- },
- input, output);
-}
-
-template <typename TIn, typename TOut>
-void CpuQuantizeKernel::run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window)
-{
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- // Calculate output offset difference.
- const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
- UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
- uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
-
- // Collapse window and reset first dimension to handle tail calculations manually
- Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
-
- win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- // Duplicate offset in signed vector format
- const int8x16_t offset = wrapper::vdup_n(static_cast<int8_t>(uqinfo.offset), wrapper::traits::vector_128_tag{});
-
- Iterator input(src, win_collapsed);
- Iterator output(dst, win_collapsed);
- execute_window_loop(
- win_collapsed,
- [&](const Coordinates &)
- {
- auto input_ptr = reinterpret_cast<const TIn *>(input.ptr());
- auto output_ptr = reinterpret_cast<TOut *>(output.ptr());
- int x = window_start_x;
- for (; x <= (window_end_x - window_step); x += window_step)
- {
- const wrapper::traits::neon_vector_t<TIn, window_step> qv =
- wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype
-
- // Signed addition.
- auto res = vaddq_s8(reinterpret_cast<int8x16_t>(qv), offset);
-
- // Output is dependent on datatype.
- wrapper::vstore(&output_ptr[x],
- reinterpret_cast<wrapper::traits::neon_vector_t<TOut, window_step>>(res));
- }
- // Compute left-over elements
- for (; x < window_end_x; ++x)
- {
- auto result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]);
- output_ptr[x] = static_cast<TOut>(result);
- }
- },
- input, output);
-}
-
-template <typename TIn, typename TOut>
-void CpuQuantizeKernel::run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window)
-{
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
- UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
- uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
-
- // Collapse window and reset first dimension to handle tail calculations manually
- Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
- win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- // Duplicate offset in signed vector format
- const int16x8_t offset = wrapper::vdup_n(static_cast<int16_t>(uqinfo.offset), wrapper::traits::vector_128_tag{});
-
- const int32_t low_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 0 : -128;
- const int32_t upper_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 255 : 127;
-
- Iterator input(src, win_collapsed);
- Iterator output(dst, win_collapsed);
- execute_window_loop(
- win_collapsed,
- [&](const Coordinates &)
- {
- auto input_ptr = reinterpret_cast<const TIn *>(input.ptr());
- TOut *output_ptr = reinterpret_cast<TOut *>(output.ptr());
-
- int x = window_start_x;
- for (; x <= (window_end_x - window_step); x += window_step)
- {
- const auto qv = wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype
- int16x8_t lower = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgetlow(qv)));
- int16x8_t upper = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgethigh(qv)));
-
- // Signed addition.
- lower = wrapper::vqadd(lower, offset);
- upper = wrapper::vqadd(upper, offset);
-
- // Output is dependent on datatype.
- auto res = recombine_8_16<TOut>(lower, upper);
- wrapper::vstore(&output_ptr[x], res);
- }
- // Compute left-over elements
- for (; x < window_end_x; ++x)
- {
- // Add offset and clamp result to within the range of the output datatype.
- int32_t result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]);
- result = utility::clamp<int32_t>(result, low_bound, upper_bound);
-
- // Cast result to output datatype.
- output_ptr[x] = static_cast<TOut>(result);
- }
- },
- input, output);
-}
-
-template <typename TIn, typename TOut>
-void CpuQuantizeKernel::run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window)
-{
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
- UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
- if (is_data_type_quantized_asymmetric(src->info()->data_type()))
- {
- uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
- }
-#ifdef __aarch64__
- constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN;
-#else //__aarch64__
- constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO;
-#endif //__aarch64__
-
- // Collapse window and reset first dimension to handle tail calculations manually
- Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
- win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator input(src, win_collapsed);
- Iterator output(dst, win_collapsed);
- execute_window_loop(
- win_collapsed,
- [&](const Coordinates &)
- {
- auto input_ptr = reinterpret_cast<const TIn *>(input.ptr());
- auto output_ptr = reinterpret_cast<TOut *>(output.ptr());
-
- int x = window_start_x;
- for (; x <= (window_end_x - window_step); x += window_step)
- {
- wrapper::vstore(&output_ptr[x], vquantize_qasymm8<TOut>(load_value(&input_ptr[x]), uqinfo));
- }
- // Compute left-over elements
- for (; x < window_end_x; ++x)
- {
- output_ptr[x] = Qasymm8QuantizationHelper<TOut>::quantize(input_ptr[x], uqinfo, rounding_policy);
- }
- },
- input, output);
-}
-
-template <typename T>
-void CpuQuantizeKernel::run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window)
-{
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
- UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
- if (is_data_type_quantized_asymmetric(src->info()->data_type()))
- {
- uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
- }
-#ifdef __aarch64__
- constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN;
-#else //__aarch64__
- constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO;
-#endif //__aarch64__
-
- // Collapse window and reset first dimension to handle tail calculations manually
- Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
- win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator input(src, win_collapsed);
- Iterator output(dst, win_collapsed);
- execute_window_loop(
- win_collapsed,
- [&](const Coordinates &)
- {
- auto input_ptr = reinterpret_cast<const T *>(input.ptr());
- auto output_ptr = reinterpret_cast<uint16_t *>(output.ptr());
-
- int x = window_start_x;
- for (; x <= (window_end_x - window_step); x += window_step)
- {
- uint16x8x2_t tmp = vquantize_qasymm16(load_value(&input_ptr[x]), uqinfo);
- vst1q_u16(&output_ptr[x], tmp.val[0]);
- vst1q_u16(&output_ptr[x + 8], tmp.val[1]);
- }
- // Compute left-over elements
- for (; x < window_end_x; ++x)
- {
- output_ptr[x] = quantize_qasymm16(input_ptr[x], uqinfo, rounding_policy);
- }
- },
- input, output);
-}
-
void CpuQuantizeKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
@@ -448,7 +156,7 @@ void CpuQuantizeKernel::run_op(ITensorPack &tensors, const Window &window, const
const auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
auto dst = tensors.get_tensor(TensorType::ACL_DST);
- (this->*_func)(src, dst, window);
+ (*_func)(src, dst, window);
}
const char *CpuQuantizeKernel::name() const
diff --git a/src/cpu/kernels/CpuQuantizeKernel.h b/src/cpu/kernels/CpuQuantizeKernel.h
index c2f7ac6d9d..750310c811 100644
--- a/src/cpu/kernels/CpuQuantizeKernel.h
+++ b/src/cpu/kernels/CpuQuantizeKernel.h
@@ -76,31 +76,7 @@ private:
*
* @param[in] window Region on which to execute the kernel.
*/
- using QuantizeFunctionExecutorPtr = void (CpuQuantizeKernel::*)(const ITensor *src,
- ITensor *dst,
- const Window &window);
- /** Function to apply QASYMM8 or QASYMM8_SIGNED quantization on a tensor.
- *
- * @param[in] window Region on which to execute the kernel.
- */
- template <typename TIn, typename TOut>
- void run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window);
- /** Function to apply QASYMM16 quantization on a tensor.
- *
- * @param[in] window Region on which to execute the kernel.
- */
- template <typename T>
- void run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window);
-
- template <typename TIn, typename TOut>
- void run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window);
-
- template <typename TIn, typename TOut>
- void run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window);
-
- template <typename TIn, typename TOut>
- void run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window);
-
+ using QuantizeFunctionExecutorPtr = void (*)(const ITensor *src, ITensor *dst, const Window &window);
QuantizeFunctionExecutorPtr _func{nullptr};
size_t _split_dimension{Window::DimY};
};
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp
index 54ff858eeb..b7e395fb79 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.cpp
+++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp
@@ -48,18 +48,41 @@ namespace kernels
{
namespace
{
+
/* Softmax */
static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_kernels = {
+ {"sme2_fp32_softmax",
+ [](const SoftmaxKernelDataTypeISASelectorData &data)
+ { return (!data.is_log && data.dt == DataType::F32 && data.isa.sme2 && data.axis == 0); },
+ REGISTER_FP32_SME2(sme2_fp32_softmax)},
{"neon_fp32_softmax",
[](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::F32); },
REGISTER_FP32_NEON(neon_fp32_softmax<false>)},
+ {"sme2_fp16_softmax",
+ [](const SoftmaxKernelDataTypeISASelectorData &data)
+ { return (!data.is_log && data.dt == DataType::F16 && data.isa.sme2 && data.axis == 0); },
+ REGISTER_FP16_SME2(sme2_fp16_softmax)},
{"neon_fp16_softmax",
[](const SoftmaxKernelDataTypeISASelectorData &data)
{ return (!data.is_log && data.dt == DataType::F16) && data.isa.fp16; },
REGISTER_FP16_NEON(neon_fp16_softmax<false>)},
+ {"sme2_qu8_softmax_lut_512VL",
+ [](const SoftmaxKernelDataTypeISASelectorData &data)
+ {
+ return (!data.is_log && data.dt == DataType::QASYMM8 && data.isa.sme2 && data.axis == 0 &&
+ data.sme2_vector_length == 512);
+ },
+ REGISTER_QASYMM8_SME2(sme2_qasymm8_softmax_lut_512VL)},
{"neon_qu8_softmax",
[](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::QASYMM8); },
REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_softmax<false>)},
+ {"sme2_qs8_softmax_lut_512VL",
+ [](const SoftmaxKernelDataTypeISASelectorData &data)
+ {
+ return (!data.is_log && data.dt == DataType::QASYMM8_SIGNED && data.isa.sme2 && data.axis == 0 &&
+ data.sme2_vector_length == 512);
+ },
+ REGISTER_QASYMM8_SIGNED_SME2(sme2_qasymm8_signed_softmax_lut_512VL)},
{"neon_qs8_softmax",
[](const SoftmaxKernelDataTypeISASelectorData &data)
{ return (!data.is_log && data.dt == DataType::QASYMM8_SIGNED); },
@@ -80,6 +103,28 @@ static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_ker
REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_signed_softmax<true>)},
};
+void init_lut(std::vector<float> &lut, DataType type, float scale, float beta)
+{
+ if (type == DataType::QASYMM8)
+ {
+ for (int i = 0; i < 256; ++i)
+ {
+ lut.push_back(std::exp(-scale * beta * i));
+ }
+ }
+ else if (type == DataType::QASYMM8_SIGNED)
+ {
+ for (int i = -128; i < 128; ++i)
+ {
+ lut.push_back(std::exp(-scale * beta * i));
+ }
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Invalid datatype for QASYMM8/QASYMM8_SIGNED softmax");
+ }
+}
+
Status validate_arguments_softmax(
const ITensorInfo &src, const ITensorInfo &dst, float beta, int axis, const ITensorInfo &tmp, bool is_log)
{
@@ -149,8 +194,8 @@ void CpuSoftmaxKernel::configure(
auto_init_if_empty(*tmp, TensorInfo(*src).set_data_type(DataType::F32).reset_padding());
}
- const auto *uk = CpuSoftmaxKernel::get_implementation(
- SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log});
+ const auto *uk = CpuSoftmaxKernel::get_implementation(SoftmaxKernelDataTypeISASelectorData{
+ src->data_type(), CPUInfo::get().get_isa(), is_log, axis, CPUInfo::get().get_sme2_vector_length()});
ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
std::string kernel_name = is_log ? std::string("CpuLogSoftmaxKernel") : std::string("CpuSoftmaxKernel");
@@ -186,6 +231,13 @@ void CpuSoftmaxKernel::configure(
win.set(_axis, Window::Dimension(0, 1, 1));
ICpuKernel<CpuSoftmaxKernel>::configure(win);
+
+ const std::string uk_name = uk->name;
+ if (uk_name == "sme2_qu8_softmax_lut_512VL" || uk_name == "sme2_qs8_softmax_lut_512VL")
+ {
+ const float scale = src->quantization_info().uniform().scale;
+ init_lut(_lut, src->data_type(), scale, beta);
+ }
}
Status CpuSoftmaxKernel::validate(
@@ -222,11 +274,11 @@ void CpuSoftmaxKernel::run_op(ITensorPack &tensors, const Window &window, const
const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration;
void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
- _run_method(src, tmp_for_thread, dst, _beta, _axis, window);
+ _run_method(src, tmp_for_thread, dst, _beta, _axis, window, _lut.data());
}
else
{
- _run_method(src, nullptr, dst, _beta, _axis, window);
+ _run_method(src, nullptr, dst, _beta, _axis, window, nullptr);
}
}
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h
index 043ad975d5..676e79782b 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.h
+++ b/src/cpu/kernels/CpuSoftmaxKernel.h
@@ -37,8 +37,8 @@ namespace kernels
class CpuSoftmaxKernel : public ICpuKernel<CpuSoftmaxKernel>
{
private:
- using SoftmaxKernelPtr =
- std::add_pointer<void(const ITensor *, void *const, ITensor *, float, int, const Window &)>::type;
+ using SoftmaxKernelPtr = std::add_pointer<void(
+ const ITensor *, void *const, ITensor *, float, int, const Window &, const float *)>::type;
public:
CpuSoftmaxKernel() = default;
@@ -78,10 +78,11 @@ public:
static const std::vector<SoftmaxKernel> &get_available_kernels();
private:
- float _beta{1.0f};
- SoftmaxKernelPtr _run_method{nullptr};
- std::string _name{};
- int _axis{};
+ float _beta{1.0f};
+ SoftmaxKernelPtr _run_method{nullptr};
+ std::string _name{};
+ int _axis{};
+ std::vector<float> _lut = {};
};
} // namespace kernels
} // namespace cpu
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 9a913c5c58..941fed0ba8 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,6 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
+#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
+#define ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
+
#pragma once
#include "arm_gemm_local.hpp"
@@ -151,6 +155,7 @@ public:
int _maxthreads;
bool _fixed_format;
bool _fast_mode;
+ bool _accumulate;
const GemmConfig *_cfg;
GemmArgs(const CPUInfo *ci,
@@ -165,6 +170,7 @@ public:
const int maxthreads,
bool fixed_format = false,
bool fast_mode = false,
+ bool accumulate = false,
const GemmConfig *cfg = nullptr)
: _ci(ci),
_Msize(M),
@@ -178,6 +184,7 @@ public:
_maxthreads(maxthreads),
_fixed_format(fixed_format),
_fast_mode(fast_mode),
+ _accumulate(accumulate),
_cfg(cfg)
{
}
@@ -253,6 +260,19 @@ public:
}
};
+struct DequantizeFloat
+{
+public:
+ float scale = 0;
+
+ DequantizeFloat() = default;
+
+ // Constructor
+ DequantizeFloat(const float scale) : scale(scale)
+ {
+ }
+};
+
struct Nothing
{
};
@@ -278,3 +298,5 @@ template <typename Top, typename Tret, class OutputStage = Nothing>
bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm
+
+#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index 4825814e31..45d1e43274 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -166,6 +166,12 @@ public:
{
}
+ /*** Dequanize scale interface (optional) ***/
+ /* Set the dequantize scale for GEMMs when converting from int to float (float out = scale * float(int out) ) */
+ virtual void set_dequantize_scale(const float)
+ {
+ }
+
/*** Introspection interface ***/
/* Get the configuration of this GEMM */
virtual GemmConfig get_config() = 0;
diff --git a/src/cpu/kernels/dequantize/generic/neon/fp16.cpp b/src/cpu/kernels/dequantize/generic/neon/fp16.cpp
new file mode 100644
index 0000000000..caffdf53e1
--- /dev/null
+++ b/src/cpu/kernels/dequantize/generic/neon/fp16.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+#include "src/cpu/kernels/dequantize/generic/neon/impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void fp16_run_dequantization_core(const ITensor *input, ITensor *output, const Window &window)
+{
+ run_dequantization_core<float16_t>(input, output, window);
+}
+} // namespace cpu
+} // namespace arm_compute
+#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
diff --git a/src/cpu/kernels/dequantize/generic/neon/fp32.cpp b/src/cpu/kernels/dequantize/generic/neon/fp32.cpp
new file mode 100644
index 0000000000..58e987b450
--- /dev/null
+++ b/src/cpu/kernels/dequantize/generic/neon/fp32.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/cpu/kernels/dequantize/generic/neon/impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void fp32_run_dequantization_core(const ITensor *input, ITensor *output, const Window &window)
+{
+ run_dequantization_core<float>(input, output, window);
+}
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/dequantize/generic/neon/impl.h b/src/cpu/kernels/dequantize/generic/neon/impl.h
new file mode 100644
index 0000000000..7197d4dff6
--- /dev/null
+++ b/src/cpu/kernels/dequantize/generic/neon/impl.h
@@ -0,0 +1,340 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_IMPL_H
+#define ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_IMPL_H
+
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/Window.h"
+
+#include "src/core/NEON/NEAsymm.h"
+#include "src/core/NEON/NESymm.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/cpu/kernels/dequantize/generic/neon/list.h"
+
+#include <arm_neon.h>
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+template <typename T>
+inline void store_result(T *ptr, const float32x4x4_t &v)
+{
+ ARM_COMPUTE_UNUSED(ptr, v);
+}
+
+template <>
+inline void store_result<float>(float *ptr, const float32x4x4_t &v)
+{
+ wrapper::vstore(ptr, v.val[0]);
+ wrapper::vstore(ptr + 4, v.val[1]);
+ wrapper::vstore(ptr + 8, v.val[2]);
+ wrapper::vstore(ptr + 12, v.val[3]);
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+inline void store_result<float16_t>(float16_t *ptr, const float32x4x4_t &v)
+{
+ wrapper::vstore(ptr, vcombine_f16(vcvt_f16_f32(v.val[0]), vcvt_f16_f32(v.val[1])));
+ wrapper::vstore(ptr + 8, vcombine_f16(vcvt_f16_f32(v.val[2]), vcvt_f16_f32(v.val[3])));
+}
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+
+template <typename T>
+inline void store_result(T *ptr, const float32x4x2_t &v)
+{
+ ARM_COMPUTE_UNUSED(ptr, v);
+}
+
+template <>
+inline void store_result<float>(float *ptr, const float32x4x2_t &v)
+{
+ wrapper::vstore(ptr, v.val[0]);
+ wrapper::vstore(ptr + 4, v.val[1]);
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+inline void store_result<float16_t>(float16_t *ptr, const float32x4x2_t &v)
+{
+ wrapper::vstore(ptr, vcombine_f16(vcvt_f16_f32(v.val[0]), vcvt_f16_f32(v.val[1])));
+}
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+
+template <typename TOut, typename TIn>
+void run_dequantization_qasymm8(const ITensor *input, ITensor *output, const Window &window)
+{
+ const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform();
+ const float scale = qinfo.scale;
+ const int32_t offset = qinfo.offset;
+
+ const int window_step_x = 16;
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ // Collapse window and reset first dimension to handle tail calculations manually
+ Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+ win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ // Create iterators
+ Iterator in(input, win_collapsed);
+ Iterator out(output, win_collapsed);
+
+ execute_window_loop(
+ win_collapsed,
+ [&](const Coordinates &)
+ {
+ const auto in_ptr = reinterpret_cast<const TIn *>(in.ptr());
+ const auto out_ptr = reinterpret_cast<TOut *>(out.ptr());
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const auto vin = wrapper::vloadq(in_ptr + x);
+ const auto vdeq = vdequantize(vin, scale, offset);
+
+ store_result(reinterpret_cast<TOut *>(out_ptr + x), vdeq);
+ }
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ auto val = *(in_ptr + x);
+ *(out_ptr + x) = static_cast<TOut>(Qasymm8QuantizationHelper<TIn>::dequantize(val, qinfo));
+ }
+ },
+ in, out);
+}
+
+template <typename T>
+void run_dequantization_qsymm8_per_channel_nchw(const ITensor *input, ITensor *output, const Window &window)
+{
+ const auto scale = input->info()->quantization_info().scale();
+
+ const int window_step_x = 16;
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ // Reset first dimension to handle tail calculations manually
+ Window win(window);
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ // Create iterators
+ Iterator in(input, win);
+ Iterator out(output, win);
+
+ execute_window_loop(
+ win,
+ [&](const Coordinates &id)
+ {
+ const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr());
+ const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const auto vin = wrapper::vloadq(in_ptr + x);
+ const auto vdeq = vdequantize(vin, scale[id.z()]);
+
+ store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
+ }
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ int8_t val = *(in_ptr + x);
+ *(out_ptr + x) = static_cast<T>(dequantize(val, scale[id.z()]));
+ }
+ },
+ in, out);
+}
+
+template <typename T>
+void run_dequantization_qsymm8_per_channel_nhwc(const ITensor *input, ITensor *output, const Window &window)
+{
+ const auto scale = input->info()->quantization_info().scale();
+
+ const int window_step_x = 16;
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ // Reset first dimension to handle tail calculations manually
+ Window win(window);
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ // Create iterators
+ Iterator in(input, win);
+ Iterator out(output, win);
+
+ execute_window_loop(
+ win,
+ [&](const Coordinates &)
+ {
+ const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr());
+ const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const float32x4x4_t vscale = {{scale[x + 0], scale[x + 1], scale[x + 2], scale[x + 3], scale[x + 4],
+ scale[x + 5], scale[x + 6], scale[x + 7], scale[x + 8], scale[x + 9],
+ scale[x + 10], scale[x + 11], scale[x + 12], scale[x + 13],
+ scale[x + 14], scale[x + 15]}};
+ const auto vin = wrapper::vloadq(in_ptr + x);
+ const auto vdeq = vdequantize(vin, vscale);
+
+ store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
+ }
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ int8_t val = *(in_ptr + x);
+ *(out_ptr + x) = static_cast<T>(dequantize(val, scale[x]));
+ }
+ },
+ in, out);
+}
+
+template <typename T>
+void run_dequantization_qsymm8(const ITensor *input, ITensor *output, const Window &window)
+{
+ const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform();
+ const float scale = qinfo.scale;
+
+ const int window_step_x = 16;
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ // Collapse window and reset first dimension to handle tail calculations manually
+ Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+ win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ // Create iterators
+ Iterator in(input, win_collapsed);
+ Iterator out(output, win_collapsed);
+
+ execute_window_loop(
+ win_collapsed,
+ [&](const Coordinates &)
+ {
+ const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr());
+ const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const auto vin = wrapper::vloadq(in_ptr + x);
+ const auto vdeq = vdequantize(vin, scale);
+
+ store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
+ }
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ int8_t val = *(in_ptr + x);
+ *(out_ptr + x) = static_cast<T>(dequantize(val, scale));
+ }
+ },
+ in, out);
+}
+
+template <typename T>
+void run_dequantization_qsymm16(const ITensor *input, ITensor *output, const Window &window)
+{
+ const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform();
+ const float scale = qinfo.scale;
+
+ const int window_step_x = 8;
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ // Collapse window and reset first dimension to handle tail calculations manually
+ Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+ win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ // Create iterators
+ Iterator in(input, win_collapsed);
+ Iterator out(output, win_collapsed);
+
+ execute_window_loop(
+ win_collapsed,
+ [&](const Coordinates &)
+ {
+ const auto in_ptr = reinterpret_cast<const int16_t *>(in.ptr());
+ const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const auto vin = wrapper::vloadq(in_ptr + x);
+ const auto vdeq = vdequantize_int16(vin, scale);
+
+ store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
+ }
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ int16_t val = *(in_ptr + x);
+ *(out_ptr + x) = static_cast<T>(dequantize_qsymm16(val, scale));
+ }
+ },
+ in, out);
+}
+
+template <typename T>
+void run_dequantization_core(const ITensor *input, ITensor *output, const Window &window)
+{
+ switch (input->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ run_dequantization_qasymm8<T, uint8_t>(input, output, window);
+ break;
+ case DataType::QASYMM8_SIGNED:
+ run_dequantization_qasymm8<T, int8_t>(input, output, window);
+ break;
+ case DataType::QSYMM8_PER_CHANNEL:
+ input->info()->data_layout() == DataLayout::NHWC
+ ? run_dequantization_qsymm8_per_channel_nhwc<T>(input, output, window)
+ : run_dequantization_qsymm8_per_channel_nchw<T>(input, output, window);
+ break;
+ case DataType::QSYMM8:
+ run_dequantization_qsymm8<T>(input, output, window);
+ break;
+ case DataType::QSYMM16:
+ run_dequantization_qsymm16<T>(input, output, window);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type.");
+ }
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_IMPL_H
diff --git a/src/cpu/kernels/dequantize/generic/neon/list.h b/src/cpu/kernels/dequantize/generic/neon/list.h
new file mode 100644
index 0000000000..678eb2c01a
--- /dev/null
+++ b/src/cpu/kernels/dequantize/generic/neon/list.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_LIST_H
+#define ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_LIST_H
+
+#include "arm_compute/core/Helpers.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+#define DECLARE_DEQUANTIZE_KERNEL(func_name) void func_name(const ITensor *input, ITensor *output, const Window &window)
+
+DECLARE_DEQUANTIZE_KERNEL(fp32_run_dequantization_core);
+DECLARE_DEQUANTIZE_KERNEL(fp16_run_dequantization_core);
+
+#undef DECLARE_DEQUANTIZE_KERNEL
+
+} // namespace cpu
+} // namespace arm_compute
+#endif // ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_LIST_H
diff --git a/src/cpu/kernels/quantize/generic/neon/fp16.cpp b/src/cpu/kernels/quantize/generic/neon/fp16.cpp
new file mode 100644
index 0000000000..37bfb5b2aa
--- /dev/null
+++ b/src/cpu/kernels/quantize/generic/neon/fp16.cpp
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+#include "src/cpu/kernels/quantize/generic/neon/impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void fp16_u8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qasymm8<float16_t, uint8_t>(src, dst, window);
+}
+void fp16_i8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qasymm8<float16_t, int8_t>(src, dst, window);
+}
+void fp16_run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qasymm16<float16_t>(src, dst, window);
+}
+} // namespace cpu
+} // namespace arm_compute
+#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
diff --git a/src/cpu/kernels/quantize/generic/neon/fp32.cpp b/src/cpu/kernels/quantize/generic/neon/fp32.cpp
new file mode 100644
index 0000000000..0cba332fd6
--- /dev/null
+++ b/src/cpu/kernels/quantize/generic/neon/fp32.cpp
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/cpu/kernels/quantize/generic/neon/impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void fp32_u8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qasymm8<float, uint8_t>(src, dst, window);
+}
+void fp32_i8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qasymm8<float, int8_t>(src, dst, window);
+}
+void fp32_run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qasymm16<float>(src, dst, window);
+}
+
+void fp32_i8_run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qsymm8<float, int8_t>(src, dst, window);
+}
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/quantize/generic/neon/impl.h b/src/cpu/kernels/quantize/generic/neon/impl.h
new file mode 100644
index 0000000000..9954a7645e
--- /dev/null
+++ b/src/cpu/kernels/quantize/generic/neon/impl.h
@@ -0,0 +1,330 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_H
+#define ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_H
+
+#include "arm_compute/core/Helpers.h"
+
+#include "src/core/helpers/WindowHelpers.h"
+#include "src/core/NEON/NEAsymm.h"
+#include "src/core/NEON/wrapper/intrinsics/intrinsics.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+constexpr auto window_step = 16;
+
+template <typename T>
+inline float32x4x4_t load_value(const T *input_ptr)
+{
+ using Tx16_t = typename wrapper::traits::neon_vector<T, 16>::type;
+ return arm_compute::convert_to_float32x4x4<Tx16_t>(wrapper::vloadq(input_ptr));
+}
+
+template <>
+inline float32x4x4_t load_value(const float *input_ptr)
+{
+ return {wrapper::vloadq(input_ptr), wrapper::vloadq(input_ptr + 4), wrapper::vloadq(input_ptr + 8),
+ wrapper::vloadq(input_ptr + 12)};
+}
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+inline float32x4x4_t load_value(const float16_t *input_ptr)
+{
+ return {vcvt_f32_f16(wrapper::vload(input_ptr)), vcvt_f32_f16(wrapper::vload(input_ptr + 4)),
+ vcvt_f32_f16(wrapper::vload(input_ptr + 8)), vcvt_f32_f16(wrapper::vload(input_ptr + 12))};
+}
+
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+template <typename element_type>
+using vector_type = wrapper::traits::neon_vector_t<element_type, window_step>;
+
+template <typename quantized_type>
+inline vector_type<quantized_type> vquantize_qasymm8(const float32x4x4_t &qv, const UniformQuantizationInfo &qi);
+
+template <>
+inline vector_type<uint8_t> vquantize_qasymm8<uint8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi)
+{
+ return vquantize(qv, qi);
+}
+
+template <>
+inline vector_type<int8_t> vquantize_qasymm8<int8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi)
+{
+ return vquantize_signed(qv, qi);
+}
+
+template <typename TOut, typename = typename std::enable_if<std::is_signed<TOut>::value, bool>::type>
+inline int8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper)
+{
+ return wrapper::vcombine(wrapper::vqmovn(lower), wrapper::vqmovn(upper));
+}
+
+template <typename TOut, typename = typename std::enable_if<std::is_unsigned<TOut>::value, bool>::type>
+inline uint8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper)
+{
+ return wrapper::vcombine(wrapper::vqmovun(lower), wrapper::vqmovun(upper));
+}
+
+template <typename TIn, typename TOut>
+void run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window)
+{
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
+ UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
+ uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
+
+ // Collapse window and reset first dimension to handle tail calculations manually
+ Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+ win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator input(src, win_collapsed);
+ Iterator output(dst, win_collapsed);
+ execute_window_loop(
+ win_collapsed,
+ [&](const Coordinates &)
+ {
+ auto input_ptr = reinterpret_cast<const TIn *>(input.ptr());
+ auto output_ptr = reinterpret_cast<TOut *>(output.ptr());
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step); x += window_step)
+ {
+ wrapper::vstore(&output_ptr[x], vquantize_qasymm8<TOut>(load_value(&input_ptr[x]), uqinfo));
+ }
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ output_ptr[x] = quantize_qsymm8(input_ptr[x], dst->info()->quantization_info());
+ }
+ },
+ input, output);
+}
+
+template <typename TIn, typename TOut>
+void run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window)
+{
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ // Calculate output offset difference.
+ const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
+ UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
+ uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
+
+ // Collapse window and reset first dimension to handle tail calculations manually
+ Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+
+ win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ // Duplicate offset in signed vector format
+ const int8x16_t offset = wrapper::vdup_n(static_cast<int8_t>(uqinfo.offset), wrapper::traits::vector_128_tag{});
+
+ Iterator input(src, win_collapsed);
+ Iterator output(dst, win_collapsed);
+ execute_window_loop(
+ win_collapsed,
+ [&](const Coordinates &)
+ {
+ auto input_ptr = reinterpret_cast<const TIn *>(input.ptr());
+ auto output_ptr = reinterpret_cast<TOut *>(output.ptr());
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step); x += window_step)
+ {
+ const wrapper::traits::neon_vector_t<TIn, window_step> qv =
+ wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype
+
+ // Signed addition.
+ auto res = vaddq_s8(reinterpret_cast<int8x16_t>(qv), offset);
+
+ // Output is dependent on datatype.
+ wrapper::vstore(&output_ptr[x],
+ reinterpret_cast<wrapper::traits::neon_vector_t<TOut, window_step>>(res));
+ }
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ auto result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]);
+ output_ptr[x] = static_cast<TOut>(result);
+ }
+ },
+ input, output);
+}
+
+template <typename TIn, typename TOut>
+void run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window)
+{
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
+ UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
+ uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
+
+ // Collapse window and reset first dimension to handle tail calculations manually
+ Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+ win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ // Duplicate offset in signed vector format
+ const int16x8_t offset = wrapper::vdup_n(static_cast<int16_t>(uqinfo.offset), wrapper::traits::vector_128_tag{});
+
+ const int32_t low_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 0 : -128;
+ const int32_t upper_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 255 : 127;
+
+ Iterator input(src, win_collapsed);
+ Iterator output(dst, win_collapsed);
+ execute_window_loop(
+ win_collapsed,
+ [&](const Coordinates &)
+ {
+ auto input_ptr = reinterpret_cast<const TIn *>(input.ptr());
+ TOut *output_ptr = reinterpret_cast<TOut *>(output.ptr());
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step); x += window_step)
+ {
+ const auto qv = wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype
+ int16x8_t lower = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgetlow(qv)));
+ int16x8_t upper = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgethigh(qv)));
+
+ // Signed addition.
+ lower = wrapper::vqadd(lower, offset);
+ upper = wrapper::vqadd(upper, offset);
+
+ // Output is dependent on datatype.
+ auto res = recombine_8_16<TOut>(lower, upper);
+ wrapper::vstore(&output_ptr[x], res);
+ }
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ // Add offset and clamp result to within the range of the output datatype.
+ int32_t result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]);
+ result = utility::clamp<int32_t>(result, low_bound, upper_bound);
+
+ // Cast result to output datatype.
+ output_ptr[x] = static_cast<TOut>(result);
+ }
+ },
+ input, output);
+}
+
+template <typename TIn, typename TOut>
+void run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window)
+{
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
+ UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
+ if (is_data_type_quantized_asymmetric(src->info()->data_type()))
+ {
+ uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
+ }
+#ifdef __aarch64__
+ constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN;
+#else //__aarch64__
+ constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO;
+#endif //__aarch64__
+
+ // Collapse window and reset first dimension to handle tail calculations manually
+ Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+ win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator input(src, win_collapsed);
+ Iterator output(dst, win_collapsed);
+ execute_window_loop(
+ win_collapsed,
+ [&](const Coordinates &)
+ {
+ auto input_ptr = reinterpret_cast<const TIn *>(input.ptr());
+ auto output_ptr = reinterpret_cast<TOut *>(output.ptr());
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step); x += window_step)
+ {
+ wrapper::vstore(&output_ptr[x], vquantize_qasymm8<TOut>(load_value(&input_ptr[x]), uqinfo));
+ }
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ output_ptr[x] = Qasymm8QuantizationHelper<TOut>::quantize(input_ptr[x], uqinfo, rounding_policy);
+ }
+ },
+ input, output);
+}
+
+template <typename T>
+void run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window)
+{
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform();
+ UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform();
+ if (is_data_type_quantized_asymmetric(src->info()->data_type()))
+ {
+ uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo);
+ }
+#ifdef __aarch64__
+ constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN;
+#else //__aarch64__
+ constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO;
+#endif //__aarch64__
+
+ // Collapse window and reset first dimension to handle tail calculations manually
+ Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+ win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator input(src, win_collapsed);
+ Iterator output(dst, win_collapsed);
+ execute_window_loop(
+ win_collapsed,
+ [&](const Coordinates &)
+ {
+ auto input_ptr = reinterpret_cast<const T *>(input.ptr());
+ auto output_ptr = reinterpret_cast<uint16_t *>(output.ptr());
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step); x += window_step)
+ {
+ uint16x8x2_t tmp = vquantize_qasymm16(load_value(&input_ptr[x]), uqinfo);
+ vst1q_u16(&output_ptr[x], tmp.val[0]);
+ vst1q_u16(&output_ptr[x + 8], tmp.val[1]);
+ }
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ output_ptr[x] = quantize_qasymm16(input_ptr[x], uqinfo, rounding_policy);
+ }
+ },
+ input, output);
+}
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_H
diff --git a/src/cpu/kernels/quantize/generic/neon/integer.cpp b/src/cpu/kernels/quantize/generic/neon/integer.cpp
new file mode 100644
index 0000000000..4e39afaaee
--- /dev/null
+++ b/src/cpu/kernels/quantize/generic/neon/integer.cpp
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/cpu/kernels/quantize/generic/neon/impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void u8_u8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qasymm8<uint8_t, uint8_t>(src, dst, window);
+}
+void u8_i8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qasymm8<uint8_t, int8_t>(src, dst, window);
+}
+void i8_u8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qasymm8<int8_t, uint8_t>(src, dst, window);
+}
+void i8_i8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qasymm8<int8_t, int8_t>(src, dst, window);
+}
+
+void u8_run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qasymm16<uint8_t>(src, dst, window);
+}
+void i8_run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_quantize_qasymm16<int8_t>(src, dst, window);
+}
+
+void u8_u8_run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_requantize_offset_only<uint8_t, uint8_t>(src, dst, window);
+}
+void u8_i8_run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_requantize_offset_only<uint8_t, int8_t>(src, dst, window);
+}
+void i8_u8_run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_requantize_offset_only<int8_t, uint8_t>(src, dst, window);
+}
+void i8_i8_run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_requantize_offset_only<int8_t, int8_t>(src, dst, window);
+}
+
+void i8_u8_run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_requantize_offset_only_convert<int8_t, uint8_t>(src, dst, window);
+}
+void u8_i8_run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window)
+{
+ run_requantize_offset_only_convert<uint8_t, int8_t>(src, dst, window);
+}
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/quantize/generic/neon/list.h b/src/cpu/kernels/quantize/generic/neon/list.h
new file mode 100644
index 0000000000..c4fb1048eb
--- /dev/null
+++ b/src/cpu/kernels/quantize/generic/neon/list.h
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_LIST_H
+#define ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_LIST_H
+
+#include "arm_compute/core/Helpers.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+#define DECLARE_QUANTIZE_KERNEL(func_name) void func_name(const ITensor *src, ITensor *dst, const Window &window)
+
+DECLARE_QUANTIZE_KERNEL(u8_u8_run_quantize_qasymm8);
+DECLARE_QUANTIZE_KERNEL(u8_i8_run_quantize_qasymm8);
+DECLARE_QUANTIZE_KERNEL(i8_u8_run_quantize_qasymm8);
+DECLARE_QUANTIZE_KERNEL(i8_i8_run_quantize_qasymm8);
+
+DECLARE_QUANTIZE_KERNEL(u8_u8_run_requantize_offset_only);
+DECLARE_QUANTIZE_KERNEL(u8_i8_run_requantize_offset_only);
+DECLARE_QUANTIZE_KERNEL(i8_u8_run_requantize_offset_only);
+DECLARE_QUANTIZE_KERNEL(i8_i8_run_requantize_offset_only);
+
+DECLARE_QUANTIZE_KERNEL(i8_u8_run_requantize_offset_only_convert);
+DECLARE_QUANTIZE_KERNEL(u8_i8_run_requantize_offset_only_convert);
+
+DECLARE_QUANTIZE_KERNEL(u8_run_quantize_qasymm16);
+DECLARE_QUANTIZE_KERNEL(i8_run_quantize_qasymm16);
+
+DECLARE_QUANTIZE_KERNEL(fp32_u8_run_quantize_qasymm8);
+DECLARE_QUANTIZE_KERNEL(fp32_i8_run_quantize_qasymm8);
+DECLARE_QUANTIZE_KERNEL(fp32_run_quantize_qasymm16);
+
+DECLARE_QUANTIZE_KERNEL(fp32_i8_run_quantize_qsymm8);
+
+DECLARE_QUANTIZE_KERNEL(fp16_u8_run_quantize_qasymm8);
+DECLARE_QUANTIZE_KERNEL(fp16_i8_run_quantize_qasymm8);
+DECLARE_QUANTIZE_KERNEL(fp16_run_quantize_qasymm16);
+
+#undef DECLARE_QUANTIZE_KERNEL
+
+} // namespace cpu
+} // namespace arm_compute
+#endif // ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_LIST_H
diff --git a/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp b/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp
new file mode 100644
index 0000000000..143bb5487f
--- /dev/null
+++ b/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+
+#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void reduce_RedOpX_reduceX_float16_8(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
+}
+
+void reduce_RedOpYZW_reduceY_float16_8(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
+}
+
+void reduce_RedOpYZW_reduceZ_float16_8(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
+}
+
+void reduce_RedOpYZW_reduceW_float16_8(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
+}
+} // namespace cpu
+} // namespace arm_compute
+#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
diff --git a/src/cpu/kernels/reduction_layer/generic/neon/fp32.cpp b/src/cpu/kernels/reduction_layer/generic/neon/fp32.cpp
new file mode 100644
index 0000000000..6f5f13e571
--- /dev/null
+++ b/src/cpu/kernels/reduction_layer/generic/neon/fp32.cpp
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void reduce_RedOpYZW_complex_reduceZ_float32_4_2_SUM(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(
+ window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
+}
+
+void reduce_RedOpX_reduceX_float32_4(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
+}
+
+void reduce_RedOpYZW_reduceY_float32_4(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
+}
+
+void reduce_RedOpYZW_reduceZ_float32_4(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
+}
+
+void reduce_RedOpYZW_reduceW_float32_4(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
+}
+
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/reduction_layer/generic/neon/impl.h b/src/cpu/kernels/reduction_layer/generic/neon/impl.h
new file mode 100644
index 0000000000..3fa821d3a4
--- /dev/null
+++ b/src/cpu/kernels/reduction_layer/generic/neon/impl.h
@@ -0,0 +1,1633 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_H
+#define ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_H
+
+#include "arm_compute/core/Coordinates.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/TensorInfo.h"
+
+#include "src/core/NEON/NEMath.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+#include "support/SaturateCast.h"
+
+#include <arm_neon.h>
+
+namespace arm_compute
+{
+// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized
+template <typename T>
+void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0)
+{
+ if (std::is_same<T, uint8_t>::value)
+ {
+ auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2));
+ wrapper::vstore(output.ptr() + offset, res);
+ }
+ else
+ {
+ auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2));
+ wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res);
+ }
+}
+
+template <typename T>
+uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
+{
+ uint32x4_t mask{0};
+ if (op == ReductionOperation::ARG_IDX_MIN)
+ {
+ mask = wrapper::vcgt(b, a);
+ }
+ else
+ {
+ mask = wrapper::vclt(b, a);
+ }
+
+ uint32x4_t vec_idx = {idx, idx + 1, idx + 2, idx + 3};
+ if (axis != 0)
+ {
+ vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
+ }
+ uint32x4x4_t res = {{wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0}};
+
+ return res;
+}
+
+template <typename T>
+uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
+{
+ uint32x4x4_t mask{{0}};
+ uint8x16_t mask_u8{0};
+ if (op == ReductionOperation::ARG_IDX_MIN)
+ {
+ mask_u8 = wrapper::vcgt(b, a);
+ }
+ else
+ {
+ mask_u8 = wrapper::vclt(b, a);
+ }
+ auto wide_u16_1 =
+ wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
+ auto wide_u16_2 =
+ wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
+ mask.val[0] =
+ wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
+ mask.val[1] =
+ wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
+ mask.val[2] =
+ wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
+ mask.val[3] =
+ wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
+
+ uint32x4x4_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3},
+ {idx + 4, idx + 5, idx + 6, idx + 7},
+ {idx + 8, idx + 9, idx + 10, idx + 11},
+ {idx + 12, idx + 13, idx + 14, idx + 15}}};
+ if (axis != 0)
+ {
+ vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
+ vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
+ vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
+ vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
+ }
+ uint32x4x4_t res = {
+ {vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]), vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
+ vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]), vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])}};
+
+ return res;
+}
+
+// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
+template <typename T>
+inline typename std::enable_if<
+ std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
+ typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type
+calculate_min(T in)
+{
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ return wrapper::vpmin(pmin, pmin);
+}
+
+// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
+template <typename T>
+inline typename std::enable_if<
+ std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
+ typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type
+calculate_min(T in)
+{
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ pmin = wrapper::vpmin(pmin, pmin);
+ pmin = wrapper::vpmin(pmin, pmin);
+ return wrapper::vpmin(pmin, pmin);
+}
+
+// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
+template <typename T>
+inline typename std::enable_if<
+ std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
+ typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type
+calculate_max(T in)
+{
+ auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ return wrapper::vpmax(pmax, pmax);
+}
+
+// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
+template <typename T>
+inline typename std::enable_if<
+ std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
+ typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type
+calculate_max(T in)
+{
+ auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ pmax = wrapper::vpmax(pmax, pmax);
+ pmax = wrapper::vpmax(pmax, pmax);
+ return wrapper::vpmax(pmax, pmax);
+}
+
+template <typename T>
+uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
+{
+ uint32x4_t res_idx_mask{0};
+ uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
+
+ if (op == ReductionOperation::ARG_IDX_MIN)
+ {
+ auto pmin = calculate_min(vec_res_value);
+ auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
+ res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
+ }
+ else
+ {
+ auto pmax = calculate_max(vec_res_value);
+ auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
+ res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
+ }
+
+ res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
+ pmin = wrapper::vpmin(pmin, pmin);
+ uint32_t res = wrapper::vgetlane(pmin, 0);
+
+ return (res - 0xFFFFFFFF);
+}
+
+template <typename T>
+uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
+{
+ uint32x4x4_t res_idx_mask{{0}};
+ uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
+ uint8x16_t mask_u8{0};
+ if (op == ReductionOperation::ARG_IDX_MIN)
+ {
+ auto pmin = calculate_min(vec_res_value);
+ mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
+ }
+ else
+ {
+ auto pmax = calculate_max(vec_res_value);
+ mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
+ }
+
+ // Widen vectors
+ auto wide_u16_1 =
+ wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
+ auto wide_u16_2 =
+ wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
+ auto wide_u32_1 =
+ wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
+ auto wide_u32_2 =
+ wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
+ auto wide_u32_3 =
+ wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
+ auto wide_u32_4 =
+ wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
+ res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
+ res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
+ res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
+ res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
+ res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
+ res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
+ res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
+ res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
+
+ uint32_t res = 0xFFFFFFFF;
+ int iter = 0;
+ do
+ {
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
+ pmin = wrapper::vpmin(pmin, pmin);
+ res = std::min(wrapper::vgetlane(pmin, 0), res);
+ iter++;
+ } while (iter < 4);
+
+ return (res - 0xFFFFFFFF);
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+uint32x4x4_t inline calculate_index(
+ uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
+{
+ uint32x4x2_t mask{0};
+ uint16x8_t mask_u16{0};
+ if (op == ReductionOperation::ARG_IDX_MIN)
+ {
+ mask_u16 = wrapper::vcgt(b, a);
+ }
+ else
+ {
+ mask_u16 = wrapper::vclt(b, a);
+ }
+ mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16));
+ mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16));
+ uint32x4x2_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, {idx + 4, idx + 5, idx + 6, idx + 7}}};
+ if (axis != 0)
+ {
+ vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
+ vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
+ }
+ uint32x4x4_t res = {wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
+ wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]), 0, 0};
+
+ return res;
+}
+
+// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
+inline float16x4_t calculate_min(float16x8_t in)
+{
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ pmin = wrapper::vpmin(pmin, pmin);
+ return wrapper::vpmin(pmin, pmin);
+}
+// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
+inline float16x4_t calculate_max(float16x8_t in)
+{
+ auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ pmax = wrapper::vpmax(pmax, pmax);
+ return wrapper::vpmax(pmax, pmax);
+}
+
+template <>
+inline uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
+{
+ uint32x4x2_t res_idx_mask{0};
+ uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
+ uint16x8_t mask_u16;
+ if (op == ReductionOperation::ARG_IDX_MIN)
+ {
+ auto pmin = calculate_min(vec_res_value);
+ mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
+ }
+ else
+ {
+ auto pmax = calculate_max(vec_res_value);
+ mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
+ }
+
+ // Widen vectors
+ auto wide_u32_1 =
+ wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
+ auto wide_u32_2 =
+ wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
+ res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
+ res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
+ res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
+ res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
+
+ uint32_t res = 0xFFFFFFFF;
+ uint32_t iter = 0;
+ do
+ {
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
+ pmin = wrapper::vpmin(pmin, pmin);
+ res = std::min(wrapper::vgetlane(pmin, 0), res);
+ iter++;
+ } while (iter < 2);
+
+ return (res - 0xFFFFFFFF);
+}
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+template <class F>
+class Reducer
+{
+public:
+ static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
+ {
+ // Set out window
+ Window out_window(window);
+ out_window.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ f(window, out_window, input, output, op);
+ }
+ static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
+ {
+ // Set in window
+ Window in_window(window);
+ Window out_window(window);
+
+ in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
+ out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
+
+ f(in_window, out_window, input, output, 1, op);
+ }
+ static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
+ {
+ // Set in window
+ Window in_window(window);
+ Window out_window(window);
+
+ in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
+ out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
+
+ f(in_window, out_window, input, output, 2, op);
+ }
+ static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
+ {
+ // Set in/out window
+ Window in_window(window);
+ Window out_window(window);
+
+ in_window.set(3, Window::Dimension(0, 1, 1));
+ out_window.set(3, Window::Dimension(0, 1, 1));
+
+ f(in_window, out_window, input, output, 3, op);
+ }
+};
+
+template <typename T, int S>
+struct RedOpX
+{
+ /** SIMD vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+ inline void operator()(
+ const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
+ {
+ const size_t input_dim_0 = in->info()->dimension(0);
+ const int window_step_x = 16 / sizeof(T);
+ const auto window_start_x = static_cast<int>(in_window.x().start());
+ const auto window_end_x = static_cast<int>(in_window.x().end());
+
+ Window in_win_no_pad = in_window;
+ in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator input(in, in_win_no_pad);
+ Iterator output(out, out_window);
+
+ execute_window_loop(
+ in_win_no_pad,
+ [&](const Coordinates &)
+ {
+ const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
+
+ auto init_res_value = static_cast<T>(0.f);
+ switch (op)
+ {
+ case ReductionOperation::ARG_IDX_MAX:
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::MIN:
+ case ReductionOperation::MAX:
+ {
+ init_res_value = static_cast<T>(*input_ptr);
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ init_res_value = static_cast<T>(1.f);
+ break;
+ }
+ default:
+ break;
+ }
+ auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
+ uint32x4x4_t vec_res_idx{{0}};
+
+ // Compute window_step_x elements per iteration
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const auto vec_elements = wrapper::vloadq(input_ptr + x);
+ switch (op)
+ {
+ case ReductionOperation::SUM_SQUARE:
+ vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
+ break;
+ case ReductionOperation::MEAN_SUM:
+ case ReductionOperation::SUM:
+ vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
+ break;
+ case ReductionOperation::PROD:
+ vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
+ break;
+ case ReductionOperation::ARG_IDX_MIN:
+ {
+ auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value,
+ vec_res_idx, op, 0);
+ vec_res_value = temp_vec_res_value;
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MAX:
+ {
+ auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
+ vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value,
+ vec_res_idx, op, 0);
+ vec_res_value = temp_vec_res_value;
+ break;
+ }
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ break;
+ }
+ case ReductionOperation::MAX:
+ {
+ vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ }
+
+ switch (op)
+ {
+ case ReductionOperation::SUM:
+ case ReductionOperation::MEAN_SUM:
+ case ReductionOperation::SUM_SQUARE:
+ {
+#ifdef ARM_COMPUTE_DEBUG_ENABLED
+ auto res = static_cast<T>(0.f);
+ for (int i = 0; i < S; ++i)
+ {
+ res += wrapper::vgetlane(vec_res_value, i);
+ }
+#else // ARM_COMPUTE_DEBUG_ENABLED
+ auto carry_res =
+ wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
+ for (int i = 0; i < S / 4; ++i)
+ {
+ carry_res = wrapper::vpadd(carry_res, carry_res);
+ }
+ auto res = wrapper::vgetlane(carry_res, 0);
+#endif // ARM_COMPUTE_DEBUG_ENABLED
+ if (op == ReductionOperation::SUM_SQUARE)
+ {
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ res += (*(input_ptr + x)) * (*(input_ptr + x));
+ }
+ }
+ else
+ {
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ res += *(input_ptr + x);
+ }
+ }
+
+ if (op == ReductionOperation::MEAN_SUM)
+ {
+ res /= input_dim_0;
+ }
+
+ *(reinterpret_cast<T *>(output.ptr())) = res;
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ auto carry_res =
+ wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
+ T res = 1;
+ for (int i = 0; i < S / 2; ++i)
+ {
+ res *= wrapper::vgetlane(carry_res, i);
+ }
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ res *= *(input_ptr + x);
+ }
+
+ *(reinterpret_cast<T *>(output.ptr())) = res;
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MIN:
+ {
+ auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
+ auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ if (*(input_ptr + x) < res)
+ {
+ idx = x;
+ res = *(input_ptr + x);
+ }
+ }
+ *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MAX:
+ {
+ auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
+ auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ if (*(input_ptr + x) > res)
+ {
+ idx = x;
+ res = *(input_ptr + x);
+ }
+ }
+ *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
+ break;
+ }
+ case ReductionOperation::MIN:
+ {
+ auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
+ }
+ *(reinterpret_cast<T *>(output.ptr())) = res;
+ break;
+ }
+ case ReductionOperation::MAX:
+ {
+ auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
+ }
+ *(reinterpret_cast<T *>(output.ptr())) = res;
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ },
+ input, output);
+ }
+};
+
+template <typename T>
+struct RedOpX_quantized
+{
+ inline void operator()(
+ const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
+ {
+ using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
+
+ const auto oq_info = out->info()->quantization_info().uniform();
+
+ const TensorInfo in_info = *(in->info());
+ const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
+
+ const int window_step_x = 16 / sizeof(T);
+ const auto window_start_x = static_cast<int>(in_window.x().start());
+ const auto window_end_x = static_cast<int>(in_window.x().end());
+
+ Window in_win_no_pad = in_window;
+ in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator input(in, in_win_no_pad);
+ Iterator output(out, out_window);
+
+ const auto in_offset = static_cast<float>(iq_info.offset);
+ const float in_scale = iq_info.scale;
+
+ const auto out_offset = static_cast<float>(oq_info.offset);
+ const float out_scale = oq_info.scale;
+
+ const auto num_elements = static_cast<float>(in_info.dimension(0));
+
+ const float A = in_scale / (out_scale * num_elements);
+ const float B = out_offset - (in_scale * in_offset) / (out_scale);
+
+ execute_window_loop(
+ in_win_no_pad,
+ [&](const Coordinates &)
+ {
+ const auto input_ptr = reinterpret_cast<T *>(input.ptr());
+
+ auto vec_res_value1 =
+ wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
+ auto vec_res_value2 =
+ wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
+ auto vec_res_value3 =
+ wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
+ auto vec_res_value4 =
+ wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
+
+ auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
+ auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
+ auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
+ auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
+
+ typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = {0};
+
+ if (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN ||
+ op == ReductionOperation::MIN || op == ReductionOperation::MAX)
+ {
+ vec_res_value = wrapper::vdup_n(*input_ptr, wrapper::traits::vector_128_tag{});
+ }
+
+ uint32x4x4_t vec_res_idx{{0}};
+ // Compute window_step_x elements per iteration
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const auto vec_elements = wrapper::vloadq(input_ptr + x);
+ switch (op)
+ {
+ case ReductionOperation::SUM:
+ case ReductionOperation::MEAN_SUM:
+ {
+ const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
+ const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
+
+ const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
+ const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
+ const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
+ const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
+
+ vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
+ vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
+ vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
+ vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
+ const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale);
+
+ const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
+ const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
+
+ const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
+ const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
+ const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
+ const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
+
+ auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
+ auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
+ auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
+ auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
+
+ //de-quantize vec_elements
+ temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
+ temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
+ temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
+ temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
+
+ vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
+ vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
+ vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
+ vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MIN:
+ {
+ auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>(
+ x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
+ vec_res_value = temp_vec_res_value;
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MAX:
+ {
+ auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
+ vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>(
+ x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
+ vec_res_value = temp_vec_res_value;
+ break;
+ }
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ break;
+ }
+ case ReductionOperation::MAX:
+ {
+ vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ }
+
+ switch (op)
+ {
+ case ReductionOperation::ARG_IDX_MIN:
+ {
+ auto idx =
+ calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
+ auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ if (*(input_ptr + x) < res)
+ {
+ idx = x;
+ res = *(input_ptr + x);
+ }
+ }
+ *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MAX:
+ {
+ auto idx =
+ calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
+ auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ if (*(input_ptr + x) > res)
+ {
+ idx = x;
+ res = *(input_ptr + x);
+ }
+ }
+ *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
+ break;
+ }
+ case ReductionOperation::MIN:
+ {
+ auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
+ }
+ *(reinterpret_cast<T *>(output.ptr())) = res;
+ break;
+ }
+ case ReductionOperation::MAX:
+ {
+ auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
+ }
+ *(reinterpret_cast<T *>(output.ptr())) = res;
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
+ carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
+ carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
+
+ float res = wrapper::vgetlane(carry_res, 0);
+ res *= wrapper::vgetlane(carry_res, 1);
+ res *= wrapper::vgetlane(carry_res, 2);
+ res *= wrapper::vgetlane(carry_res, 3);
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ //de-quantize input
+ if (std::is_same<T, uint8_t>::value)
+ {
+ res *= dequantize_qasymm8(*(input_ptr + x), iq_info);
+ }
+ else
+ {
+ res *= dequantize_qasymm8_signed(*(input_ptr + x), iq_info);
+ }
+ }
+
+ //re-quantize result
+ if (std::is_same<T, uint8_t>::value)
+ {
+ res = quantize_qasymm8(res, iq_info);
+ }
+ else
+ {
+ res = quantize_qasymm8_signed(res, iq_info);
+ }
+
+ *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res);
+ break;
+ }
+ case ReductionOperation::SUM:
+ case ReductionOperation::MEAN_SUM:
+ {
+ auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
+ carry_res = wrapper::vadd(carry_res, vec_res_value3);
+ carry_res = wrapper::vadd(carry_res, vec_res_value4);
+
+ auto carry_paddition =
+ wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
+ carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
+ auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0));
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ res += *(input_ptr + x);
+ }
+
+ if (op == ReductionOperation::MEAN_SUM)
+ {
+ const int32_t resFinal = A * (static_cast<float>(res)) + B;
+
+ *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(resFinal);
+ }
+ else
+ {
+ // Subtract accumulated offsets
+ res -= (in_info.dimension(0) - 1) * iq_info.offset;
+ *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res);
+ }
+
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ },
+ input, output);
+ }
+};
+
+template <typename T, int S>
+struct RedOpYZW
+{
+ /** SIMD vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+ using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
+
+ inline void operator()(const Window &in_window,
+ Window &out_window,
+ const ITensor *in,
+ ITensor *out,
+ int axis,
+ const ReductionOperation op)
+ {
+ const TensorInfo in_info = *(in->info());
+ const int window_step_x = 16 / sizeof(T);
+ const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
+ const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
+ // As it split over x-axis, need to set the correct spiltted window start and end.
+ const auto window_start_x = static_cast<int>(0);
+ const auto window_end_x = static_cast<int>(in_window.shape().x());
+
+ Window in_win_no_pad = in_window;
+ in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
+ Window out_win_no_pad = out_window;
+ out_win_no_pad.set(Window::DimX,
+ Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
+
+ Iterator input(in, in_win_no_pad);
+ Iterator output(out, out_win_no_pad);
+
+ execute_window_loop(
+ in_win_no_pad,
+ [&](const Coordinates &)
+ {
+ const auto input_ptr = reinterpret_cast<T *>(input.ptr());
+
+ // Compute window_step_x elements per iteration
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ neon_vector vec_res_value = {0};
+ switch (op)
+ {
+ case ReductionOperation::ARG_IDX_MAX:
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::MIN:
+ case ReductionOperation::MAX:
+ {
+ vec_res_value = wrapper::vloadq(input_ptr + x);
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
+ break;
+ }
+ default:
+ {
+ vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+ break;
+ }
+ }
+ uint32x4x4_t vec_res_idx{{0}};
+
+ for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
+ {
+ const T *in_ptr =
+ reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
+ const auto vec_elements = wrapper::vloadq(in_ptr);
+ switch (op)
+ {
+ case ReductionOperation::SUM:
+ case ReductionOperation::MEAN_SUM:
+ vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
+ break;
+ case ReductionOperation::SUM_SQUARE:
+ vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
+ break;
+ case ReductionOperation::PROD:
+ vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
+ break;
+ case ReductionOperation::ARG_IDX_MIN:
+ {
+ auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ vec_res_idx =
+ calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
+ vec_res_value = temp_vec_res_value;
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MAX:
+ {
+ auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
+ vec_res_idx =
+ calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
+ vec_res_value = temp_vec_res_value;
+ break;
+ }
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ break;
+ }
+ case ReductionOperation::MAX:
+ {
+ vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ }
+
+ if (op == ReductionOperation::MEAN_SUM)
+ {
+ auto vec_width_inv =
+ wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
+ vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv);
+ }
+
+ if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
+ {
+ wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]);
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ if (std::is_same<T, float16_t>::value)
+ {
+ wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]);
+ }
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ }
+ else
+ {
+ wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value);
+ }
+ }
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ auto res_value = 0.f;
+ switch (op)
+ {
+ case ReductionOperation::ARG_IDX_MAX:
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::MIN:
+ case ReductionOperation::MAX:
+ {
+ res_value = *(input_ptr + x);
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ res_value = static_cast<T>(1.f);
+ break;
+ }
+ default:
+ {
+ res_value = static_cast<T>(0.f);
+ break;
+ }
+ }
+
+ uint32_t res_idx = 0;
+ for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
+ {
+ const T *in_ptr =
+ reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
+
+ switch (op)
+ {
+ case ReductionOperation::SUM:
+ case ReductionOperation::MEAN_SUM:
+ res_value += *in_ptr;
+ break;
+ case ReductionOperation::SUM_SQUARE:
+ res_value += *in_ptr * *in_ptr;
+ break;
+ case ReductionOperation::PROD:
+ res_value *= *in_ptr;
+ break;
+ case ReductionOperation::ARG_IDX_MIN:
+ {
+ if (*in_ptr < res_value)
+ {
+ res_value = *in_ptr;
+ res_idx = dim;
+ }
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MAX:
+ {
+ if (*in_ptr > res_value)
+ {
+ res_value = *in_ptr;
+ res_idx = dim;
+ }
+ break;
+ }
+ case ReductionOperation::MIN:
+ {
+ res_value = *in_ptr < res_value ? *in_ptr : res_value;
+ break;
+ }
+ case ReductionOperation::MAX:
+ {
+ res_value = *in_ptr > res_value ? *in_ptr : res_value;
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ }
+
+ if (op == ReductionOperation::MEAN_SUM)
+ {
+ res_value /= in_info.dimension(axis);
+ }
+
+ if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
+ {
+ *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx;
+ }
+ else
+ {
+ *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value;
+ }
+ }
+ },
+ input, output);
+ }
+};
+
+template <typename T, int S, int axis, ReductionOperation op>
+struct RedOpYZW_complex
+{
+ /** SIMD vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+ using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
+
+ inline void operator()(
+ const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation)
+ {
+ ARM_COMPUTE_ERROR_ON(axis != 2);
+ ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM);
+
+ const TensorInfo in_info = *(in->info());
+ const size_t stride_z = in_info.strides_in_bytes()[axis];
+ const int window_step_x = 16 / sizeof(T);
+ const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
+ const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
+ // As it split over x-axis, need to set the correct spiltted window start and end.
+ const auto window_start_x = static_cast<int>(0);
+ const auto window_end_x = static_cast<int>(in_window.shape().x());
+
+ Window in_win_no_pad = in_window;
+ in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
+ Window out_win_no_pad = out_window;
+ out_win_no_pad.set(Window::DimX,
+ Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
+
+ Iterator input(in, in_win_no_pad);
+ Iterator output(out, out_win_no_pad);
+
+ execute_window_loop(
+ in_win_no_pad,
+ [&](const Coordinates &)
+ {
+ // Compute window_step_x elements per iteration
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ neon_vector vec_res_value_0 = {0};
+ neon_vector vec_res_value_1 = {0};
+
+ vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+ vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+
+ T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
+ for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
+ {
+ T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
+ T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim);
+
+ const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
+ const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
+
+ vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
+ vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
+ }
+
+ wrapper::vstore(out_ptr, vec_res_value_0);
+ wrapper::vstore(out_ptr + 4, vec_res_value_1);
+ }
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ auto res_value_0 = 0.f;
+ auto res_value_1 = 0.f;
+
+ T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
+ for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
+ {
+ T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
+ res_value_0 += *in_ptr;
+ res_value_1 += *(in_ptr + 1);
+ }
+ *out_ptr = res_value_0;
+ *(out_ptr + 1) = res_value_1;
+ }
+ },
+ input, output);
+ }
+};
+
+template <typename T>
+struct RedOpYZW_quantized
+{
+ inline void operator()(const Window &in_window,
+ Window &out_window,
+ const ITensor *in,
+ ITensor *out,
+ int axis,
+ const ReductionOperation op)
+ {
+ const TensorInfo in_info = *(in->info());
+ const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
+ using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
+
+ const auto oq_info = out->info()->quantization_info().uniform();
+
+ const int window_step_x = 16 / sizeof(T);
+ const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
+ const auto window_end_x_tmp = static_cast<int>(in_window.x().end());
+ // As it split over x-axis, need to set the correct spiltted window start and end.
+ const auto window_start_x = static_cast<int>(0);
+ const auto window_end_x = static_cast<int>(in_window.shape().x());
+
+ Window in_win_no_pad = in_window;
+ in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
+ Window out_win_no_pad = out_window;
+ out_win_no_pad.set(Window::DimX,
+ Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
+
+ Iterator input(in, in_win_no_pad);
+ Iterator output(out, out_win_no_pad);
+
+ using vector_type =
+ typename wrapper::traits::neon_bitvector<PromotedType, wrapper::traits::BitWidth::W128>::type;
+ using vector_type_f = typename wrapper::traits::neon_vector<float, 4>::type;
+
+ vector_type vec_res_value1{};
+ vector_type vec_res_value2{};
+ vector_type vec_res_value3{};
+ vector_type vec_res_value4{};
+
+ vector_type_f vec_res_value1_f{};
+ vector_type_f vec_res_value2_f{};
+ vector_type_f vec_res_value3_f{};
+ vector_type_f vec_res_value4_f{};
+
+ const float in_offset = static_cast<float>(iq_info.offset);
+ const float in_scale = iq_info.scale;
+
+ const float out_offset = static_cast<float>(oq_info.offset);
+ const float out_scale = oq_info.scale;
+
+ const float num_elements = static_cast<float>(in_info.dimension(axis));
+
+ const float A = in_scale / (out_scale * num_elements);
+ const float B = out_offset - (in_scale * in_offset) / (out_scale);
+
+ const auto vec_A = wrapper::vdup_n(static_cast<float>(A), wrapper::traits::vector_128_tag{});
+ const auto vec_B = wrapper::vdup_n(static_cast<float>(B), wrapper::traits::vector_128_tag{});
+
+ execute_window_loop(
+ in_win_no_pad,
+ [&](const Coordinates &)
+ {
+ const auto input_ptr = reinterpret_cast<T *>(input.ptr());
+
+ // Compute window_step_x elements per iteration
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ uint32x4x4_t vec_res_idx{{0}};
+ vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
+ vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
+ vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
+ vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
+
+ vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
+ vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
+ vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
+ vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
+
+ auto vec_res_value = wrapper::vloadq(input_ptr + x);
+
+ for (unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
+ {
+ const T *in_ptr = input_ptr + x + in_info.strides_in_bytes()[axis] * index_dim;
+ const auto vec_elements = wrapper::vloadq(in_ptr);
+ switch (op)
+ {
+ case ReductionOperation::SUM:
+ case ReductionOperation::MEAN_SUM:
+ {
+ const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
+ const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
+
+ const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
+ const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
+ const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
+ const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
+
+ vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
+ vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
+ vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
+ vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset),
+ wrapper::traits::vector_128_tag{});
+ const auto scale32x4f_4 =
+ wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{});
+
+ const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
+ const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
+
+ const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
+ const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
+ const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
+ const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
+
+ auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
+ auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
+ auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
+ auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
+
+ //de-quantize vec_elements
+ temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4);
+ temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4);
+ temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4);
+ temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4);
+
+ vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f);
+ vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f);
+ vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f);
+ vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f);
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MIN:
+ {
+ auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value,
+ vec_res_idx, op, axis);
+ vec_res_value = temp_vec_res_value;
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MAX:
+ {
+ auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
+ vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value,
+ vec_res_idx, op, axis);
+ vec_res_value = temp_vec_res_value;
+ break;
+ }
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ break;
+ }
+ case ReductionOperation::MAX:
+ {
+ vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ }
+
+ switch (op)
+ {
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::ARG_IDX_MAX:
+ {
+ wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x), vec_res_idx.val[0]);
+ wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 4, vec_res_idx.val[1]);
+ wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 8, vec_res_idx.val[2]);
+ wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 12,
+ vec_res_idx.val[3]);
+ break;
+ }
+ case ReductionOperation::MIN:
+ case ReductionOperation::MAX:
+ {
+ wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), vec_res_value);
+ break;
+ }
+ case ReductionOperation::SUM:
+ {
+ // Subtract offsets
+ auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset);
+
+ auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1);
+ auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2);
+ auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3);
+ auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4);
+
+ vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets);
+ vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets);
+ vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets);
+ vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets);
+
+ const auto temp16x8t_1 =
+ wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2));
+ const auto temp16x8t_2 =
+ wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4));
+
+ combine_and_store<T>(temp16x8t_1, temp16x8t_2, output, x);
+ break;
+ }
+ case ReductionOperation::MEAN_SUM:
+ {
+ vec_res_value1_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value1), vec_A);
+ vec_res_value2_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value2), vec_A);
+ vec_res_value3_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value3), vec_A);
+ vec_res_value4_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value4), vec_A);
+
+#ifdef __aarch64__
+ vec_res_value1 = wrapper::vcvta<PromotedType>(vec_res_value1_f);
+ vec_res_value2 = wrapper::vcvta<PromotedType>(vec_res_value2_f);
+ vec_res_value3 = wrapper::vcvta<PromotedType>(vec_res_value3_f);
+ vec_res_value4 = wrapper::vcvta<PromotedType>(vec_res_value4_f);
+#else // defined(__aarch64__)
+ vec_res_value1 = wrapper::vcvt<PromotedType>(vec_res_value1_f);
+ vec_res_value2 = wrapper::vcvt<PromotedType>(vec_res_value2_f);
+ vec_res_value3 = wrapper::vcvt<PromotedType>(vec_res_value3_f);
+ vec_res_value4 = wrapper::vcvt<PromotedType>(vec_res_value4_f);
+#endif // __aarch64__
+
+ const auto temp16x8t_1 =
+ wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
+ const auto temp16x8t_2 =
+ wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
+ auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
+
+ wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ const auto offset32x4f_4 =
+ wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
+ const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale));
+
+ //re-quantize
+ vec_res_value1_f =
+ wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
+ vec_res_value2_f =
+ wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
+ vec_res_value3_f =
+ wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
+ vec_res_value4_f =
+ wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
+
+ vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
+ vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
+ vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
+ vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
+
+ const auto temp16x8t_1 =
+ wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
+ const auto temp16x8t_2 =
+ wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
+ auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
+
+ wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ }
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ float res_value = 0.f;
+ int32_t res_value_q = 0;
+
+ switch (op)
+ {
+ case ReductionOperation::ARG_IDX_MAX:
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::MIN:
+ case ReductionOperation::MAX:
+ {
+ res_value = *(input_ptr + x);
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ res_value = static_cast<T>(1.0f);
+ break;
+ }
+ default:
+ {
+ res_value = static_cast<T>(0.0f);
+ break;
+ }
+ }
+ uint32_t res_idx = 0;
+
+ for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
+ {
+ const T *in_ptr =
+ reinterpret_cast<T *>(input.ptr() + x + in_info.strides_in_bytes()[axis] * dim);
+ switch (op)
+ {
+ case ReductionOperation::SUM:
+ {
+ res_value += *in_ptr;
+ break;
+ }
+ case ReductionOperation::MEAN_SUM:
+ {
+ res_value_q += *in_ptr;
+ break;
+ }
+ case ReductionOperation::SUM_SQUARE:
+ {
+ res_value += *in_ptr * *in_ptr;
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ //de-quantize input
+ if (std::is_same<T, uint8_t>::value)
+ {
+ res_value *= dequantize_qasymm8(*in_ptr, iq_info);
+ }
+ else
+ {
+ res_value *= dequantize_qasymm8_signed(*in_ptr, iq_info);
+ }
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MIN:
+ {
+ if (*in_ptr < res_value)
+ {
+ res_value = *in_ptr;
+ res_idx = dim;
+ }
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MAX:
+ {
+ if (*in_ptr > res_value)
+ {
+ res_value = *in_ptr;
+ res_idx = dim;
+ }
+ break;
+ }
+ case ReductionOperation::MIN:
+ {
+ res_value = *in_ptr < res_value ? *in_ptr : res_value;
+ break;
+ }
+ case ReductionOperation::MAX:
+ {
+ res_value = *in_ptr > res_value ? *in_ptr : res_value;
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ }
+
+ switch (op)
+ {
+ case ReductionOperation::MEAN_SUM:
+ {
+ // Apply previously calculated coefficients (with rounding on aarch64)
+#ifdef __aarch64__
+ const int32_t res =
+ arm_compute::support::cpp11::round(A * (static_cast<float>(res_value_q)) + B);
+#else // defined(__aarch64__)
+ const int32_t res = A * (static_cast<float>(res_value_q)) + B;
+#endif // __aarch64__
+ *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res);
+ break;
+ }
+ case ReductionOperation::SUM:
+ {
+ // Subtract accumulated offsets
+ res_value -= (in_info.dimension(axis) - 1) * iq_info.offset;
+ *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value);
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ //re-quantize result
+ T res = 0;
+ if (std::is_same<T, uint8_t>::value)
+ {
+ res = quantize_qasymm8(res_value, iq_info);
+ }
+ else
+ {
+ res = quantize_qasymm8_signed(res_value, iq_info);
+ }
+ *(reinterpret_cast<T *>(output.ptr() + x)) = res;
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::ARG_IDX_MAX:
+ {
+ *(reinterpret_cast<uint32_t *>(output.ptr() + x * 4)) = res_idx;
+ break;
+ }
+ default:
+ *(reinterpret_cast<T *>(output.ptr() + x)) = res_value;
+ }
+ }
+ },
+ input, output);
+ }
+};
+
+} // namespace arm_compute
+#endif // ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_H
diff --git a/src/cpu/kernels/reduction_layer/generic/neon/integer.cpp b/src/cpu/kernels/reduction_layer/generic/neon/integer.cpp
new file mode 100644
index 0000000000..ad66b456ac
--- /dev/null
+++ b/src/cpu/kernels/reduction_layer/generic/neon/integer.cpp
@@ -0,0 +1,62 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void reduce_RedOpX_reduceX_S32_4(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
+}
+
+void reduce_RedOpYZW_reduceY_S32_4(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
+}
+void reduce_RedOpYZW_reduceZ_S32_4(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
+}
+
+void reduce_RedOpYZW_reduceW_S32_4(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
+}
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/reduction_layer/generic/neon/list.h b/src/cpu/kernels/reduction_layer/generic/neon/list.h
new file mode 100644
index 0000000000..947c28a130
--- /dev/null
+++ b/src/cpu/kernels/reduction_layer/generic/neon/list.h
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_LIST_H
+#define ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_LIST_H
+
+#include "arm_compute/core/Helpers.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+#define DECLARE_REDUCTION_KERNEL(func_name) \
+ void func_name(const Window &window, const ITensor *in, ITensor *out, const ReductionOperation op)
+
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_complex_reduceZ_float32_4_2_SUM);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_float32_4);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_float32_4);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_float32_4);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_float32_4);
+
+DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_float16_8);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_float16_8);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_float16_8);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_float16_8);
+
+DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_S32_4);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_S32_4);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_S32_4);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_S32_4);
+
+DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_qasymm8);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_qasymm8);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_qasymm8);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_qasymm8);
+
+DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_qasymm8_signed);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_qasymm8_signed);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_qasymm8_signed);
+DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_qasymm8_signed);
+
+#undef DECLARE_REDUCTION_KERNEL
+} // namespace cpu
+} // namespace arm_compute
+#endif // ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_LIST_H
diff --git a/src/cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp b/src/cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp
new file mode 100644
index 0000000000..bc711c6855
--- /dev/null
+++ b/src/cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void reduce_RedOpX_reduceX_qasymm8(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, RedOpX_quantized<uint8_t>(), op);
+}
+
+void reduce_RedOpYZW_reduceY_qasymm8(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
+}
+
+void reduce_RedOpYZW_reduceZ_qasymm8(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
+}
+
+void reduce_RedOpYZW_reduceW_qasymm8(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
+}
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp
new file mode 100644
index 0000000000..10ac3d6715
--- /dev/null
+++ b/src/cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void reduce_RedOpX_reduceX_qasymm8_signed(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), op);
+}
+
+void reduce_RedOpYZW_reduceY_qasymm8_signed(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, RedOpYZW_quantized<int8_t>(), op);
+}
+
+void reduce_RedOpYZW_reduceZ_qasymm8_signed(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<int8_t>(), op);
+}
+
+void reduce_RedOpYZW_reduceW_qasymm8_signed(const Window &window,
+ const ITensor *input,
+ ITensor *output,
+ const ReductionOperation op)
+{
+ return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, RedOpYZW_quantized<int8_t>(), op);
+}
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/fp16.cpp b/src/cpu/kernels/softmax/generic/neon/fp16.cpp
index da62d2d614..425fcf7ac6 100644
--- a/src/cpu/kernels/softmax/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/fp16.cpp
@@ -33,9 +33,15 @@ namespace cpu
{
template <bool IS_LOG>
-void neon_fp16_softmax(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
+void neon_fp16_softmax(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr)
{
+ ARM_COMPUTE_UNUSED(lut_ptr);
if (axis == 0)
{
return neon_softmax_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window);
@@ -46,10 +52,20 @@ void neon_fp16_softmax(
}
}
-template void neon_fp16_softmax<true>(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
-template void neon_fp16_softmax<false>(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_fp16_softmax<true>(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr);
+template void neon_fp16_softmax<false>(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/fp32.cpp b/src/cpu/kernels/softmax/generic/neon/fp32.cpp
index 0701620636..a64946eb74 100644
--- a/src/cpu/kernels/softmax/generic/neon/fp32.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/fp32.cpp
@@ -31,9 +31,15 @@ namespace cpu
{
template <bool IS_LOG>
-void neon_fp32_softmax(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
+void neon_fp32_softmax(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr)
{
+ ARM_COMPUTE_UNUSED(lut_ptr);
if (axis == 0)
{
return neon_softmax_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window);
@@ -44,10 +50,20 @@ void neon_fp32_softmax(
}
}
-template void neon_fp32_softmax<true>(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
-template void neon_fp32_softmax<false>(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_fp32_softmax<true>(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr);
+template void neon_fp32_softmax<false>(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
index d39240bb38..369f9bb005 100644
--- a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
@@ -30,9 +30,15 @@ namespace arm_compute
namespace cpu
{
template <bool IS_LOG>
-void neon_qasymm8_softmax(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
+void neon_qasymm8_softmax(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr)
{
+ ARM_COMPUTE_UNUSED(lut_ptr);
if (axis == 0)
{
return neon_softmax_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window);
@@ -43,10 +49,20 @@ void neon_qasymm8_softmax(
}
}
-template void neon_qasymm8_softmax<true>(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
-template void neon_qasymm8_softmax<false>(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_qasymm8_softmax<true>(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr);
+template void neon_qasymm8_softmax<false>(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
index 26fd5dbfa0..594ceb7654 100644
--- a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
@@ -30,9 +30,15 @@ namespace arm_compute
namespace cpu
{
template <bool IS_LOG>
-void neon_qasymm8_signed_softmax(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
+void neon_qasymm8_signed_softmax(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr)
{
+ ARM_COMPUTE_UNUSED(lut_ptr);
if (axis == 0)
{
return neon_softmax_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window);
@@ -43,10 +49,20 @@ void neon_qasymm8_signed_softmax(
}
}
-template void neon_qasymm8_signed_softmax<true>(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
-template void neon_qasymm8_signed_softmax<false>(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_qasymm8_signed_softmax<true>(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr);
+template void neon_qasymm8_signed_softmax<false>(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/sme2/fp16.cpp b/src/cpu/kernels/softmax/generic/sme2/fp16.cpp
new file mode 100644
index 0000000000..e70c9f4793
--- /dev/null
+++ b/src/cpu/kernels/softmax/generic/sme2/fp16.cpp
@@ -0,0 +1,781 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// SoftMax
+//
+// Steps:
+// * Find max: max_value = max(src)
+// * Regularize: dst[i] = exp(src[i] - max_value)
+// sum_value = sum(dst)
+// * Normalize: dst[i] = dst[i] / sum_value
+void sme2_f16_softmax_kernel( //
+ const float16_t *src,
+ float16_t *dst,
+ float beta,
+ const uintptr_t shape[4],
+ const uintptr_t src_strides[4],
+ const uintptr_t dst_strides[4])
+{
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+
+ // Registers
+ //
+ // * x9: temporary, index
+ // * x10: temporary, -inf
+ // * x11: temporary, 0
+ // * x12: temporary, 1.0f
+ // * x13: temporary, body_length
+ //
+ // * x20: index_3
+ // * x21: src_3
+ // * x22: dst_3
+ // * x23: index_2
+ // * x24: src_2
+ // * x25: dst_2
+ // * x26: index_1
+ // * x27: src_1
+ // * x28: dst_1
+ //
+ // * z0: c1
+ // * z1: c2
+ // * z2: c3
+ // * z3: c4
+ // * z4: c5
+ // * z5: shift
+ // * z6: inv_ln2
+ // * z7: neg_ln2_hi
+ // * z8: neg_ln2_lo
+ // * z9: min_input
+ // * z10: 23, 0
+ // * z11: max_value
+ // * z12-z15: x, x_fp32_lower_halves, r_hi, r, r2
+ // * z16-z19: max_value, shift, z, scale, poly
+ // * z20-z21: n, p1, p12345
+ // * z22-z23: n, p23, p2345
+ // * z24-z25: p45
+ // * z26: beta
+ // * z28-z31: sum_value, x_fp32_upper_halves
+ //
+ // * za0-za3: sum_value
+ //
+ // * p0: all-true
+ // * p1: left-over predicate for find-max & normalize loops
+ // * p2-p4: left-over predicates for regularize loop
+ // * p4-p7: underflow in vector loop
+ // * p5-p6: underflow in leftover loop
+ // *
+ // * pn9: all-true
+
+ // Prepares all constant values
+
+ ptrue p0.b
+ .inst 0x25207811 // ptrue pn9.b
+
+ mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ dup z0.s, w9 // c1.
+ dup z1.s, w10 // c2.
+ dup z2.s, w11 // c3.
+ dup z3.s, w12 // c4.
+ dup z4.s, w13 // c5.
+
+ mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+
+ movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+
+ dup z5.s, w9 // shift
+ dup z6.s, w10 // inv_ln2
+ dup z7.s, w11 // neg_ln2_hi
+ dup z8.s, w12 // neg_ln2_lo
+ dup z9.s, w13 // min_input
+
+ dup z26.s, %w[beta] // beta
+ fcvt h26, s26
+ dup z26.h, z26.h[0]
+
+ mov w10, #0xfc00 // -inf: 0xfc00 for fp16
+
+ mov w11, #0 // 0
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cnth x13, ALL, MUL #4
+ udiv x9, %x[length], x13
+ mul x13, x13, x9
+
+ // ==================================================
+ // 3D loop opening
+ // ==================================================
+
+ mov x20, %x[shape_3]
+ mov x21, %x[src]
+ mov x22, %x[dst]
+
+loop_3_start%=:
+ // for index_3 in shape_3 downto 1
+ cmp x20, #0
+ b.eq loop_3_end%=
+ sub x20, x20, #1
+
+ mov x23, %x[shape_2]
+ mov x24, x21
+ mov x25, x22
+
+loop_2_start%=:
+ // for index_2 in shape_2 downto 1
+ cmp x23, #0
+ b.eq loop_2_end%=
+ sub x23, x23, #1
+
+ mov x26, %x[shape_1]
+ mov x27, x24
+ mov x28, x25
+
+loop_1_start%=:
+ // for index_1 in shape_2 downto 1
+ cmp x26, #0
+ b.eq loop_1_end%=
+ sub x26, x26, #1
+
+ // ==================================================
+ // Step 1: Find max
+ // ==================================================
+
+ // ---------------------------------------------------------------- z16-z19: max_value = -inf
+ dup z16.h, w10
+ dup z17.h, w10
+ dup z18.h, w10
+ dup z19.h, w10
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // x9: index
+ dup z11.h, w10 // z11: max_value = -inf
+
+find_max_body_start%=:
+ cmp x9, x13
+ b.eq find_max_body_end%=
+
+ .inst 0xa009a76c // ld1h {z12.h-z15.h}, pn9/z, [x27, x9, LSL #1] // z12-z15: x
+ .inst 0xc16cb910 // fmax {z16.h-z19.h}, {z16.h-z19.h}, {z12.h-z15.h} // z16-z19: max_value = max(max_value, x)
+
+ inch x9, ALL, MUL #4
+ b find_max_body_start%=
+find_max_body_end%=:
+
+ // Loop for processing the leftover part.
+find_max_leftover_start%=:
+ whilelo p1.h, x9, %x[length]
+ b.none find_max_leftover_end%=
+
+ ld1h z12.h, p1/z, [x27, x9, LSL #1] // z12: x
+ fmax z16.h, p1/m, z16.h, z12.h // z16: max_value = max(max_value, x)
+
+ inch x9
+ b find_max_leftover_start%=
+find_max_leftover_end%=:
+
+ // ---------------------------------------------------------------- z16: max_value
+ .inst 0xc172b110 // fmax {z16.h-z17.h}, {z16.h-z17.h}, {z18.s-z19.h}
+ fmax z16.h, p0/m, z16.h, z17.h
+ fmaxv h16, p0, z16.h
+
+ // ---------------------------------------------------------------- z11: max_value
+ dup z11.h, z16.h[0]
+
+ // ==================================================
+ // Step 2: Regularize, i.e. Calculate exp(x - max(x)
+ // ==================================================
+
+ .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value (in fp32)
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // ---------------------------------------------------- x9: index
+
+regularize_body_start%=:
+ cmp x9, x13
+ b.eq regularize_body_end%=
+
+ // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data
+ .inst 0xa009a76c // ld1h {z12.h-z15.h}, pn9/z, [x27, x9, LSL #1] // z12-z15: x
+
+ // ---------------------------------------------------------------- z12-z15: x = input_data - max_value
+ fsub z12.h, z12.h, z11.h
+ fsub z13.h, z13.h, z11.h
+ fsub z14.h, z14.h, z11.h
+ fsub z15.h, z15.h, z11.h
+
+ // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta
+ fmul z12.h, z12.h, z26.h
+ fmul z13.h, z13.h, z26.h
+ fmul z14.h, z14.h, z26.h
+ fmul z15.h, z15.h, z26.h
+
+ // ----------------------------------------------------------------
+ // Convert fp16 values to fp32. This results in four more registers.
+ // z12 --> z12, z28
+ fcvtlt z28.s, p0/m, z12.h
+ fcvt z12.s, p0/m, z12.h
+
+ // z13 --> z13, z29
+ fcvtlt z29.s, p0/m, z13.h
+ fcvt z13.s, p0/m, z13.h
+
+ // z14 --> z14, z30
+ fcvtlt z30.s, p0/m, z14.h
+ fcvt z14.s, p0/m, z14.h
+
+ // z15 --> z15, z31
+ fcvtlt z31.s, p0/m, z15.h
+ fcvt z15.s, p0/m, z15.h
+
+ // ----------------------------------------------------------------
+ // Process z12-z15
+ // ----------------------------------------------------------------
+ // ---------------------------------------------------------------- z16-z19: shift
+ mov z16.d, z5.d
+ mov z17.d, z5.d
+ mov z18.d, z5.d
+ mov z19.d, z5.d
+
+ // ---------------------------------------------------------------- p4-p7: underflow = x < min_input
+ fcmlt p4.s, p0/z, z12.s, z9.s
+ fcmlt p5.s, p0/z, z13.s, z9.s
+ fcmlt p6.s, p0/z, z14.s, z9.s
+ fcmlt p7.s, p0/z, z15.s, z9.s
+
+ // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
+ fmla z16.s, p0/m, z12.s, z6.s
+ fmla z17.s, p0/m, z13.s, z6.s
+ fmla z18.s, p0/m, z14.s, z6.s
+ fmla z19.s, p0/m, z15.s, z6.s
+
+ // ---------------------------------------------------------------- z20-z23: n = z - shift
+ fsub z20.s, z16.s, z5.s
+ fsub z21.s, z17.s, z5.s
+ fsub z22.s, z18.s, z5.s
+ fsub z23.s, z19.s, z5.s
+
+ // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p0/m, z20.s, z7.s
+ fmla z13.s, p0/m, z21.s, z7.s
+ fmla z14.s, p0/m, z22.s, z7.s
+ fmla z15.s, p0/m, z23.s, z7.s
+
+ // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo
+ fmla z12.s, p0/m, z20.s, z8.s
+ fmla z13.s, p0/m, z21.s, z8.s
+ fmla z14.s, p0/m, z22.s, z8.s
+ fmla z15.s, p0/m, z23.s, z8.s
+
+ // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
+ dup z10.s, #23
+ urshl z16.s, p0/m, z16.s, z10.s
+ urshl z17.s, p0/m, z17.s, z10.s
+ urshl z18.s, p0/m, z18.s, z10.s
+ urshl z19.s, p0/m, z19.s, z10.s
+
+ // Processes the first 2 vectors. (z12-z13)
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z12.s, z0.s
+ fmul z21.s, z13.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z12.s, z2.s
+ fmla z23.s, p0/m, z13.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z12.s, z4.s
+ fmla z25.s, p0/m, z13.s, z4.s
+
+ // ---------------------------------------------------------------- z12-z13: r2 = r * r
+ fmul z12.s, z12.s, z12.s
+ fmul z13.s, z13.s, z13.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z12.s, z24.s
+ fmla z23.s, p0/m, z13.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z12.s, z22.s
+ fmla z21.s, p0/m, z13.s, z23.s
+
+ // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
+ fmla z16.s, p0/m, z20.s, z16.s
+ fmla z17.s, p0/m, z21.s, z17.s
+
+ // Processes the last 2 vectors (z14-z15)
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z14.s, z0.s
+ fmul z21.s, z15.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z14.s, z2.s
+ fmla z23.s, p0/m, z15.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z14.s, z4.s
+ fmla z25.s, p0/m, z15.s, z4.s
+
+ // ---------------------------------------------------------------- z14-z15: r2 = r * r
+ fmul z14.s, z14.s, z14.s
+ fmul z15.s, z15.s, z15.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z14.s, z24.s
+ fmla z23.s, p0/m, z15.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z14.s, z22.s
+ fmla z21.s, p0/m, z15.s, z23.s
+
+ // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
+ fmla z18.s, p0/m, z20.s, z18.s
+ fmla z19.s, p0/m, z21.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
+ dup z10.s, #0
+ sel z12.s, p4, z10.s, z16.s
+ sel z13.s, p5, z10.s, z17.s
+ sel z14.s, p6, z10.s, z18.s
+ sel z15.s, p7, z10.s, z19.s
+
+ // ---------------------------------------------------------------- sum in fp32
+ .inst 0xc1a17d80 // fadd za.s[w11, #0, VGx4], {z12.s-z15.s} za0-za3: sum_value = sum_value + poly
+
+ // ----------------------------------------------------------------
+ // Process z28-z31
+ // ----------------------------------------------------------------
+ // ---------------------------------------------------------------- z16-z19: shift
+ mov z16.d, z5.d
+ mov z17.d, z5.d
+ mov z18.d, z5.d
+ mov z19.d, z5.d
+
+ // ---------------------------------------------------------------- p4-p7: underflow = x < min_input
+ fcmlt p4.s, p0/z, z28.s, z9.s
+ fcmlt p5.s, p0/z, z29.s, z9.s
+ fcmlt p6.s, p0/z, z30.s, z9.s
+ fcmlt p7.s, p0/z, z31.s, z9.s
+
+ // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
+ fmla z16.s, p0/m, z28.s, z6.s
+ fmla z17.s, p0/m, z29.s, z6.s
+ fmla z18.s, p0/m, z30.s, z6.s
+ fmla z19.s, p0/m, z31.s, z6.s
+
+ // ---------------------------------------------------------------- z20-z23: n = z - shift
+ fsub z20.s, z16.s, z5.s
+ fsub z21.s, z17.s, z5.s
+ fsub z22.s, z18.s, z5.s
+ fsub z23.s, z19.s, z5.s
+
+ // ---------------------------------------------------------------- z24-z27: r_hi = x + n * neg_ln2_hi
+ fmla z28.s, p0/m, z20.s, z7.s
+ fmla z29.s, p0/m, z21.s, z7.s
+ fmla z30.s, p0/m, z22.s, z7.s
+ fmla z31.s, p0/m, z23.s, z7.s
+
+ // ---------------------------------------------------------------- z27-z30: r = r_hi + n * neg_ln2_lo
+ fmla z28.s, p0/m, z20.s, z8.s
+ fmla z29.s, p0/m, z21.s, z8.s
+ fmla z30.s, p0/m, z22.s, z8.s
+ fmla z31.s, p0/m, z23.s, z8.s
+
+ // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
+ dup z10.s, #23
+ urshl z16.s, p0/m, z16.s, z10.s
+ urshl z17.s, p0/m, z17.s, z10.s
+ urshl z18.s, p0/m, z18.s, z10.s
+ urshl z19.s, p0/m, z19.s, z10.s
+
+ // Processes the first 2 vectors. (z28-z29)
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z28.s, z0.s
+ fmul z21.s, z29.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z28.s, z2.s
+ fmla z23.s, p0/m, z29.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z25: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z28.s, z4.s
+ fmla z25.s, p0/m, z29.s, z4.s
+
+ // ---------------------------------------------------------------- z28-z29: r2 = r * r
+ fmul z28.s, z28.s, z28.s
+ fmul z29.s, z29.s, z29.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z28.s, z24.s
+ fmla z23.s, p0/m, z29.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z28.s, z22.s
+ fmla z21.s, p0/m, z29.s, z23.s
+
+ // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
+ fmla z16.s, p0/m, z20.s, z16.s
+ fmla z17.s, p0/m, z21.s, z17.s
+
+ // Processes the last 2 vectors (z30-z31)
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z30.s, z0.s
+ fmul z21.s, z31.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z30.s, z2.s
+ fmla z23.s, p0/m, z31.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z30.s, z4.s
+ fmla z25.s, p0/m, z31.s, z4.s
+
+ // ---------------------------------------------------------------- z30-z31: r2 = r * r
+ fmul z30.s, z30.s, z30.s
+ fmul z31.s, z31.s, z31.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z30.s, z24.s
+ fmla z23.s, p0/m, z31.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z30.s, z22.s
+ fmla z21.s, p0/m, z31.s, z23.s
+
+ // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
+ fmla z18.s, p0/m, z20.s, z18.s
+ fmla z19.s, p0/m, z21.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
+ dup z10.s, #0
+ sel z28.s, p4, z10.s, z16.s
+ sel z29.s, p5, z10.s, z17.s
+ sel z30.s, p6, z10.s, z18.s
+ sel z31.s, p7, z10.s, z19.s
+
+ // ---------------------------------------------------------------- sum in fp32
+ .inst 0xc1a17f80 // fadd za.s[w11, #0, VGx4], {z28.s-z31.s} za0-za3: sum_value = sum_value + poly
+
+ fcvt z12.h, p0/m, z12.s
+ fcvtnt z12.h, p0/m, z28.s
+
+ fcvt z13.h, p0/m, z13.s
+ fcvtnt z13.h, p0/m, z29.s
+
+ fcvt z14.h, p0/m, z14.s
+ fcvtnt z14.h, p0/m, z30.s
+
+ fcvt z15.h, p0/m, z15.s
+ fcvtnt z15.h, p0/m, z31.s
+
+ // Stores 4 consecutive registers to the output
+ .inst 0xa029a78c // st1h {z12.h-z15.h}, pn9, [x28, x9, LSL #1]
+
+ inch x9, ALL, MUL #4
+ b regularize_body_start%=
+regularize_body_end%=:
+
+ // ---------------------------------------------------------------- z28: sum_value
+ .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4]
+ fadd z28.s, z28.s, z29.s
+ fadd z30.s, z30.s, z31.s
+ fadd z28.s, z28.s, z30.s
+
+ // Loop for processing the leftover part.
+regularize_leftover_start%=:
+ whilelo p2.h, x9, %x[length]
+ b.none regularize_leftover_end%=
+
+ ld1h z12.h, p2/z, [x27, x9, LSL #1] // x12: input_data
+
+ fsub z12.h, z12.h, z11.h // z12: x = input_data - max_value
+ fmul z12.h, z12.h, z26.h // z12: x = (input_data - max_value) * beta
+
+ // ---------------------------------------------------------------- z12.h --> z12.s, z13.s
+ fcvtlt z13.s, p2/m, z12.h
+ fcvt z12.s, p2/m, z12.h
+
+ // ---------------------------------------------------------------- p3, p4: predicates for z12, z14
+ pfalse p1.b
+ trn1 p3.h, p2.h, p1.h // for z12
+ trn2 p4.h, p2.h, p1.h // for z13
+
+ mov z16.d, z5.d // z16: shift
+ mov z17.d, z5.d // z17: shift
+ fcmlt p5.s, p3/z, z12.s, z9.s // p5: underflow = x < min_input
+ fcmlt p6.s, p4/z, z13.s, z9.s // p6: underflow = x < min_input
+ fmla z16.s, p3/m, z12.s, z6.s // z16: z = shift + x * inv_ln2
+ fmla z17.s, p4/m, z13.s, z6.s // z17: z = shift + x * inv_ln2
+ fsub z20.s, z16.s, z5.s // z20: n = z - shift
+ fsub z21.s, z17.s, z5.s // z21: n = z - shift
+ fmla z12.s, p3/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi
+ fmla z13.s, p4/m, z21.s, z7.s // z13: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p3/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo
+ fmla z13.s, p4/m, z21.s, z8.s // z13: r = r_hi + n * neg_ln2_lo
+ dup z10.s, #23 // z10: 23
+ urshl z16.s, p3/m, z16.s, z10.s // z16: scale = z << 23 (2^n)
+ urshl z17.s, p4/m, z17.s, z10.s // z17: scale = z << 23 (2^n)
+ fmul z20.s, z12.s, z0.s // z20: p1 = r * c1
+ fmul z21.s, z13.s, z0.s // z21: p1 = r * c1
+ mov z22.d, z1.d // z22: p23 = c2
+ mov z23.d, z1.d // z23: p23 = c2
+ fmla z22.s, p3/m, z12.s, z2.s // z22: p23 = c2 + r * c3
+ fmla z23.s, p4/m, z13.s, z2.s // z23: p23 = c2 + r * c3
+ mov z24.d, z3.d // z24: c4
+ mov z25.d, z3.d // z25: c4
+ fmla z24.s, p3/m, z12.s, z4.s // z24: p45 = c4 + r * c5
+ fmla z25.s, p4/m, z13.s, z4.s // z25: p45 = c4 + r * c5
+ fmul z12.s, z12.s, z12.s // z12: r2 = r * r
+ fmul z13.s, z13.s, z13.s // z13: r2 = r * r
+ fmla z22.s, p3/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45
+ fmla z23.s, p4/m, z13.s, z25.s // z23: p2345 = p23 + r2 * p45
+ fmla z20.s, p3/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345
+ fmla z21.s, p4/m, z13.s, z23.s // z21: p12345 = p1 + r2 * p2345
+ fmla z16.s, p3/m, z20.s, z16.s // z16: poly = scale + p12345 * scale
+ fmla z17.s, p4/m, z21.s, z17.s // z17: poly = scale + p12345 * scale
+ dup z10.s, #0 // z10: 0
+ sel z16.s, p5, z10.s, z16.s // z16: poly = underflow ? 0 : poly
+ sel z17.s, p6, z10.s, z17.s // z17: poly = underflow ? 0 : poly
+ fadd z28.s, p3/m, z28.s, z16.s // z28: sum_value = sum_value + poly
+ fadd z28.s, p4/m, z28.s, z17.s // z28: sum_value = sum_value + poly
+
+ fcvt z16.h, p3/m, z16.s
+ fcvtnt z16.h, p4/m, z17.s
+ st1h z16.h, p2, [x28, x9, LSL #1]
+
+ inch x9
+ b regularize_leftover_start%=
+regularize_leftover_end%=:
+
+ // ==================================================
+ // Step 3: Normalize
+ // ==================================================
+
+ // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value
+ faddv s28, p0, z28.s
+ fmov s29, #1.0 // 1.0f
+ fdiv s28, s29, s28
+ fcvt h28, s28
+
+ dup z28.h, z28.h[0]
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // x9: index
+
+normalize_body_start%=:
+ cmp x9, x13
+ b.eq normalize_body_end%=
+
+ .inst 0xa009a78c // ld1h {z12.h-z15.h}, pn9/z, [x28, x9, LSL #1]
+
+ // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value
+ fmul z12.h, z12.h, z28.h
+ fmul z13.h, z13.h, z28.h
+ fmul z14.h, z14.h, z28.h
+ fmul z15.h, z15.h, z28.h
+
+ .inst 0xa029a78c // st1h {z12.h-z15.h}, pn9, [x28, x9, LSL #1]
+
+ inch x9, ALL, MUL #4
+ b normalize_body_start%=
+normalize_body_end%=:
+
+ // Loop for processing the leftover part.
+normalize_leftover_start%=:
+ whilelo p1.h, x9, %x[length]
+ b.none normalize_leftover_end%=
+
+ ld1h z12.h, p1/z, [x28, x9, LSL #1] // z12: x
+ fmul z12.h, z12.h, z28.h // z12: result = x * inv_sum_value
+
+ st1h z12.h, p1, [x28, x9, LSL #1]
+
+ inch x9
+ b normalize_leftover_start%=
+normalize_leftover_end%=:
+
+ // ==================================================
+ // 3D loop closing
+ // ==================================================
+
+ add x27, x27, %x[src_stride_1]
+ add x28, x28, %x[dst_stride_1]
+ b loop_1_start%=
+loop_1_end%=:
+
+ add x24, x24, %x[src_stride_2]
+ add x25, x25, %x[dst_stride_2]
+ b loop_2_start%=
+loop_2_end%=:
+
+ add x21, x21, %x[src_stride_3]
+ add x22, x22, %x[dst_stride_3]
+ b loop_3_start%=
+loop_3_end%=:
+
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), //
+ [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), //
+ [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]),
+ [src_stride_3] "r"(src_strides[3]), //
+ [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]),
+ [dst_stride_3] "r"(dst_strides[3]), //
+ [length] "r"(shape[0]) //
+ : "cc", "memory", //
+ "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p9", //
+ "x9", "x10", "x11", "x12", "x13", "x14", //
+ "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", //
+ "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_fp16_softmax(const ITensor *in,
+ void *const,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr)
+{
+ ARM_COMPUTE_UNUSED(lut_ptr);
+ ARM_COMPUTE_UNUSED(axis);
+
+ const auto *src_info = in->info();
+ const auto *dst_info = out->info();
+
+ const auto &full_shape = dst_info->tensor_shape();
+ const auto &src_strides = src_info->strides_in_bytes();
+ const auto &dst_strides = dst_info->strides_in_bytes();
+
+ const uintptr_t k_shape[] = {
+ full_shape[0],
+ window.num_iterations(1),
+ window.num_iterations(2),
+ window.num_iterations(3),
+ };
+
+ const uintptr_t k_src_strides[] = {
+ src_strides[0],
+ src_strides[1],
+ src_strides[2],
+ src_strides[3],
+ };
+
+ const uintptr_t k_dst_strides[] = {
+ dst_strides[0],
+ dst_strides[1],
+ dst_strides[2],
+ dst_strides[3],
+ };
+
+ const uintptr_t k_src_offset = window[0].start() * src_strides[0] + //
+ window[1].start() * src_strides[1] + //
+ window[2].start() * src_strides[2] + //
+ window[3].start() * src_strides[3];
+
+ const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + //
+ window[1].start() * dst_strides[1] + //
+ window[2].start() * dst_strides[2] + //
+ window[3].start() * dst_strides[3];
+
+ const auto *k_src = reinterpret_cast<const float16_t *>(in->buffer() + k_src_offset);
+ auto *k_dst = reinterpret_cast<float16_t *>(out->buffer() + k_dst_offset);
+
+ sme2_f16_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/cpu/kernels/softmax/generic/sme2/fp32.cpp b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp
new file mode 100644
index 0000000000..5e29d51746
--- /dev/null
+++ b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp
@@ -0,0 +1,585 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// SoftMax
+//
+// Steps:
+// * Find max: max_value = max(src)
+// * Regularize: dst[i] = exp(src[i] - max_value)
+// sum_value = sum(dst)
+// * Normalize: dst[i] = dst[i] / sum_value
+void sme2_f32_softmax_kernel( //
+ const float *src,
+ float *dst,
+ float beta,
+ const uintptr_t shape[4],
+ const uintptr_t src_strides[4],
+ const uintptr_t dst_strides[4])
+{
+ // Precondition:
+ // * src_strides[0] == sizeof(float)
+ // * dst_strides[0] == sizeof(float)
+
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+
+ // Registers
+ //
+ // * x9: temporary, index
+ // * x10: temporary, -inf
+ // * x11: temporary, 0
+ // * x12: temporary, 1.0f
+ // * x13: temporary, body_length
+ //
+ // * x20: index_3
+ // * x21: src_3
+ // * x22: dst_3
+ // * x23: index_2
+ // * x24: src_2
+ // * x25: dst_2
+ // * x26: index_1
+ // * x27: src_1
+ // * x28: dst_1
+ //
+ // * z0: c1
+ // * z1: c2
+ // * z2: c3
+ // * z3: c4
+ // * z4: c5
+ // * z5: shift
+ // * z6: inv_ln2
+ // * z7: neg_ln2_hi
+ // * z8: neg_ln2_lo
+ // * z9: min_input
+ // * z10: 23, 0
+ // * z11: max_value
+ // * z12-z15: x, r_hi, r, r2
+ // * z16-z19: max_value, shift, z, scale, poly
+ // * z20-z21: n, p1, p12345
+ // * z22-z23: n, p23, p2345
+ // * z24-z25: p45
+ // * z26: beta
+ // * z28-z31: sum_value
+ //
+ // * za0-za3: sum_value
+ //
+ // * p0: all-true
+ // * p1: left-over predicate
+ // * p4-p7: underflow
+ // * pn9: all-true
+
+ // Prepares all constant values
+
+ ptrue p0.b
+ .inst 0x25207811 // ptrue pn9.b
+
+ mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ dup z0.s, w9 // c1.
+ dup z1.s, w10 // c2.
+ dup z2.s, w11 // c3.
+ dup z3.s, w12 // c4.
+ dup z4.s, w13 // c5.
+
+ mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+
+ movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+
+ dup z5.s, w9 // shift
+ dup z6.s, w10 // inv_ln2
+ dup z7.s, w11 // neg_ln2_hi
+ dup z8.s, w12 // neg_ln2_lo
+ dup z9.s, w13 // min_input
+
+ dup z26.s, %w[beta] // beta
+
+ mov w10, #0x0000 // -inf: 0xff800000
+ movk w10, #0xff80 // -inf: 0xff800000
+
+ mov w11, #0 // 0
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cntw x13, ALL, MUL #4
+ udiv x9, %x[length], x13
+ mul x13, x13, x9
+
+ // ==================================================
+ // 3D loop opening
+ // ==================================================
+
+ mov x20, %x[shape_3]
+ mov x21, %x[src]
+ mov x22, %x[dst]
+
+loop_3_start%=:
+ // for index_3 in shape_3 downto 1
+ cmp x20, #0
+ b.eq loop_3_end%=
+ sub x20, x20, #1
+
+ mov x23, %x[shape_2]
+ mov x24, x21
+ mov x25, x22
+
+loop_2_start%=:
+ // for index_2 in shape_2 downto 1
+ cmp x23, #0
+ b.eq loop_2_end%=
+ sub x23, x23, #1
+
+ mov x26, %x[shape_1]
+ mov x27, x24
+ mov x28, x25
+
+loop_1_start%=:
+ // for index_1 in shape_2 downto 1
+ cmp x26, #0
+ b.eq loop_1_end%=
+ sub x26, x26, #1
+
+ // ==================================================
+ // Step 1: Find max
+ // ==================================================
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // x9: index
+ dup z11.s, w10 // z11: max_value = -inf
+
+ // ---------------------------------------------------------------- z16-z19: max_value = -inf
+ mov z16.d, z11.d
+ mov z17.d, z11.d
+ mov z18.d, z11.d
+ mov z19.d, z11.d
+
+find_max_body_start%=:
+ cmp x9, x13
+ b.eq find_max_body_end%=
+
+ .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] // z12-z15: x
+ .inst 0xc1acb910 // fmax {z16.s-z19.s}, {z16.s-z19.s}, {z12.s-z15.s} // z16-z19: max_value = max(max_value, x)
+
+ incw x9, ALL, MUL #4
+ b find_max_body_start%=
+find_max_body_end%=:
+
+ // Loop for processing the leftover part.
+find_max_leftover_start%=:
+ whilelo p1.s, x9, %x[length]
+ b.none find_max_leftover_end%=
+
+ ld1w z12.s, p1/z, [x27, x9, LSL #2] // z12: x
+ fmax z16.s, p1/m, z16.s, z12.s // z16: max_value = max(max_value, x)
+
+ incw x9
+ b find_max_leftover_start%=
+find_max_leftover_end%=:
+
+ // ---------------------------------------------------------------- z16: max_value
+ .inst 0xc1b2b110 // fmax {z16.s-z17.s}, {z16.s-z17.s}, {z18.s-z19.s}
+ fmax z16.s, p0/m, z16.s, z17.s
+ fmaxv s16, p0, z16.s
+
+ // ---------------------------------------------------------------- z11: max_value
+ dup z11.s, z16.s[0]
+
+ // ==================================================
+ // Step 2: Regularize
+ // ==================================================
+
+ .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // ---------------------------------------------------- x9: index
+
+regularize_body_start%=:
+ cmp x9, x13
+ b.eq regularize_body_end%=
+
+ // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data
+ .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2]
+
+ // ---------------------------------------------------------------- z12-z15: x = input_data - max_value
+ fsub z12.s, z12.s, z11.s
+ fsub z13.s, z13.s, z11.s
+ fsub z14.s, z14.s, z11.s
+ fsub z15.s, z15.s, z11.s
+
+ // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta
+ fmul z12.s, z12.s, z26.s
+ fmul z13.s, z13.s, z26.s
+ fmul z14.s, z14.s, z26.s
+ fmul z15.s, z15.s, z26.s
+
+ // ---------------------------------------------------------------- z16-z19: shift
+ mov z16.d, z5.d
+ mov z17.d, z5.d
+ mov z18.d, z5.d
+ mov z19.d, z5.d
+
+ // ---------------------------------------------------------------- p4-p7: underflow = x < min_input
+ fcmlt p4.s, p0/z, z12.s, z9.s
+ fcmlt p5.s, p0/z, z13.s, z9.s
+ fcmlt p6.s, p0/z, z14.s, z9.s
+ fcmlt p7.s, p0/z, z15.s, z9.s
+
+ // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
+ fmla z16.s, p0/m, z12.s, z6.s
+ fmla z17.s, p0/m, z13.s, z6.s
+ fmla z18.s, p0/m, z14.s, z6.s
+ fmla z19.s, p0/m, z15.s, z6.s
+
+ // ---------------------------------------------------------------- z20-z23: n = z - shift
+ fsub z20.s, z16.s, z5.s
+ fsub z21.s, z17.s, z5.s
+ fsub z22.s, z18.s, z5.s
+ fsub z23.s, z19.s, z5.s
+
+ // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p0/m, z20.s, z7.s
+ fmla z13.s, p0/m, z21.s, z7.s
+ fmla z14.s, p0/m, z22.s, z7.s
+ fmla z15.s, p0/m, z23.s, z7.s
+
+ // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo
+ fmla z12.s, p0/m, z20.s, z8.s
+ fmla z13.s, p0/m, z21.s, z8.s
+ fmla z14.s, p0/m, z22.s, z8.s
+ fmla z15.s, p0/m, z23.s, z8.s
+
+ // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
+ dup z10.s, #23
+ urshl z16.s, p0/m, z16.s, z10.s
+ urshl z17.s, p0/m, z17.s, z10.s
+ urshl z18.s, p0/m, z18.s, z10.s
+ urshl z19.s, p0/m, z19.s, z10.s
+
+ // Processes the first 2 vectors.
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z12.s, z0.s
+ fmul z21.s, z13.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z12.s, z2.s
+ fmla z23.s, p0/m, z13.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z12.s, z4.s
+ fmla z25.s, p0/m, z13.s, z4.s
+
+ // ---------------------------------------------------------------- z12-z13: r2 = r * r
+ fmul z12.s, z12.s, z12.s
+ fmul z13.s, z13.s, z13.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z12.s, z24.s
+ fmla z23.s, p0/m, z13.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z12.s, z22.s
+ fmla z21.s, p0/m, z13.s, z23.s
+
+ // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
+ fmla z16.s, p0/m, z20.s, z16.s
+ fmla z17.s, p0/m, z21.s, z17.s
+
+ // Processes the last 2 vectors
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z14.s, z0.s
+ fmul z21.s, z15.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z14.s, z2.s
+ fmla z23.s, p0/m, z15.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z14.s, z4.s
+ fmla z25.s, p0/m, z15.s, z4.s
+
+ // ---------------------------------------------------------------- z14-z15: r2 = r * r
+ fmul z14.s, z14.s, z14.s
+ fmul z15.s, z15.s, z15.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z14.s, z24.s
+ fmla z23.s, p0/m, z15.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z14.s, z22.s
+ fmla z21.s, p0/m, z15.s, z23.s
+
+ // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
+ fmla z18.s, p0/m, z20.s, z18.s
+ fmla z19.s, p0/m, z21.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
+ dup z10.s, #0
+ sel z16.s, p4, z10.s, z16.s
+ sel z17.s, p5, z10.s, z17.s
+ sel z18.s, p6, z10.s, z18.s
+ sel z19.s, p7, z10.s, z19.s
+
+ // Stores 4 consecutive registers to the output
+ .inst 0xa029c790 // st1w {z16.s-z19.s}, pn9, [x28, x9, LSL #2]
+
+ .inst 0xc1a17e00 // fadd za.s[w11, #0, VGx4], {z16.s-z19.s} za0-za3: sum_value = sum_value + poly
+
+ incw x9, ALL, MUL #4
+ b regularize_body_start%=
+regularize_body_end%=:
+
+ // ---------------------------------------------------------------- z28: sum_value
+ .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4]
+ fadd z28.s, z28.s, z29.s
+ fadd z30.s, z30.s, z31.s
+ fadd z28.s, z28.s, z30.s
+
+ // Loop for processing the leftover part.
+regularize_leftover_start%=:
+ whilelo p1.s, x9, %x[length]
+ b.none regularize_leftover_end%=
+
+ ld1w z12.s, p1/z, [x27, x9, LSL #2] // x12: input_data
+
+ fsub z12.s, z12.s, z11.s // z12: x = input_data - max_value
+ fmul z12.s, z12.s, z26.s // z12: x = (input_data - max_value) * beta
+
+ mov z16.d, z5.d // z16: shift
+ fcmlt p4.s, p1/z, z12.s, z9.s // p4: underflow = x < min_input
+ fmla z16.s, p1/m, z12.s, z6.s // z16: z = shift + x * inv_ln2
+ fsub z20.s, z16.s, z5.s // z20: n = z - shift
+ fmla z12.s, p1/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p1/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo
+ dup z10.s, #23 // z10: 23
+ urshl z16.s, p1/m, z16.s, z10.s // z16: scale = z << 23 (2^n)
+ fmul z20.s, z12.s, z0.s // z20: p1 = r * c1
+ mov z22.d, z1.d // z22: p23 = c2
+ fmla z22.s, p1/m, z12.s, z2.s // z22: p23 = c2 + r * c3
+ mov z24.d, z3.d // z24: c4
+ fmla z24.s, p1/m, z12.s, z4.s // z24: p45 = c4 + r * c5
+ fmul z12.s, z12.s, z12.s // z12: r2 = r * r
+ fmla z22.s, p1/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45
+ fmla z20.s, p1/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345
+ fmla z16.s, p1/m, z20.s, z16.s // z16: poly = scale + p12345 * scale
+ dup z10.s, #0 // z10: 0
+ sel z16.s, p4, z10.s, z16.s // z16: poly = underflow ? 0 : poly
+
+ st1w z16.s, p1, [x28, x9, LSL #2]
+
+ fadd z28.s, p1/m, z28.s, z16.s // z28: sum_value = sum_value + poly
+
+ incw x9
+ b regularize_leftover_start%=
+regularize_leftover_end%=:
+
+ // ==================================================
+ // Step 3: Normalize
+ // ==================================================
+
+ // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value
+ fmov s29, #1.0 // 1.0f
+ faddv s28, p0, z28.s
+ fdiv s28, s29, s28
+ dup z28.s, z28.s[0]
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // x9: index
+
+normalize_body_start%=:
+ cmp x9, x13
+ b.eq normalize_body_end%=
+
+ .inst 0xa009c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, x9, LSL #2] // z12-z15: x
+
+ // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value
+ fmul z12.s, z12.s, z28.s
+ fmul z13.s, z13.s, z28.s
+ fmul z14.s, z14.s, z28.s
+ fmul z15.s, z15.s, z28.s
+
+ .inst 0xa029c78c // st1w {z12.s-z15.s}, pn9, [x28, x9, LSL #2]
+
+ incw x9, ALL, MUL #4
+ b normalize_body_start%=
+normalize_body_end%=:
+
+ // Loop for processing the leftover part.
+normalize_leftover_start%=:
+ whilelo p1.s, x9, %x[length]
+ b.none normalize_leftover_end%=
+
+ ld1w z12.s, p1/z, [x28, x9, LSL #2] // z12: x
+ fmul z12.s, z12.s, z28.s // z12: result = x * inv_sum_value
+
+ st1w z12.s, p1, [x28, x9, LSL #2]
+
+ incw x9
+ b normalize_leftover_start%=
+normalize_leftover_end%=:
+
+ // ==================================================
+ // 3D loop closing
+ // ==================================================
+
+ add x27, x27, %x[src_stride_1]
+ add x28, x28, %x[dst_stride_1]
+ b loop_1_start%=
+loop_1_end%=:
+
+ add x24, x24, %x[src_stride_2]
+ add x25, x25, %x[dst_stride_2]
+ b loop_2_start%=
+loop_2_end%=:
+
+ add x21, x21, %x[src_stride_3]
+ add x22, x22, %x[dst_stride_3]
+ b loop_3_start%=
+loop_3_end%=:
+
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), //
+ [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), //
+ [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]),
+ [src_stride_3] "r"(src_strides[3]), //
+ [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]),
+ [dst_stride_3] "r"(dst_strides[3]), //
+ [length] "r"(shape[0]) //
+ : "cc", "memory", //
+ "p0", "p4", "p5", "p6", "p7", "p9", //
+ "x9", "x10", "x11", "x12", "x13", //
+ "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", //
+ "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_fp32_softmax(const ITensor *in,
+ void *const,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr)
+{
+ ARM_COMPUTE_UNUSED(lut_ptr);
+ ARM_COMPUTE_UNUSED(axis);
+
+ const auto *src_info = in->info();
+ const auto *dst_info = out->info();
+
+ const auto &full_shape = dst_info->tensor_shape();
+ const auto &src_strides = src_info->strides_in_bytes();
+ const auto &dst_strides = dst_info->strides_in_bytes();
+
+ const uintptr_t k_shape[] = {
+ full_shape[0],
+ window.num_iterations(1),
+ window.num_iterations(2),
+ window.num_iterations(3),
+ };
+
+ const uintptr_t k_src_strides[] = {
+ src_strides[0],
+ src_strides[1],
+ src_strides[2],
+ src_strides[3],
+ };
+
+ const uintptr_t k_dst_strides[] = {
+ dst_strides[0],
+ dst_strides[1],
+ dst_strides[2],
+ dst_strides[3],
+ };
+
+ const uintptr_t k_src_offset = window[0].start() * src_strides[0] + //
+ window[1].start() * src_strides[1] + //
+ window[2].start() * src_strides[2] + //
+ window[3].start() * src_strides[3];
+
+ const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + //
+ window[1].start() * dst_strides[1] + //
+ window[2].start() * dst_strides[2] + //
+ window[3].start() * dst_strides[3];
+
+ const auto *k_src = reinterpret_cast<const float *>(in->buffer() + k_src_offset);
+ auto *k_dst = reinterpret_cast<float *>(out->buffer() + k_dst_offset);
+
+ sme2_f32_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp b/src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp
new file mode 100644
index 0000000000..9feb669f7c
--- /dev/null
+++ b/src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp
@@ -0,0 +1,634 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// SoftMax
+//
+// Steps:
+// * Find max: max_value = max(src)
+// * Regularize: dst[i] = exp(src[i] - max_value)
+// sum_value = sum(dst)
+// * Normalize: dst[i] = dst[i] / sum_value
+void sme2_qasymm8_softmax_kernel_512VL( //
+ const uint8_t *src,
+ uint8_t *dst,
+ float beta,
+ const uintptr_t shape[4],
+ const uintptr_t src_strides[4],
+ const uintptr_t dst_strides[4],
+ const float *lut,
+ float *tmp)
+{
+ // Precondition:
+ // * src_strides[0] == sizeof(uint8_t)
+ // * dst_strides[0] == sizeof(uint8_t)
+ // * tmp_strides[0] == sizeof(float)
+
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+
+ // Registers
+ //
+ // * x1: Loop index
+ // * x2: LUT index
+ // * x13: temporary, body_length
+ //
+ // * x20: index_3
+ // * x21: src_3
+ // * x22: dst_3
+ // * x23: index_2
+ // * x24: src_2
+ // * x25: dst_2
+ // * x26: index_1
+ // * x27: src_1
+ // * x28: dst_1
+ // * x29 tmp
+ //
+ //
+ // * p0: all-true
+ // * p1: predicate for QASYMM8 values
+ // * p2: predicate 0 for FP32 values (first quarter of expanded/unpacked p1)
+ // * p3: predicate 1 for FP32 values (second quarter of expanded/unpacked p1)
+ // * p4: predicate 2 for FP32 values (third quarter of expanded/unpacked p1)
+ // * p5: predicate 3 for FP32 values (fourth quarter of expanded/unpacked p1)
+ // * pn9: all-true for 32 bit values
+ // * pn8: all-true for 8-bit values
+ //
+ // * z0-z15 the 256 LUT values of exp(-scale*beta*x) for x in QASYMM8, stored as FP32 values
+
+ // Prepares all constant values
+
+ ptrue p0.b
+ .inst 0x25a07811 // ptrue pn9.s
+ .inst 0x25207810 // ptrue pn8.b
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cntb x13, ALL, MUL #4
+ udiv x9, %x[length], x13
+ mul x13, x13, x9
+
+ // ==================================================
+ // 3D loop opening
+ // ==================================================
+
+ mov x20, %x[shape_3]
+ mov x21, %x[src]
+ mov x22, %x[dst]
+ mov x19, %x[lut]
+ mov x29, %x[tmp]
+
+ // Load the LUT to the register file.
+ mov x2, %x[lut]
+ .inst 0xa040c440 //ld1w { z0.s - z3.s }, pn9/z, [x2]
+ add x2, x2, #256
+ .inst 0xa040c444 //ld1w { z4.s - z7.s }, pn9/z, [x2]
+ add x2, x2, #256
+ .inst 0xa040c448 //ld1w { z8.s - z11.s }, pn9/z, [x2]
+ add x2, x2, #256
+ .inst 0xa040c44c //ld1w { z12.s - z15.s }, pn9/z, [x2]
+
+
+loop_3_start%=:
+ // for index_3 in shape_3 downto 1
+ cmp x20, #0
+ b.eq loop_3_end%=
+ sub x20, x20, #1
+
+ mov x23, %x[shape_2]
+ mov x24, x21
+ mov x25, x22
+
+loop_2_start%=:
+ // for index_2 in shape_2 downto 1
+ cmp x23, #0
+ b.eq loop_2_end%=
+ sub x23, x23, #1
+
+ mov x26, %x[shape_1]
+ mov x27, x24
+ mov x28, x25
+
+loop_1_start%=:
+ // for index_1 in shape_2 downto 1
+ cmp x26, #0
+ b.eq loop_1_end%=
+ sub x26, x26, #1
+
+ // ==================================================
+ // Step 1: Find max
+ // ==================================================
+ // z16-z19 = minimum QASYMM8 value (0) to allow for it to be used for comparison to find the max.
+ dup z16.b, #0
+ dup z17.b, #0
+ dup z18.b, #0
+ dup z19.b, #0
+ mov x1, #0 // x1: index
+find_max_body_start%=:
+ cmp x1, x13
+ b.eq find_max_body_end%=
+ .inst 0xa0018374 // ld1b { z20.b - z23.b }, pn8/z, [x27, x1] z20-z23: x
+ .inst 0xc134b811 // umax { z16.b - z19.b }, { z16.b - z19.b }, { z20.b - z23.b } z16-z19: max_value = max(max_value, x)
+ add x1, x1, #256 // Advance index by 256 bytes/integers: Z registers = 2048-bit data = 256 8-bit integers.
+ b find_max_body_start%=
+find_max_body_end%=:
+
+ // Loop for processing the leftover part.
+find_max_leftover_start%=:
+ whilelo p1.b, x1, %x[length]
+ b.none find_max_leftover_end%=
+
+ ld1b z30.b, p1/z, [x27, x1] // z30: x
+ umax z16.b, p1/m, z16.b, z30.b // z16: max_value = max(max_value, x)
+
+ add x1, x1, #64
+
+ b find_max_leftover_start%=
+find_max_leftover_end%=:
+
+ .inst 0xc132b011 // umax { z16.b, z17.b }, { z16.b, z17.b }, { z18.b, z19.b }
+ umax z16.b, p0/m, z16.b, z17.b
+ umaxv b16, p0, z16.b // Reduction unsigned max operation to get maximum_value
+ dup z16.b, z16.b[0]
+ uunpklo z16.h, z16.b // Using unpack instructions to align the max value with the FP32 entries in the LUT for use in the TBX instruction
+ uunpklo z16.s, z16.h
+
+ mov x1, #0 // reset index
+ dup z25.s, #0
+
+ mov x1, #0
+
+regularize_start%=:
+ whilelo p1.b, x1, %x[length]
+ b.none regularize_end%=
+
+ // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated
+ punpklo p2.h, p1.b
+ punpkhi p4.h, p1.b
+
+ punpkhi p3.h, p2.b
+ punpklo p2.h, p2.b
+
+ punpkhi p5.h, p4.b
+ punpklo p4.h, p4.b
+
+ ld1b z17.b, p1/z, [x27, x1] //z17: input data
+
+ uunpklo z18.h, z17.b //Using unpack instructions to align the input QASYMM8 values with the FP32 entries in the LUT for use in the TBX instruction
+ uunpkhi z19.h, z17.b
+
+ uunpklo z17.s, z18.h // z17 = low low input QASYMM8 values
+ uunpkhi z18.s, z18.h // z18 = low high input QASYMM8 values
+
+ uunpkhi z20.s, z19.h // z20 = high high input QASYMM8 values
+ uunpklo z19.s, z19.h // z19 = high low input QASYMM8 values
+
+ sub z17.s, z16.s, z17.s // z12: x = max_value - input_data
+ sub z18.s, z16.s, z18.s // z13: x = max_value - input_data
+ sub z19.s, z16.s, z19.s // z14: x = max_value - input_data
+ sub z20.s, z16.s, z20.s // z15: x = max_value - input_data
+
+ tbx z21.s, z0.s, z17.s // Look-up entries 0-15 in the LUT.
+ tbx z22.s, z0.s, z18.s
+ tbx z23.s, z0.s, z19.s
+ tbx z24.s, z0.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z1.s, z17.s // Look-up entries 16-31 in the LUT.
+ tbx z22.s, z1.s, z18.s
+ tbx z23.s, z1.s, z19.s
+ tbx z24.s, z1.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z2.s, z17.s // Look-up entries 32-47 in the LUT.
+ tbx z22.s, z2.s, z18.s
+ tbx z23.s, z2.s, z19.s
+ tbx z24.s, z2.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z3.s, z17.s // Look-up entries 48-63 in the LUT.
+ tbx z22.s, z3.s, z18.s
+ tbx z23.s, z3.s, z19.s
+ tbx z24.s, z3.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z4.s, z17.s // Look-up entries 64-79 in the LUT.
+ tbx z22.s, z4.s, z18.s
+ tbx z23.s, z4.s, z19.s
+ tbx z24.s, z4.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z5.s, z17.s // Look-up entries 80-95 in the LUT.
+ tbx z22.s, z5.s, z18.s
+ tbx z23.s, z5.s, z19.s
+ tbx z24.s, z5.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z6.s, z17.s // Look-up entries 96-111 in the LUT.
+ tbx z22.s, z6.s, z18.s
+ tbx z23.s, z6.s, z19.s
+ tbx z24.s, z6.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z7.s, z17.s // Look-up entries 112-127 in the LUT.
+ tbx z22.s, z7.s, z18.s
+ tbx z23.s, z7.s, z19.s
+ tbx z24.s, z7.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z8.s, z17.s // Look-up entries 128-143 in the LUT.
+ tbx z22.s, z8.s, z18.s
+ tbx z23.s, z8.s, z19.s
+ tbx z24.s, z8.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z9.s, z17.s // Look-up entries 144-159 in the LUT.
+ tbx z22.s, z9.s, z18.s
+ tbx z23.s, z9.s, z19.s
+ tbx z24.s, z9.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z10.s, z17.s // Look-up entries 160-175 in the LUT.
+ tbx z22.s, z10.s, z18.s
+ tbx z23.s, z10.s, z19.s
+ tbx z24.s, z10.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z11.s, z17.s // Look-up entries 176-191 in the LUT.
+ tbx z22.s, z11.s, z18.s
+ tbx z23.s, z11.s, z19.s
+ tbx z24.s, z11.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z12.s, z17.s // Look-up entries 192-207 in the LUT.
+ tbx z22.s, z12.s, z18.s
+ tbx z23.s, z12.s, z19.s
+ tbx z24.s, z12.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z13.s, z17.s // Look-up entries 208-223 in the LUT.
+ tbx z22.s, z13.s, z18.s
+ tbx z23.s, z13.s, z19.s
+ tbx z24.s, z13.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z14.s, z17.s // Look-up entries 224-239 in the LUT.
+ tbx z22.s, z14.s, z18.s
+ tbx z23.s, z14.s, z19.s
+ tbx z24.s, z14.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z15.s, z17.s // Look-up entries 240-255 in the LUT.
+ tbx z22.s, z15.s, z18.s
+ tbx z23.s, z15.s, z19.s
+ tbx z24.s, z15.s, z20.s
+
+
+ st1w z21.s, p2, [x29, x1, LSL #2]// z21 store exp(-scale*beta*x) into the tmp tensor
+ fadd z25.s, p2/m, z25.s, z21.s
+ add x1, x1, #16
+
+ st1w z22.s, p3, [x29, x1, LSL #2]// z22 store exp(-scale*beta*x) into the tmp tensor
+ fadd z25.s, p3/m, z25.s, z22.s
+ add x1, x1, #16
+
+ st1w z23.s, p4, [x29, x1, LSL #2]// z23 store exp(-scale*beta*x) into the tmp tensor
+ fadd z25.s, p4/m, z25.s, z23.s
+ add x1, x1, #16
+
+ st1w z24.s, p5, [x29, x1, LSL #2]// z24 store exp(-scale*beta*x) into the tmp tensor
+ fadd z25.s, p5/m, z25.s, z24.s
+ add x1, x1, #16
+
+ b regularize_start%=
+regularize_end%=:
+
+ mov w9, 0x0000
+ movk w9, 0x4380, LSL #16 // Moving 256.f into w9 to scale - via multiplication (division by reciprocal) - the floating point [0,1] range of the results to the [0,255] integer range of QASYMM8
+ dup z29.s, w9
+ faddv s25, p0, z25.s
+ fdiv s25, s29, s25
+ dup z25.s, z25.s[0] // z25: 256.f/sum. 256 is needed to get the full range and 1/sum is part of softmax.
+
+ // ==================================================
+ // Step 3: Normalize
+ // ==================================================
+ mov x1, #0
+normalize_body_start%=:
+ cmp x1, x13
+ b.eq normalize_body_end%=
+
+ mov x2, x1 // Preserve the index into x2 for the final store to dst.
+ .inst 0xa001c7b0 // ld1w { z16.s - z19.s }, pn9/z, [x29, x1, lsl #2]
+ add x1, x1, #64
+ .inst 0xa001c7b4 // ld1w { z20.s - z23.s }, pn9/z, [x29, x1, lsl #2]
+ add x1, x1, #64
+
+ // z16-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256.
+ fmul z16.s, z25.s, z16.s
+ fmul z17.s, z25.s, z17.s
+ fmul z18.s, z25.s, z18.s
+ fmul z19.s, z25.s, z19.s
+ fmul z20.s, z25.s, z20.s
+ fmul z21.s, z25.s, z21.s
+ fmul z22.s, z25.s, z22.s
+ fmul z23.s, z25.s, z23.s
+
+ // z16-z23: convert the FP32 values from the tmp tensor to uint32.
+ fcvtzu z16.s, p0/m, z16.s
+ fcvtzu z17.s, p0/m, z17.s
+ fcvtzu z18.s, p0/m, z18.s
+ fcvtzu z19.s, p0/m, z19.s
+ fcvtzu z20.s, p0/m, z20.s
+ fcvtzu z21.s, p0/m, z21.s
+ fcvtzu z22.s, p0/m, z22.s
+ fcvtzu z23.s, p0/m, z23.s
+
+ // z16-z17: narrow the uint32 values into uint8 and saturate them.
+ .inst 0xc133e230 // uqcvt z16.b, { z16.s - z19.s }
+ .inst 0xc133e2b1 // uqcvt z17.b, { z20.s - z23.s }
+
+ dup z20.s, z25.s[0] // Juggling the value to z20 as z25 will be overwritten by the load below
+
+ .inst 0xa001c7b8 // ld1w { z24.s - z27.s }, pn9/z, [x29, x1, lsl #2]
+ add x1, x1, #64
+ .inst 0xa001c7bc // ld1w { z28.s - z31.s }, pn9/z, [x29, x1, lsl #2]
+ add x1, x1, #64
+
+ // z24-z31: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256.
+ fmul z24.s, z20.s, z24.s
+ fmul z25.s, z20.s, z25.s
+ fmul z26.s, z20.s, z26.s
+ fmul z27.s, z20.s, z27.s
+ fmul z28.s, z20.s, z28.s
+ fmul z29.s, z20.s, z29.s
+ fmul z30.s, z20.s, z30.s
+ fmul z31.s, z20.s, z31.s
+
+ // z24-z31: convert the FP32 values from the tmp tensor to uint32.
+ fcvtzu z24.s, p0/m, z24.s
+ fcvtzu z25.s, p0/m, z25.s
+ fcvtzu z26.s, p0/m, z26.s
+ fcvtzu z27.s, p0/m, z27.s
+ fcvtzu z28.s, p0/m, z28.s
+ fcvtzu z29.s, p0/m, z29.s
+ fcvtzu z30.s, p0/m, z30.s
+ fcvtzu z31.s, p0/m, z31.s
+
+ // z18-z19: narrow the uint32 values into uint8 and saturate them.
+ .inst 0xc133e332 // uqcvt z18.b, { z24.s - z27.s }
+ .inst 0xc133e3b3 // uqcvt z19.b, { z28.s - z31.s }
+
+ .inst 0xa0228390 // st1b { z16.b - z19.b }, pn8, [x28, x2]
+
+ dup z25.s, z20.s[0] // Juggling the value back to z25 as z20 will be overwritten by the next iteration or z25 will be used below.
+
+b normalize_body_start%=
+normalize_body_end%=:
+
+normalize_leftover_start%=:
+ whilelo p1.b, x1, %x[length]
+ b.none normalize_leftover_end%=
+
+ // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated
+ punpklo p2.h, p1.b
+ punpkhi p4.h, p1.b
+
+ punpkhi p3.h, p2.b
+ punpklo p2.h, p2.b
+
+ punpkhi p5.h, p4.b
+ punpklo p4.h, p4.b
+
+ mov x2, x1 // Preserve the index into x2 for the final store to dst.
+
+ // z20-z23: load exp(-scale*beta*x) from the tmp tensor
+ ld1w z20.s, p2/z, [x29, x1, LSL #2]
+ add x1, x1, #16
+
+ ld1w z21.s, p3/z, [x29, x1, LSL #2]
+ add x1, x1, #16
+
+ ld1w z22.s, p4/z, [x29, x1, LSL #2]
+ add x1, x1, #16
+
+ ld1w z23.s, p5/z, [x29, x1, LSL #2]
+ add x1, x1, #16
+
+ // z20-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256.
+ fmul z20.s, z25.s, z20.s
+ fmul z21.s, z25.s, z21.s
+ fmul z22.s, z25.s, z22.s
+ fmul z23.s, z25.s, z23.s
+
+ // z20-23: convert the FP32 values from the tmp tensor to uint32.
+ fcvtzu z20.s, p0/m, z20.s
+ fcvtzu z21.s, p0/m, z21.s
+ fcvtzu z22.s, p0/m, z22.s
+ fcvtzu z23.s, p0/m, z23.s
+
+ .inst 0xc133e2b3 // uqcvt z19.b, { z20.s - z23.s }, narrow the uint32 values into uint8 and saturate them into z19.
+
+ st1b z19.b, p1, [x28, x2]
+
+ b normalize_leftover_start%=
+normalize_leftover_end%=:
+ // ==================================================
+ // 3D loop closing
+ // ==================================================
+ add x27, x27, %x[src_stride_1]
+ add x28, x28, %x[dst_stride_1]
+ b loop_1_start%=
+loop_1_end%=:
+
+ add x24, x24, %x[src_stride_2]
+ add x25, x25, %x[dst_stride_2]
+ b loop_2_start%=
+loop_2_end%=:
+
+ add x21, x21, %x[src_stride_3]
+ add x22, x22, %x[dst_stride_3]
+ b loop_3_start%=
+loop_3_end%=:
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : [src] "r"(src), [tmp] "r"(tmp), [dst] "r"(dst), [beta] "r"(beta), [lut] "r"(lut), //
+ [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), //
+ [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]),
+ [src_stride_3] "r"(src_strides[3]), //
+ [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]),
+ [dst_stride_3] "r"(dst_strides[3]), //
+ [length] "r"(shape[0]) //
+ : "cc", "memory", //
+ "p0", "p1", "p2", "p3", "p4", //
+ "x2", "x9", "x13", //
+ "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x19", //
+ "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_qasymm8_softmax_lut_512VL(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr)
+{
+ ARM_COMPUTE_UNUSED(axis);
+
+ const auto *src_info = in->info();
+ const auto *dst_info = out->info();
+
+ const auto &full_shape = dst_info->tensor_shape();
+ const auto &src_strides = src_info->strides_in_bytes();
+ const auto &dst_strides = dst_info->strides_in_bytes();
+ Strides tmp_strides;
+
+ tmp_strides[0] = src_strides[0] * 4;
+ tmp_strides[1] = src_strides[1] * 4;
+ tmp_strides[2] = src_strides[2] * 4;
+ tmp_strides[3] = src_strides[3] * 4;
+
+ const uintptr_t k_shape[] = {
+ full_shape[0],
+ window.num_iterations(1),
+ window.num_iterations(2),
+ window.num_iterations(3),
+ };
+
+ const uintptr_t k_src_strides[] = {
+ src_strides[0],
+ src_strides[1],
+ src_strides[2],
+ src_strides[3],
+ };
+
+ const uintptr_t k_dst_strides[] = {
+ dst_strides[0],
+ dst_strides[1],
+ dst_strides[2],
+ dst_strides[3],
+ };
+
+ const uintptr_t k_src_offset = window[0].start() * src_strides[0] + //
+ window[1].start() * src_strides[1] + //
+ window[2].start() * src_strides[2] + //
+ window[3].start() * src_strides[3];
+
+ const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + //
+ window[1].start() * dst_strides[1] + //
+ window[2].start() * dst_strides[2] + //
+ window[3].start() * dst_strides[3];
+
+ const uintptr_t k_tmp_offset = window[0].start() * tmp_strides[0] + //
+ window[1].start() * tmp_strides[1] + //
+ window[2].start() * tmp_strides[2] + //
+ window[3].start() * tmp_strides[3];
+
+ const auto *k_src = reinterpret_cast<const uint8_t *>(in->buffer() + k_src_offset);
+ float *tmp_float_ptr = reinterpret_cast<float *>(tmp);
+ auto *k_tmp = reinterpret_cast<float *>(tmp_float_ptr + k_tmp_offset);
+ auto *k_dst = reinterpret_cast<uint8_t *>(out->buffer() + k_dst_offset);
+
+ sme2_qasymm8_softmax_kernel_512VL(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides, lut_ptr, k_tmp);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp
new file mode 100644
index 0000000000..14c0f6c327
--- /dev/null
+++ b/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp
@@ -0,0 +1,655 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// SoftMax
+//
+// Steps:
+// * Find max: max_value = max(src)
+// * Regularize: dst[i] = exp(src[i] - max_value)
+// sum_value = sum(dst)
+// * Normalize: dst[i] = dst[i] / sum_value
+void sme2_qasymm8_signed_softmax_kernel_512VL( //
+ const int8_t *src,
+ int8_t *dst,
+ float beta,
+ const uintptr_t shape[4],
+ const uintptr_t src_strides[4],
+ const uintptr_t dst_strides[4],
+ const float *lut,
+ float *tmp)
+{
+ // Precondition:
+ // * src_strides[0] == sizeof(int8_t)
+ // * dst_strides[0] == sizeof(int8_t)
+ // * tmp_strides[0] == sizeof(float)
+
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+
+ // For register list explanation refer to qasymm8.cpp.
+
+ // Prepares all constant values
+
+ ptrue p0.b
+ .inst 0x25a07811 // ptrue pn9.s
+ .inst 0x25207810 // ptrue pn8.b
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cntb x13, ALL, MUL #4
+ udiv x9, %x[length], x13
+ mul x13, x13, x9
+
+ // ==================================================
+ // 3D loop opening
+ // ==================================================
+
+ mov x20, %x[shape_3]
+ mov x21, %x[src]
+ mov x22, %x[dst]
+ mov x19, %x[lut]
+ mov x29, %x[tmp]
+
+ // Load the LUT to the register file.
+ mov x2, %x[lut]
+ .inst 0xa040c440 //ld1w { z0.s - z3.s }, pn9/z, [x2]
+ add x2, x2, #256
+ .inst 0xa040c444 //ld1w { z4.s - z7.s }, pn9/z, [x2]
+ add x2, x2, #256
+ .inst 0xa040c448 //ld1w { z8.s - z11.s }, pn9/z, [x2]
+ add x2, x2, #256
+ .inst 0xa040c44c //ld1w { z12.s - z15.s }, pn9/z, [x2]
+
+
+loop_3_start%=:
+ // for index_3 in shape_3 downto 1
+ cmp x20, #0
+ b.eq loop_3_end%=
+ sub x20, x20, #1
+
+ mov x23, %x[shape_2]
+ mov x24, x21
+ mov x25, x22
+
+loop_2_start%=:
+ // for index_2 in shape_2 downto 1
+ cmp x23, #0
+ b.eq loop_2_end%=
+ sub x23, x23, #1
+
+ mov x26, %x[shape_1]
+ mov x27, x24
+ mov x28, x25
+
+loop_1_start%=:
+ // for index_1 in shape_2 downto 1
+ cmp x26, #0
+ b.eq loop_1_end%=
+ sub x26, x26, #1
+
+ // ==================================================
+ // Step 1: Find max
+ // ==================================================
+ // z16-z19 = minimum QASYMM8_SIGNED value (-128) to allow for it to be used for comparison to find the max.
+ dup z16.b, #0x80
+ dup z17.b, #0x80
+ dup z18.b, #0x80
+ dup z19.b, #0x80
+
+ mov x1, #0 // x1: index
+find_max_body_start%=:
+ cmp x1, x13
+ b.eq find_max_body_end%=
+ .inst 0xa0018374 // ld1b { z20.b - z23.b }, pn8/z, [x27, x1] z16-z19: x
+ .inst 0xc134b810 // smax { z16.b - z19.b }, { z16.b - z19.b }, { z20.b - z23.b } z16-z19: max_value = max(max_value, x)
+ add x1, x1, #256 // Advance index by 256 bytes/integers: Z registers = 2048-bit data = 256 8-bit integers.
+ b find_max_body_start%=
+find_max_body_end%=:
+
+ // Loop for processing the leftover part.
+find_max_leftover_start%=:
+ whilelo p1.b, x1, %x[length]
+ b.none find_max_leftover_end%=
+
+ ld1b z30.b, p1/z, [x27, x1] // z30: x
+ smax z16.b, p1/m, z16.b, z30.b // z16: max_value = max(max_value, x)
+
+ add x1, x1, #64
+
+ b find_max_leftover_start%=
+find_max_leftover_end%=:
+ .inst 0xc132b010 // smax { z16.b, z17.b }, { z16.b, z17.b }, { z18.b, z19.b }
+ smax z16.b, p0/m, z16.b, z17.b
+ smaxv b16, p0, z16.b // Reduction signed max operation to get maximum_value
+ mov z16.b, b16 // z16: duplicated max_value for current row
+
+ sunpklo z16.h, z16.b // Using unpack instructions to align the max value with the FP32 entries in the LUT for use in the TBX instruction
+ sunpklo z16.s, z16.h
+
+ mov x1, #0 // reset index
+ dup z25.s, #0
+
+
+regularize_start%=:
+ whilelo p1.b, x1, %x[length]
+ b.none regularize_end%=
+
+ mov w9, 0xFF80
+ movk w9, 0xFFFF, LSL #16 // Moving -127.f into w9 to set the registers below to the minimum QASYMM8_SIGNED value
+ dup z17.s, w9
+ dup z18.s, w9
+ dup z19.s, w9
+ dup z20.s, w9
+
+ dup z21.s, #0x0
+ dup z22.s, #0x0
+ dup z23.s, #0x0
+ dup z24.s, #0x0
+
+ // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated
+ punpklo p2.h, p1.b
+ punpkhi p4.h, p1.b
+
+ punpkhi p3.h, p2.b
+ punpklo p2.h, p2.b
+
+ punpkhi p5.h, p4.b
+ punpklo p4.h, p4.b
+
+ ld1b z17.b, p1/z, [x27, x1] //z17: input data
+
+ sunpklo z18.h, z17.b // Using unpack instructions to align the input QASYMM8_SIGNED values with the FP32 entries in the LUT for use in the TBX instruction
+ sunpkhi z19.h, z17.b //
+
+ sunpklo z17.s, z18.h // z17 = low low input QASYMM8_SIGNED values
+ sunpkhi z18.s, z18.h // z18 = low high input QASYMM8_SIGNED values
+
+ sunpkhi z20.s, z19.h // z20 = high high input QASYMM8_SIGNED values
+ sunpklo z19.s, z19.h // z19 = high low input QASYMM8_SIGNED values
+
+ sub z17.s, z16.s, z17.s // z12: x = max_value - input_data
+ sub z18.s, z16.s, z18.s // z13: x = max_value - input_data
+ sub z19.s, z16.s, z19.s // z14: x = max_value - input_data
+ sub z20.s, z16.s, z20.s // z15: x = max_value - input_data
+
+ add z17.s, z17.s, #128
+ add z18.s, z18.s, #128
+ add z19.s, z19.s, #128
+ add z20.s, z20.s, #128
+
+ tbx z21.s, z0.s, z17.s // Look-up entries 0-15 in the LUT.
+ tbx z22.s, z0.s, z18.s
+ tbx z23.s, z0.s, z19.s
+ tbx z24.s, z0.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z1.s, z17.s // Look-up entries 16-31 in the LUT.
+ tbx z22.s, z1.s, z18.s
+ tbx z23.s, z1.s, z19.s
+ tbx z24.s, z1.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z2.s, z17.s // Look-up entries 32-47 in the LUT.
+ tbx z22.s, z2.s, z18.s
+ tbx z23.s, z2.s, z19.s
+ tbx z24.s, z2.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z3.s, z17.s // Look-up entries 48-63 in the LUT.
+ tbx z22.s, z3.s, z18.s
+ tbx z23.s, z3.s, z19.s
+ tbx z24.s, z3.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z4.s, z17.s // Look-up entries 64-79 in the LUT.
+ tbx z22.s, z4.s, z18.s
+ tbx z23.s, z4.s, z19.s
+ tbx z24.s, z4.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z5.s, z17.s // Look-up entries 80-95 in the LUT.
+ tbx z22.s, z5.s, z18.s
+ tbx z23.s, z5.s, z19.s
+ tbx z24.s, z5.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z6.s, z17.s // Look-up entries 96-111 in the LUT.
+ tbx z22.s, z6.s, z18.s
+ tbx z23.s, z6.s, z19.s
+ tbx z24.s, z6.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z7.s, z17.s // Look-up entries 112-127 in the LUT.
+ tbx z22.s, z7.s, z18.s
+ tbx z23.s, z7.s, z19.s
+ tbx z24.s, z7.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z8.s, z17.s // Look-up entries 128-143 in the LUT.
+ tbx z22.s, z8.s, z18.s
+ tbx z23.s, z8.s, z19.s
+ tbx z24.s, z8.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z9.s, z17.s // Look-up entries 144-159 in the LUT.
+ tbx z22.s, z9.s, z18.s
+ tbx z23.s, z9.s, z19.s
+ tbx z24.s, z9.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z10.s, z17.s // Look-up entries 160-175 in the LUT.
+ tbx z22.s, z10.s, z18.s
+ tbx z23.s, z10.s, z19.s
+ tbx z24.s, z10.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z11.s, z17.s // Look-up entries 176-191 in the LUT.
+ tbx z22.s, z11.s, z18.s
+ tbx z23.s, z11.s, z19.s
+ tbx z24.s, z11.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z12.s, z17.s // Look-up entries 192-207 in the LUT.
+ tbx z22.s, z12.s, z18.s
+ tbx z23.s, z12.s, z19.s
+ tbx z24.s, z12.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z13.s, z17.s // Look-up entries 208-223 in the LUT.
+ tbx z22.s, z13.s, z18.s
+ tbx z23.s, z13.s, z19.s
+ tbx z24.s, z13.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z14.s, z17.s // Look-up entries 224-239 in the LUT.
+ tbx z22.s, z14.s, z18.s
+ tbx z23.s, z14.s, z19.s
+ tbx z24.s, z14.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z15.s, z17.s // Look-up entries 240-255 in the LUT.
+ tbx z22.s, z15.s, z18.s
+ tbx z23.s, z15.s, z19.s
+ tbx z24.s, z15.s, z20.s
+
+
+ st1w z21.s, p2, [x29, x1, LSL #2]// z21 store exp(-scale*beta*x) into the tmp tensor
+ fadd z25.s, p2/m, z25.s, z21.s
+ add x1, x1, #16
+
+ st1w z22.s, p3, [x29, x1, LSL #2]// z22 store exp(-scale*beta*x) into the tmp tensor
+ fadd z25.s, p3/m, z25.s, z22.s
+ add x1, x1, #16
+
+ st1w z23.s, p4, [x29, x1, LSL #2]// z23 store exp(-scale*beta*x) into the tmp tensor
+ fadd z25.s, p4/m, z25.s, z23.s
+ add x1, x1, #16
+
+ st1w z24.s, p5, [x29, x1, LSL #2]// z24 store exp(-scale*beta*x) into the tmp tensor
+ fadd z25.s, p5/m, z25.s, z24.s
+ add x1, x1, #16
+
+ b regularize_start%=
+regularize_end%=:
+
+ mov w9, 0x0000
+ movk w9, 0x4380, LSL #16 // Moving 256.f into w9 to scale - via multiplication (division by reciprocal) - the floating point [0,1] range of the results to the [-128, 127] integer range of QASYMM8_SIGNED
+ mov w10, 0x0000
+ movk w10, 0x4300, LSL #16 // Moving 128.f into w10 for the subtraction to move the results - via subtraction - from the [0,255] range to the [-128,127] range
+ dup z29.s, w9
+ dup z30.s, w10
+ faddv s25, p0, z25.s
+ fdiv s25, s29, s25
+ dup z25.s, z25.s[0] // z25: 256.f/sum. 256 is needed to get the full range and 1/sum is part of softmax.
+
+ // ==================================================
+ // Step 3: Normalize
+ // ==================================================
+ mov x1, #0
+normalize_body_start%=:
+ cmp x1, x13
+ b.eq normalize_body_end%=
+
+ mov x2, x1 // Preserve the index into x2 for the final store to dst.
+ .inst 0xa001c7b0 // ld1w { z16.s - z19.s }, pn9/z, [x29, x1, lsl #2]
+ add x1, x1, #64
+ .inst 0xa001c7b4 // ld1w { z20.s - z23.s }, pn9/z, [x29, x1, lsl #2]
+ add x1, x1, #64
+
+ // z16-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256.
+ fmul z16.s, z25.s, z16.s
+ fmul z17.s, z25.s, z17.s
+ fmul z18.s, z25.s, z18.s
+ fmul z19.s, z25.s, z19.s
+ fmul z20.s, z25.s, z20.s
+ fmul z21.s, z25.s, z21.s
+ fmul z22.s, z25.s, z22.s
+ fmul z23.s, z25.s, z23.s
+
+ // z16-z23: subtract 128.f.
+ fsub z16.s, z16.s, z30.s // Subtract 128.f
+ fsub z17.s, z17.s, z30.s // Subtract 128.f
+ fsub z18.s, z18.s, z30.s // Subtract 128.f
+ fsub z19.s, z19.s, z30.s // Subtract 128.f
+ fsub z20.s, z20.s, z30.s // Subtract 128.f
+ fsub z21.s, z21.s, z30.s // Subtract 128.f
+ fsub z22.s, z22.s, z30.s // Subtract 128.f
+ fsub z23.s, z23.s, z30.s // Subtract 128.f
+
+ // z16-z23: convert the FP32 values from the tmp tensor to int32.
+ fcvtzs z16.s, p0/m, z16.s
+ fcvtzs z17.s, p0/m, z17.s
+ fcvtzs z18.s, p0/m, z18.s
+ fcvtzs z19.s, p0/m, z19.s
+ fcvtzs z20.s, p0/m, z20.s
+ fcvtzs z21.s, p0/m, z21.s
+ fcvtzs z22.s, p0/m, z22.s
+ fcvtzs z23.s, p0/m, z23.s
+
+ // z16-z17: narrow the int32 values into int8 and saturate them.
+ .inst 0xc133e210 // sqcvt z16.b, { z16.s - z19.s }
+ .inst 0xc133e291 // sqcvt z17.b, { z20.s - z23.s }
+
+ // Juggling the value to z20 (resp. 21) as z25 (resp. z30) will be overwritten by the load below.
+ dup z20.s, z25.s[0]
+ dup z21.s, z30.s[0]
+
+ .inst 0xa001c7b8 // ld1w { z24.s - z27.s }, pn9/z, [x29, x1, lsl #2]
+ add x1, x1, #64
+ .inst 0xa001c7bc // ld1w { z28.s - z31.s }, pn9/z, [x29, x1, lsl #2]
+ add x1, x1, #64
+
+ // z24-z31: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256.
+ fmul z24.s, z20.s, z24.s
+ fmul z25.s, z20.s, z25.s
+ fmul z26.s, z20.s, z26.s
+ fmul z27.s, z20.s, z27.s
+ fmul z28.s, z20.s, z28.s
+ fmul z29.s, z20.s, z29.s
+ fmul z30.s, z20.s, z30.s
+ fmul z31.s, z20.s, z31.s
+
+ // z24-z31: subtract 128.f.
+ fsub z24.s, z24.s, z21.s
+ fsub z25.s, z25.s, z21.s
+ fsub z26.s, z26.s, z21.s
+ fsub z27.s, z27.s, z21.s
+ fsub z28.s, z28.s, z21.s
+ fsub z29.s, z29.s, z21.s
+ fsub z30.s, z30.s, z21.s
+ fsub z31.s, z31.s, z21.s
+
+ // z24-z31: convert the FP32 values from the tmp tensor to int32.
+ fcvtzs z24.s, p0/m, z24.s
+ fcvtzs z25.s, p0/m, z25.s
+ fcvtzs z26.s, p0/m, z26.s
+ fcvtzs z27.s, p0/m, z27.s
+ fcvtzs z28.s, p0/m, z28.s
+ fcvtzs z29.s, p0/m, z29.s
+ fcvtzs z30.s, p0/m, z30.s
+ fcvtzs z31.s, p0/m, z31.s
+
+ // z18-z19: narrow the int32 values into int8 and saturate them.
+ .inst 0xc133e312 // sqcvt z18.b, { z24.s - z27.s }
+ .inst 0xc133e393 // sqcvt z19.b, { z28.s - z31.s }
+
+ .inst 0xa0228390 // st1b { z16.b - z19.b }, pn8, [x28, x2]
+
+ // Juggling the values back to z25 (resp. z30) as z20 (resp. z21) will be overwritten by the next iteration or z25 (resp. z30) will be used below.
+ dup z25.s, z20.s[0]
+ dup z30.s, z21.s[0]
+b normalize_body_start%=
+normalize_body_end%=:
+normalize_leftover_start%=:
+ whilelo p1.b, x1, %x[length]
+ b.none normalize_leftover_end%=
+
+ // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated
+ punpklo p2.h, p1.b
+ punpkhi p4.h, p1.b
+
+ punpkhi p3.h, p2.b
+ punpklo p2.h, p2.b
+
+ punpkhi p5.h, p4.b
+ punpklo p4.h, p4.b
+
+ mov x2, x1 // Preserve the index into x2 for the final store to dst.
+
+ // z20-z23: load exp(-scale*beta*x) from the tmp tensor
+ ld1w z20.s, p2/z, [x29, x1, LSL #2]
+ add x1, x1, #16
+
+ ld1w z21.s, p3/z, [x29, x1, LSL #2]
+ add x1, x1, #16
+
+ ld1w z22.s, p4/z, [x29, x1, LSL #2]
+ add x1, x1, #16
+
+ ld1w z23.s, p5/z, [x29, x1, LSL #2]
+ add x1, x1, #16
+
+ // z20-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256.
+ fmul z20.s, z25.s, z20.s
+ fmul z21.s, z25.s, z21.s
+ fmul z22.s, z25.s, z22.s
+ fmul z23.s, z25.s, z23.s
+
+ //z20-z23: Subtract 128.f.
+ fsub z20.s, z20.s, z30.s
+ fsub z21.s, z21.s, z30.s
+ fsub z22.s, z22.s, z30.s
+ fsub z23.s, z23.s, z30.s
+
+ // z20-23: convert the FP32 values from the tmp tensor to int32.
+ fcvtzs z20.s, p0/m, z20.s
+ fcvtzs z21.s, p0/m, z21.s
+ fcvtzs z22.s, p0/m, z22.s
+ fcvtzs z23.s, p0/m, z23.s
+
+ .inst 0xc133e293 // sqcvt z19.b, { z20.s - z23.s }, narrow the int32 values into int8 and saturate them into z19.
+
+ st1b z19.b, p1, [x28, x2]
+
+ b normalize_leftover_start%=
+normalize_leftover_end%=:
+ // ==================================================
+ // 3D loop closing
+ // ==================================================
+ add x27, x27, %x[src_stride_1]
+ add x28, x28, %x[dst_stride_1]
+ b loop_1_start%=
+loop_1_end%=:
+
+ add x24, x24, %x[src_stride_2]
+ add x25, x25, %x[dst_stride_2]
+ b loop_2_start%=
+loop_2_end%=:
+
+ add x21, x21, %x[src_stride_3]
+ add x22, x22, %x[dst_stride_3]
+ b loop_3_start%=
+loop_3_end%=:
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : [src] "r"(src), [tmp] "r"(tmp), [dst] "r"(dst), [beta] "r"(beta), [lut] "r"(lut), //
+ [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), //
+ [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]),
+ [src_stride_3] "r"(src_strides[3]), //
+ [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]),
+ [dst_stride_3] "r"(dst_strides[3]), //
+ [length] "r"(shape[0]) //
+ : "cc", "memory", //
+ "p0", "p1", "p2", "p3", "p4", //
+ "x2", "x9", "x13", //
+ "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x19", //
+ "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_qasymm8_signed_softmax_lut_512VL(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr)
+{
+ ARM_COMPUTE_UNUSED(axis);
+
+ const auto *src_info = in->info();
+ const auto *dst_info = out->info();
+
+ const auto &full_shape = dst_info->tensor_shape();
+ const auto &src_strides = src_info->strides_in_bytes();
+ const auto &dst_strides = dst_info->strides_in_bytes();
+ Strides tmp_strides;
+
+ tmp_strides[0] = src_strides[0] * 4;
+ tmp_strides[1] = src_strides[1] * 4;
+ tmp_strides[2] = src_strides[2] * 4;
+ tmp_strides[3] = src_strides[3] * 4;
+
+ const uintptr_t k_shape[] = {
+ full_shape[0],
+ window.num_iterations(1),
+ window.num_iterations(2),
+ window.num_iterations(3),
+ };
+
+ const uintptr_t k_src_strides[] = {
+ src_strides[0],
+ src_strides[1],
+ src_strides[2],
+ src_strides[3],
+ };
+
+ const uintptr_t k_dst_strides[] = {
+ dst_strides[0],
+ dst_strides[1],
+ dst_strides[2],
+ dst_strides[3],
+ };
+
+ const uintptr_t k_src_offset = window[0].start() * src_strides[0] + //
+ window[1].start() * src_strides[1] + //
+ window[2].start() * src_strides[2] + //
+ window[3].start() * src_strides[3];
+
+ const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + //
+ window[1].start() * dst_strides[1] + //
+ window[2].start() * dst_strides[2] + //
+ window[3].start() * dst_strides[3];
+
+ const uintptr_t k_tmp_offset = window[0].start() * tmp_strides[0] + //
+ window[1].start() * tmp_strides[1] + //
+ window[2].start() * tmp_strides[2] + //
+ window[3].start() * tmp_strides[3];
+
+ const auto *k_src = reinterpret_cast<const int8_t *>(in->buffer() + k_src_offset);
+ float *tmp_float_ptr = reinterpret_cast<float *>(tmp);
+ auto *k_tmp = reinterpret_cast<float *>(tmp_float_ptr + k_tmp_offset);
+ auto *k_dst = reinterpret_cast<int8_t *>(out->buffer() + k_dst_offset);
+
+ sme2_qasymm8_signed_softmax_kernel_512VL(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides, lut_ptr, k_tmp);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/cpu/kernels/softmax/list.h b/src/cpu/kernels/softmax/list.h
index f9295ebbcc..7bbb265022 100644
--- a/src/cpu/kernels/softmax/list.h
+++ b/src/cpu/kernels/softmax/list.h
@@ -28,15 +28,52 @@ namespace arm_compute
{
namespace cpu
{
-#define DECLARE_SOFTMAX_KERNEL(func_name) \
- template <bool IS_LOG> \
- void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
+#define DECLARE_SOFTMAX_KERNEL(func_name) \
+ template <bool IS_LOG> \
+ void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window, \
+ const float *lut_ptr)
DECLARE_SOFTMAX_KERNEL(neon_fp32_softmax);
DECLARE_SOFTMAX_KERNEL(neon_fp16_softmax);
DECLARE_SOFTMAX_KERNEL(neon_qasymm8_softmax);
DECLARE_SOFTMAX_KERNEL(neon_qasymm8_signed_softmax);
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+void sme2_fp32_softmax(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr);
+
+void sme2_fp16_softmax(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr);
+
+void sme2_qasymm8_softmax_lut_512VL(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr);
+
+void sme2_qasymm8_signed_softmax_lut_512VL(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr);
+
+#endif // ARM_COMPUTE_ENABLE_SME2
+
#undef DECLARE_SOFTMAX_KERNEL
} // namespace cpu
} // namespace arm_compute