aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels')
-rw-r--r--src/cpu/kernels/CpuActivationKernel.cpp12
-rw-r--r--src/cpu/kernels/CpuDirectConv2dOutputStageKernel.cpp349
-rw-r--r--src/cpu/kernels/CpuDirectConv3dKernel.cpp20
-rw-r--r--src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp4
-rw-r--r--src/cpu/kernels/CpuKernelSelectionTypes.h9
-rw-r--r--src/cpu/kernels/CpuMulKernel.cpp44
-rw-r--r--src/cpu/kernels/CpuPermuteKernel.cpp65
-rw-r--r--src/cpu/kernels/CpuReshapeKernel.cpp16
-rw-r--r--src/cpu/kernels/CpuReshapeKernel.h19
-rw-r--r--src/cpu/kernels/CpuScatterKernel.cpp98
-rw-r--r--src/cpu/kernels/CpuScatterKernel.h91
-rw-r--r--src/cpu/kernels/CpuWeightsReshapeKernel.cpp7
-rw-r--r--src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h8
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp12
-rw-r--r--src/cpu/kernels/assembly/convolution_parameters.hpp2
-rw-r--r--src/cpu/kernels/assembly/gemm_common.hpp25
-rw-r--r--src/cpu/kernels/conv3d/generic/neon/float_impl.h (renamed from src/cpu/kernels/conv3d/neon/list.h)17
-rw-r--r--src/cpu/kernels/conv3d/generic/neon/fp16.cpp49
-rw-r--r--src/cpu/kernels/conv3d/generic/neon/fp32.cpp46
-rw-r--r--src/cpu/kernels/conv3d/generic/neon/qasymm8.cpp46
-rw-r--r--src/cpu/kernels/conv3d/generic/neon/qasymm8_signed.cpp46
-rw-r--r--src/cpu/kernels/conv3d/generic/neon/quantized_impl.h (renamed from src/cpu/kernels/conv3d/neon/quantized.h)12
-rw-r--r--src/cpu/kernels/conv3d/list.h47
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/generic/neon/float_impl.h186
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/generic/neon/fp16.cpp63
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/generic/neon/fp32.cpp59
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8.cpp59
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8_signed.cpp59
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/generic/neon/quantized_impl.h213
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/list.h57
-rw-r--r--src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp572
-rw-r--r--src/cpu/kernels/logistic/generic/sme2/fp32.cpp429
-rw-r--r--src/cpu/kernels/logistic/list.h42
-rw-r--r--src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp13
-rw-r--r--src/cpu/kernels/mul/generic/sme2/list.h38
-rw-r--r--src/cpu/kernels/mul/generic/sme2/qasymm8_signed.cpp410
36 files changed, 2602 insertions, 642 deletions
diff --git a/src/cpu/kernels/CpuActivationKernel.cpp b/src/cpu/kernels/CpuActivationKernel.cpp
index 4253027231..555705bd45 100644
--- a/src/cpu/kernels/CpuActivationKernel.cpp
+++ b/src/cpu/kernels/CpuActivationKernel.cpp
@@ -32,6 +32,7 @@
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
#include "src/cpu/kernels/activation/list.h"
+#include "src/cpu/kernels/logistic/list.h"
#include <array>
@@ -71,6 +72,12 @@ static const std::vector<CpuActivationKernel::ActivationKernel> available_kernel
},
REGISTER_Q8_NEON(arm_compute::cpu::neon_q8_activation_lut)},
#endif // __aarch64__
+ {"sme2_fp32_logistic",
+ [](const ActivationDataTypeISASelectorData &data) {
+ return data.dt == DataType::F32 && data.f == ActivationLayerInfo::ActivationFunction::LOGISTIC &&
+ data.isa.sme2;
+ },
+ REGISTER_FP32_SME2(arm_compute::cpu::sme2_fp32_logistic)},
{"sve2_qu8_activation",
[](const ActivationDataTypeISASelectorData &data) {
return data.dt == DataType::QASYMM8 && data.isa.sve2 &&
@@ -316,6 +323,11 @@ void CpuActivationKernel::configure(const ITensorInfo *src, ITensorInfo *dst, Ac
// Use squashed window
std::tie(win, _split_dimension) = calculate_squashed_or_max_window(*src);
+ // Collapse window with SME kernels in Y-Dim
+ if (std::string(uk->name) == "sme2_fp32_logistic")
+ {
+ win = win.collapse(win, Window::DimY);
+ }
ICPPKernel::configure(win);
}
diff --git a/src/cpu/kernels/CpuDirectConv2dOutputStageKernel.cpp b/src/cpu/kernels/CpuDirectConv2dOutputStageKernel.cpp
index d4af8bedaf..9e9137a266 100644
--- a/src/cpu/kernels/CpuDirectConv2dOutputStageKernel.cpp
+++ b/src/cpu/kernels/CpuDirectConv2dOutputStageKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -37,6 +37,7 @@
#include "src/core/NEON/NEAsymm.h"
#include "src/core/NEON/NEFixedPoint.h"
#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/cpu/kernels/directconv2d_output_stage/list.h"
#include <arm_neon.h>
#include <cstddef>
@@ -95,316 +96,6 @@ Status validate_arguments(const ITensorInfo *src
return Status{};
}
-
-template <typename T>
-typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type
-output_stage_nchw(ITensor *src,
- const ITensor *bias,
- const Window &window,
- ITensor *dst,
- int result_fixedpoint_multiplier,
- int result_shift,
- int result_offset_after_shift)
-{
- const bool has_bias = bias != nullptr;
- /** SIMD vector tag type. */
- using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
-
- ARM_COMPUTE_ERROR_ON(src->info()->data_layout() == DataLayout::UNKNOWN);
- ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
- ARM_COMPUTE_UNUSED(result_shift);
- ARM_COMPUTE_UNUSED(result_offset_after_shift);
-
- const int window_start_x = window.x().start();
- const int window_end_x = window.x().end();
- const int window_step_x = 16 / src->info()->element_size();
- Window win = window;
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator in(src, win);
- Iterator out(dst, win);
- execute_window_loop(
- win,
- [&](const Coordinates &id)
- {
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<const T *>(in.ptr()) + x;
- auto v_in = wrapper::vloadq(in_ptr);
-
- // Accumulate bias
- if (has_bias)
- {
- const auto vb = wrapper::vdup_n(
- *reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(id.z()))), ExactTagType{});
- v_in = wrapper::vadd(v_in, vb);
- }
-
- const auto out_ptr = reinterpret_cast<T *>(out.ptr()) + x;
- wrapper::vstore(out_ptr, v_in);
- }
-
- // Left-overs loop
- for (; x < window_end_x; ++x)
- {
- // Get bias and pointer to input
- auto s_in = *(reinterpret_cast<const T *>(in.ptr()) + x);
-
- // Accumulate bias
- if (has_bias)
- {
- const auto b = *reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(id.z())));
- s_in += b;
- }
-
- *(reinterpret_cast<T *>(out.ptr()) + x) = s_in;
- }
- },
- in, out);
-}
-
-template <typename T>
-typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type
-output_stage_nhwc(ITensor *src,
- const ITensor *bias,
- const Window &window,
- ITensor *dst,
- int result_fixedpoint_multiplier,
- int result_shift,
- int result_offset_after_shift)
-{
- const bool has_bias = bias != nullptr;
- ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
- ARM_COMPUTE_UNUSED(result_shift);
- ARM_COMPUTE_UNUSED(result_offset_after_shift);
-
- Window window_bias = window;
- window_bias.set(Window::DimX, Window::Dimension(0, 1, 1));
- window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
- window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
- window_bias.set(3, Window::Dimension(0, 0, 0));
-
- const int window_start_x = window.x().start();
- const int window_end_x = window.x().end();
- const int window_step_x = 16 / src->info()->element_size();
- Window win = window;
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator in(src, win);
- Iterator bi(bias, window_bias);
- Iterator out(dst, win);
-
- execute_window_loop(
- win,
- [&](const Coordinates &)
- {
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<const T *>(in.ptr());
- auto v_in = wrapper::vloadq(in_ptr + x);
-
- // Accumulate bias
- if (has_bias)
- {
- const auto bias_ptr = reinterpret_cast<T *>(bi.ptr()) + x;
- v_in = wrapper::vadd(v_in, wrapper::vloadq(bias_ptr));
- }
-
- const auto out_ptr = reinterpret_cast<T *>(out.ptr());
- wrapper::vstore(out_ptr + x, v_in);
- }
-
- // Left-overs loop
- for (; x < window_end_x; ++x)
- {
- // Get bias and pointer to input
- auto s_in = *(reinterpret_cast<const T *>(in.ptr()) + x);
-
- // Accumulate bias
- if (has_bias)
- {
- const auto bias_ptr = reinterpret_cast<T *>(bi.ptr()) + x;
- s_in += *bias_ptr;
- }
-
- const auto out_ptr = reinterpret_cast<T *>(out.ptr());
- *(out_ptr + x) = s_in;
- }
- },
- in, bi, out);
-}
-
-// Quantized case
-template <
- typename TOut,
- typename std::enable_if<std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int>::type = 0>
-void output_stage_nchw(ITensor *src,
- const ITensor *bias,
- const Window &window,
- ITensor *dst,
- int result_fixedpoint_multiplier,
- int result_shift,
- int result_offset_after_shift)
-{
- const bool has_bias = bias != nullptr;
- using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
- using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
-
- const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
-
- const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
- const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
-
- const int window_start_x = window.x().start();
- const int window_end_x = window.x().end();
- const int window_step_x = 16 / src->info()->element_size();
- Window win = window;
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator in(src, win);
- Iterator out(dst, win);
-
- execute_window_loop(
- win,
- [&](const Coordinates &id)
- {
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr()) + x;
- int32x4x4_t v_in = {{wrapper::vloadq(in_ptr), wrapper::vloadq(in_ptr + 4), wrapper::vloadq(in_ptr + 8),
- wrapper::vloadq(in_ptr + 12)}};
-
- // Accumulate bias
- if (has_bias)
- {
- const auto vb = wrapper::vdup_n(
- *reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z()))), TagType{});
- v_in = {{wrapper::vadd(v_in.val[0], vb), wrapper::vadd(v_in.val[1], vb),
- wrapper::vadd(v_in.val[2], vb), wrapper::vadd(v_in.val[3], vb)}};
- }
-
- const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
- wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift,
- result_offset_after_shift_s32, min, max, false));
- }
-
- // Left-overs loop
- for (; x < window_end_x; ++x)
- {
- // Get bias and pointer to input
- int32_t s_in = *(reinterpret_cast<const int32_t *>(in.ptr()) + x);
-
- // Accumulate bias
- if (has_bias)
- {
- const auto b = *reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z())));
- s_in += b;
- }
-
- const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
- *out_ptr =
- finalize_quantization(s_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift,
- std::numeric_limits<TOut>::lowest(), std::numeric_limits<TOut>::max(), false);
- }
- },
- in, out);
-}
-template <
- typename TOut,
- typename std::enable_if<std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int>::type = 0>
-void output_stage_nhwc(ITensor *src,
- const ITensor *bias,
- const Window &window,
- ITensor *dst,
- int result_fixedpoint_multiplier,
- int result_shift,
- int result_offset_after_shift)
-{
- const bool has_bias = bias != nullptr;
- using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
- using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
-
- const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
-
- const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
- const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
-
- Window window_bias = window;
- window_bias.set(Window::DimX, Window::Dimension(0, 1, 1));
- window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
- window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
- window_bias.set(3, Window::Dimension(0, 0, 0));
-
- const int window_start_x = window.x().start();
- const int window_end_x = window.x().end();
- const int window_step_x = 16 / src->info()->element_size();
- Window win = window;
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator in(src, win);
- Iterator bi(bias, window_bias);
- Iterator out(dst, win);
-
- execute_window_loop(
- win,
- [&](const Coordinates &)
- {
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr()) + x;
- int32x4x4_t v_in = {{
- wrapper::vloadq(in_ptr),
- wrapper::vloadq(in_ptr + 4),
- wrapper::vloadq(in_ptr + 8),
- wrapper::vloadq(in_ptr + 12),
- }};
-
- // Accumulate bias
- if (has_bias)
- {
- const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr()) + x;
-
- wrapper::vadd(v_in.val[0], wrapper::vloadq(bias_ptr));
- wrapper::vadd(v_in.val[1], wrapper::vloadq(bias_ptr + 4));
- wrapper::vadd(v_in.val[2], wrapper::vloadq(bias_ptr + 8));
- wrapper::vadd(v_in.val[3], wrapper::vloadq(bias_ptr + 12));
- }
-
- const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
- wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift,
- result_offset_after_shift_s32, min, max, false));
- }
-
- // Left-overs loop
- for (; x < window_end_x; ++x)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr()) + x;
- int32_t s_in = *in_ptr;
-
- // Accumulate bias
- if (has_bias)
- {
- const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr()) + x;
- s_in += *bias_ptr;
- }
-
- const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
- *out_ptr =
- finalize_quantization(s_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift,
- std::numeric_limits<TOut>::lowest(), std::numeric_limits<TOut>::max(), false);
- }
- },
- in, bi, out);
-}
} // namespace
void CpuDirectConv2dOutputStageKernel::configure(ITensorInfo *src,
@@ -447,24 +138,30 @@ void CpuDirectConv2dOutputStageKernel::configure(ITensorInfo
{
if (is_qasymm8_signed)
{
- _func = &output_stage_nchw<int8_t>;
+#ifdef ENABLE_QASYMM8_SIGNED_KERNELS
+ _func = &output_stage_nchw_qs8;
+#endif // ENABLE_QASYMM8_SIGNED_KERNELS
}
else
{
- _func = &output_stage_nchw<uint8_t>;
+#ifdef ENABLE_QASYMM8_KERNELS
+ _func = &output_stage_nchw_qu8;
+#endif // ENABLE_QASYMM8_KERNELS
}
break;
}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+#ifdef ENABLE_FP16_KERNELS
case DataType::F16:
{
- _func = &output_stage_nchw<float16_t>;
+ _func = &output_stage_nchw_fp16;
break;
}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+#endif // ENABLE_FP16_KERNELS
case DataType::F32:
{
- _func = &output_stage_nchw<float>;
+#ifdef ENABLE_FP32_KERNELS
+ _func = &output_stage_nchw_fp32;
+#endif // ENABLE_FP32_KERNELS
break;
}
default:
@@ -481,24 +178,30 @@ void CpuDirectConv2dOutputStageKernel::configure(ITensorInfo
{
if (is_qasymm8_signed)
{
- _func = &output_stage_nhwc<int8_t>;
+#ifdef ENABLE_QASYMM8_SIGNED_KERNELS
+ _func = &output_stage_nhwc_qs8;
+#endif // ENABLE_QASYMM8_SIGNED_KERNELS
}
else
{
- _func = &output_stage_nhwc<uint8_t>;
+#ifdef ENABLE_QASYMM8_KERNELS
+ _func = &output_stage_nhwc_qu8;
+#endif // QASYMM8_SIGNED_KERNELS
}
break;
}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+#ifdef ENABLE_FP16_KERNELS
case DataType::F16:
{
- _func = &output_stage_nhwc<float16_t>;
+ _func = &output_stage_nhwc_fp16;
break;
}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+#endif // ENABLE_FP16_KERNELS
case DataType::F32:
{
- _func = &output_stage_nhwc<float>;
+#ifdef ENABLE_FP32_KERNELS
+ _func = &output_stage_nhwc_fp32;
+#endif // ENABLE_FP32_KERNELS
break;
}
default:
diff --git a/src/cpu/kernels/CpuDirectConv3dKernel.cpp b/src/cpu/kernels/CpuDirectConv3dKernel.cpp
index b5b2aed1ba..9c37ece3dd 100644
--- a/src/cpu/kernels/CpuDirectConv3dKernel.cpp
+++ b/src/cpu/kernels/CpuDirectConv3dKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,8 +25,8 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/IAccessWindow.h"
#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Steps.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
@@ -35,10 +35,8 @@
#include "src/core/common/Registrars.h"
#include "src/core/CPP/Validate.h"
#include "src/core/helpers/AutoConfiguration.h"
-#include "src/core/NEON/wrapper/wrapper.h"
-#include "src/cpu/kernels/conv3d/neon/list.h"
-
-#include <algorithm>
+#include "src/core/helpers/WindowHelpers.h"
+#include "src/cpu/kernels/conv3d/list.h"
using namespace arm_compute::detail;
@@ -51,18 +49,16 @@ namespace kernels
namespace
{
static const std::vector<CpuDirectConv3dKernel::DirectConv3dKernel> available_kernels = {
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
{"neon_fp16_directconv3d",
[](const DataTypeISASelectorData &data) { return data.dt == DataType::F16 && data.isa.fp16; },
- REGISTER_FP16_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float16_t>)},
-#endif /* !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
+ REGISTER_FP16_NEON(directconv3d_fp16_neon_ndhwc)},
{"neon_fp32_directconv3d", [](const DataTypeISASelectorData &data) { return data.dt == DataType::F32; },
- REGISTER_FP32_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float>)},
+ REGISTER_FP32_NEON(directconv3d_fp32_neon_ndhwc)},
{"neon_qasymm8_directconv3d", [](const DataTypeISASelectorData &data) { return data.dt == DataType::QASYMM8; },
- REGISTER_QASYMM8_NEON(arm_compute::cpu::directconv3d_quantized_neon_ndhwc<uint8_t>)},
+ REGISTER_QASYMM8_NEON(directconv3d_qu8_neon_ndhwc)},
{"neon_qasymm8_signed_directconv3d",
[](const DataTypeISASelectorData &data) { return data.dt == DataType::QASYMM8_SIGNED; },
- REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::directconv3d_quantized_neon_ndhwc<int8_t>)}};
+ REGISTER_QASYMM8_SIGNED_NEON(directconv3d_qs8_neon_ndhwc)}};
Status validate_arguments(const ITensorInfo *src0,
const ITensorInfo *src1,
diff --git a/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp b/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp
index 5b88735e7a..87340e566e 100644
--- a/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp
@@ -684,6 +684,10 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
DataType::U8);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->data_type() == DataType::QASYMM8_SIGNED &&
+ src1->data_type() == DataType::QASYMM8,
+ "QASYMM8_SIGNED input with QASYMM8 weights not supported");
+
TensorShape in0_shape = src0->tensor_shape();
TensorShape in1_shape = src1->tensor_shape();
TensorShape out_shape = dst->tensor_shape();
diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h
index 7c1e4772a6..96ddad9d19 100644
--- a/src/cpu/kernels/CpuKernelSelectionTypes.h
+++ b/src/cpu/kernels/CpuKernelSelectionTypes.h
@@ -105,6 +105,13 @@ struct SoftmaxKernelDataTypeISASelectorData
cpuinfo::CpuIsaInfo isa;
bool is_log;
int axis;
+ uint64_t sme2_vector_length;
+};
+
+struct ScatterKernelDataTypeISASelectorData
+{
+ DataType dt;
+ cpuinfo::CpuIsaInfo isa;
unsigned long sme2_vector_length;
};
@@ -124,6 +131,8 @@ using ScaleKernelDataTypeISASelectorDataPtr =
std::add_pointer<bool(const ScaleKernelDataTypeISASelectorData &data)>::type;
using SoftmaxKernelDataTypeISASelectorDataPtr =
std::add_pointer<bool(const SoftmaxKernelDataTypeISASelectorData &data)>::type;
+using ScatterKernelDataTypeISASelectorDataPtr =
+ std::add_pointer<bool(const ScatterKernelDataTypeISASelectorData &data)>::type;
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/CpuMulKernel.cpp b/src/cpu/kernels/CpuMulKernel.cpp
index 8001482154..d7a3a77d51 100644
--- a/src/cpu/kernels/CpuMulKernel.cpp
+++ b/src/cpu/kernels/CpuMulKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2023 Arm Limited.
+ * Copyright (c) 2016-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,6 +34,7 @@
#include "src/core/NEON/NESymm.h"
#include "src/core/NEON/wrapper/wrapper.h"
#include "src/cpu/kernels/mul/generic/neon/list.h"
+#include "src/cpu/kernels/mul/generic/sme2/list.h"
#include <arm_neon.h>
@@ -317,6 +318,41 @@ void mul_saturate_quantized_8(const ITensor *src1, const ITensor *src2, ITensor
}
}
+bool mul_q8_sme_possible(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst, float scale)
+{
+ const auto &in0_shape = src0->tensor_shape();
+ const auto &in1_shape = src1->tensor_shape();
+ const unsigned int dst_dims = dst->num_dimensions();
+
+ // Calculate Scale
+ const auto iq0 = src0->quantization_info().uniform();
+ const auto iq1 = src1->quantization_info().uniform();
+ const auto oq = dst->quantization_info().uniform();
+ const auto multiplier = ((iq0.scale * iq1.scale) / oq.scale) * scale;
+ const auto max_result = multiplier * (127) * (127) + static_cast<float>(oq.offset);
+ const auto min_result = multiplier * (-128) * (-128) + static_cast<float>(oq.offset);
+
+ // Does not support broadcasting on x
+ // Does not support dims > 4D output, unless input shapes are identical (therefore collapsible)
+ // Checks whether CPU has SME2 Available
+ if (in0_shape.x() == in1_shape.x() && CPUInfo::get().has_sme2() && (in0_shape == in1_shape || dst_dims <= 4))
+ {
+ // Check if multiplier cannot be stored as a 14.18 signed fixed-point number
+ if (multiplier < -8191.f || multiplier > 8191.f)
+ {
+ return false;
+ }
+ // It might not be possible to store the result as a 14.18 signed fixed-point number.
+ if (max_result > 8191.f || min_result < -8191.f)
+ {
+ return false;
+ }
+ // Passed all checks
+ return true;
+ }
+ return false;
+}
+
bool mul_q8_neon_fixedpoint_possible(const ITensorInfo *src0,
const ITensorInfo *src1,
const ITensorInfo *dst,
@@ -1563,7 +1599,11 @@ void CpuMulKernel::configure(ITensorInfo *src1,
case DataType::QASYMM8_SIGNED:
if (dt_input2 == DataType::QASYMM8_SIGNED)
{
- if (mul_q8_neon_fixedpoint_possible(src1, src2, dst, scale))
+ if (mul_q8_sme_possible(src1, src2, dst, scale) && rounding_policy == RoundingPolicy::TO_ZERO)
+ {
+ _func_quantized = REGISTER_QASYMM8_SIGNED_SME2(arm_compute::cpu::sme2_q8_signed_mul);
+ }
+ else if (mul_q8_neon_fixedpoint_possible(src1, src2, dst, scale))
{
_func_quantized = &mul_q8_neon_fixedpoint<int8_t>;
}
diff --git a/src/cpu/kernels/CpuPermuteKernel.cpp b/src/cpu/kernels/CpuPermuteKernel.cpp
index b444a25ff7..c6e0dd3a5e 100644
--- a/src/cpu/kernels/CpuPermuteKernel.cpp
+++ b/src/cpu/kernels/CpuPermuteKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -97,15 +97,12 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const
template <typename T>
void run_permute(const Window &window, const ITensor *src, const ITensor *dst, const PermutationVector &perm)
{
- const DataLayout src_layout = src->info()->data_layout();
-
// Source window
Window window_src = window;
// we only support these two configs in src/core/NEON/kernels/convolution/common/shims.hpp, for all others
// we have to fall back to C++
- if ((src_layout == DataLayout::NCHW && perm == PermutationVector{2U, 0U, 1U}) ||
- (src_layout == DataLayout::NHWC && perm == PermutationVector{1U, 2U, 0U}))
+ if (perm == PermutationVector{2U, 0U, 1U} || perm == PermutationVector{1U, 2U, 0U})
{
window_src.set(Window::DimX,
Window::Dimension(window.x().start(), window.x().end(), window.x().end() - window.x().start()));
@@ -128,49 +125,16 @@ void run_permute(const Window &window, const ITensor *src, const ITensor *dst, c
Iterator src_it(src, window_src);
Iterator dst_it(dst, window_dst);
- int in_row_stride = 0;
- int in_col_stride = 0;
- int in_channel_stride = 0;
- int in_batch_stride = 0;
- int n_cols = 0;
- int n_rows = 0;
- int n_channels = 0;
- int n_batches = 0;
-
- switch (src_layout)
- {
- case DataLayout::NCHW:
- {
- in_row_stride = src->info()->strides_in_bytes().y() / sizeof(T);
- in_channel_stride = src->info()->strides_in_bytes().z() / sizeof(T);
- in_batch_stride = src->info()->strides_in_bytes()[3] / sizeof(T);
- n_cols = src->info()->tensor_shape().x();
- n_rows = window_src.y().step();
- n_channels = src->info()->tensor_shape().z();
- n_batches = src->info()->tensor_shape()[3];
- break;
- }
- case DataLayout::NHWC:
- {
- in_col_stride = src->info()->strides_in_bytes().y() / sizeof(T);
- in_row_stride = src->info()->strides_in_bytes().z() / sizeof(T);
- in_batch_stride = src->info()->strides_in_bytes()[3] / sizeof(T);
- n_channels = src->info()->tensor_shape().x();
- n_cols = window_src.y().step();
- n_rows = src->info()->tensor_shape().z();
- n_batches = src->info()->tensor_shape()[3];
- break;
- }
- default:
- {
- ARM_COMPUTE_ERROR("Invalid source data layout.");
- break;
- }
- }
-
// CHW -> HWC
- if (src_layout == DataLayout::NCHW && perm == PermutationVector{2U, 0U, 1U})
+ if (perm == PermutationVector{2U, 0U, 1U})
{
+ const int in_row_stride = src->info()->strides_in_bytes().y() / sizeof(T);
+ const int in_channel_stride = src->info()->strides_in_bytes().z() / sizeof(T);
+ const int in_batch_stride = src->info()->strides_in_bytes()[3] / sizeof(T);
+ const int n_cols = src->info()->tensor_shape().x();
+ const int n_rows = window_src.y().step();
+ const int n_channels = src->info()->tensor_shape().z();
+ const int n_batches = src->info()->tensor_shape()[3];
const int out_channel_stride = dst->info()->strides_in_bytes().x() / sizeof(T);
const int out_col_stride = dst->info()->strides_in_bytes().y() / sizeof(T);
const int out_row_stride = dst->info()->strides_in_bytes().z() / sizeof(T);
@@ -188,8 +152,15 @@ void run_permute(const Window &window, const ITensor *src, const ITensor *dst, c
src_it, dst_it);
}
// HWC -> CHW
- else if (src_layout == DataLayout::NHWC && perm == PermutationVector{1U, 2U, 0U})
+ else if (perm == PermutationVector{1U, 2U, 0U})
{
+ const int in_col_stride = src->info()->strides_in_bytes().y() / sizeof(T);
+ const int in_row_stride = src->info()->strides_in_bytes().z() / sizeof(T);
+ const int in_batch_stride = src->info()->strides_in_bytes()[3] / sizeof(T);
+ const int n_channels = src->info()->tensor_shape().x();
+ const int n_cols = window_src.y().step();
+ const int n_rows = src->info()->tensor_shape().z();
+ const int n_batches = src->info()->tensor_shape()[3];
const int out_col_stride = dst->info()->strides_in_bytes().x() / sizeof(T);
const int out_row_stride = dst->info()->strides_in_bytes().y() / sizeof(T);
const int out_channel_stride = dst->info()->strides_in_bytes().z() / sizeof(T);
diff --git a/src/cpu/kernels/CpuReshapeKernel.cpp b/src/cpu/kernels/CpuReshapeKernel.cpp
index 241e58fbce..78b08f19d2 100644
--- a/src/cpu/kernels/CpuReshapeKernel.cpp
+++ b/src/cpu/kernels/CpuReshapeKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -233,7 +233,9 @@ void CpuReshapeKernel::prepare(ITensorPack &tensors)
if (!src_has_holes && !dst_has_holes)
{
- std::tie(win, _split_dimension) = calculate_squashed_or_max_window(*dst_info);
+ size_t split_dimension;
+
+ std::tie(win, split_dimension) = calculate_squashed_or_max_window(*dst_info);
/*
Copy the tensor per window. If the src and dst tensors
are contiguous memory allocations without any holes or
@@ -241,7 +243,15 @@ void CpuReshapeKernel::prepare(ITensorPack &tensors)
we can use use a single memcopy call to copy the whole
window in reshape_tensor_per_window fn
*/
- _reshape_tensor_fn = reshape_tensor_per_window;
+ if (split_dimension != Window::DimY)
+ {
+ // Fall back when split dimension doesn't equal Window::DimY
+ _reshape_tensor_fn = reshape_tensor_per_row;
+ }
+ else
+ {
+ _reshape_tensor_fn = reshape_tensor_per_window;
+ }
}
else
{
diff --git a/src/cpu/kernels/CpuReshapeKernel.h b/src/cpu/kernels/CpuReshapeKernel.h
index ce566fd9e2..acad441b76 100644
--- a/src/cpu/kernels/CpuReshapeKernel.h
+++ b/src/cpu/kernels/CpuReshapeKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_RESHAPE_KERNEL_H
-#define ARM_COMPUTE_CPU_RESHAPE_KERNEL_H
+#ifndef ACL_SRC_CPU_KERNELS_CPURESHAPEKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPURESHAPEKERNEL_H
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuKernel.h"
@@ -74,21 +74,10 @@ public:
*/
size_t get_mws(const CPUInfo &platform, size_t thread_count) const override;
- /** Get the preferred dimension in which the scheduler splits the work into multiple jobs.
- *
- * @return The split dimension.
- */
- size_t get_split_dimension() const
- {
- return _split_dimension;
- }
-
private:
- size_t _split_dimension{Window::DimY};
-
std::function<void(const Window &window, const ITensor *src, ITensor *dst)> _reshape_tensor_fn{};
};
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_RESHAPE_KERNEL_H */
+#endif // ACL_SRC_CPU_KERNELS_CPURESHAPEKERNEL_H
diff --git a/src/cpu/kernels/CpuScatterKernel.cpp b/src/cpu/kernels/CpuScatterKernel.cpp
new file mode 100644
index 0000000000..bc0fa724b0
--- /dev/null
+++ b/src/cpu/kernels/CpuScatterKernel.cpp
@@ -0,0 +1,98 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/cpu/kernels/CpuScatterKernel.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/TensorInfo.h"
+
+#include <vector>
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+namespace
+{
+
+/* Scatter */
+static const std::vector<typename CpuScatterKernel::ScatterKernel> available_kernels = {
+
+};
+
+} // namespace
+
+const std::vector<typename CpuScatterKernel::ScatterKernel> &CpuScatterKernel::get_available_kernels()
+{
+ return available_kernels;
+}
+
+void CpuScatterKernel::configure(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ ITensorInfo *dst,
+ const ScatterInfo &info)
+{
+ ARM_COMPUTE_UNUSED(src);
+ ARM_COMPUTE_UNUSED(updates);
+ ARM_COMPUTE_UNUSED(indices);
+ ARM_COMPUTE_UNUSED(dst);
+ ARM_COMPUTE_UNUSED(info);
+
+ ARM_COMPUTE_UNUSED(_run_method);
+}
+
+Status CpuScatterKernel::validate(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ const ITensorInfo *dst,
+ const ScatterInfo &info)
+{
+ ARM_COMPUTE_UNUSED(src);
+ ARM_COMPUTE_UNUSED(updates);
+ ARM_COMPUTE_UNUSED(indices);
+ ARM_COMPUTE_UNUSED(dst);
+ ARM_COMPUTE_UNUSED(info);
+
+ return Status{ErrorCode::RUNTIME_ERROR, "No configuration implemented yet."};
+}
+
+void CpuScatterKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(tensors);
+ ARM_COMPUTE_UNUSED(window);
+ ARM_COMPUTE_UNUSED(info);
+
+ ARM_COMPUTE_UNUSED(_run_method);
+}
+
+const char *CpuScatterKernel::name() const
+{
+ return _name.c_str();
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/CpuScatterKernel.h b/src/cpu/kernels/CpuScatterKernel.h
new file mode 100644
index 0000000000..77c436034f
--- /dev/null
+++ b/src/cpu/kernels/CpuScatterKernel.h
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_CPUSCATTERKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPUSCATTERKERNEL_H
+
+#include "arm_compute/function_info/ScatterInfo.h"
+
+#include "src/core/common/Macros.h"
+#include "src/cpu/ICpuKernel.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+/** Arm(R) Neon(TM) kernel to perform the ScatterND operation */
+class CpuScatterKernel : public ICpuKernel<CpuScatterKernel>
+{
+private:
+ using ScatterKernelPtr = std::add_pointer<void(
+ const ITensor *, const ITensor *, const ITensor *, ITensor *, const ScatterInfo, const Window &)>::type;
+
+public:
+ CpuScatterKernel() = default;
+ ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuScatterKernel);
+ /** Initialise the kernel's input and output.
+ *
+ * @param[in] src Input tensor info for the source matrix.
+ * @param[in] updates Input tensor info for the Update matrix. Data type supported: same as @p src
+ * @param[in] indices Input tensor info for the Indices matrix. Data type supported: U32.
+ * @param[out] dst Output tensor info. Data type supported: same as @p src
+ * @param[in] info Attributes for Scatter Kernel
+ */
+ void configure(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ ITensorInfo *dst,
+ const ScatterInfo &info);
+ /** Static function to check if given info will lead to a valid configuration
+ *
+ * Similar to @ref CpuScatterKernel::configure()
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ const ITensorInfo *dst,
+ const ScatterInfo &info);
+
+ // Inherited methods overridden:
+ void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
+ const char *name() const override;
+ struct ScatterKernel
+ {
+ const char *name;
+ ScatterKernelPtr ukernel;
+ };
+
+ static const std::vector<ScatterKernel> &get_available_kernels();
+
+private:
+ ScatterKernelPtr _run_method{nullptr};
+ std::string _name{};
+};
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+#endif // ACL_SRC_CPU_KERNELS_CPUSCATTERKERNEL_H
diff --git a/src/cpu/kernels/CpuWeightsReshapeKernel.cpp b/src/cpu/kernels/CpuWeightsReshapeKernel.cpp
index 297ba63826..f8e9d123b5 100644
--- a/src/cpu/kernels/CpuWeightsReshapeKernel.cpp
+++ b/src/cpu/kernels/CpuWeightsReshapeKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -72,7 +72,10 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *biases, con
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(),
get_output_shape(src, biases != nullptr));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(src, dst);
+ if (!src->quantization_info().is_dynamic())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(src, dst);
+ }
}
return Status{};
diff --git a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h
index e2a27675b3..72fafca1bb 100644
--- a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h
+++ b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h
@@ -52,7 +52,7 @@ namespace kernel
*
*
*/
-template <typename TypeInput, typename TypeOutput>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
class CpuGemmAssemblyWrapperKernel final : public INEKernel
{
public:
@@ -101,7 +101,7 @@ public:
* @param[in] kernel Pointer to an assembly kernel implementation.
* @param[in] kernel_name_tag Tag to be attacehd to the kernel's name.
*/
- void configure(arm_gemm::GemmCommon<TypeInput, TypeOutput> *kernel, std::string kernel_name_tag)
+ void configure(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *kernel, std::string kernel_name_tag)
{
ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(kernel)));
_kernel = kernel;
@@ -131,8 +131,8 @@ public:
}
private:
- arm_gemm::GemmCommon<TypeInput, TypeOutput> *_kernel;
- std::string _name;
+ arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *_kernel;
+ std::string _name;
};
} // namespace kernel
} // namespace cpu
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 941fed0ba8..cbc8be416e 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -277,8 +277,8 @@ struct Nothing
{
};
-template <typename Top, typename Tret>
-using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
+template <typename Tlop, typename Trop, typename Tret>
+using UniqueGemmCommon = std::unique_ptr<GemmCommon<Tlop, Trop, Tret>>;
/* Low level API calls.
* These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
@@ -288,13 +288,13 @@ using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
template <typename Top, typename Tret, class OutputStage = Nothing>
KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {});
-template <typename Top, typename Tret, class OutputStage = Nothing>
-UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
+template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
+UniqueGemmCommon<Tlop, Trop, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
-template <typename Top, typename Tret, class OutputStage = Nothing>
+template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
-template <typename Top, typename Tret, class OutputStage = Nothing>
+template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm
diff --git a/src/cpu/kernels/assembly/convolution_parameters.hpp b/src/cpu/kernels/assembly/convolution_parameters.hpp
index a6cf96344c..09b73ca409 100644
--- a/src/cpu/kernels/assembly/convolution_parameters.hpp
+++ b/src/cpu/kernels/assembly/convolution_parameters.hpp
@@ -61,6 +61,8 @@ struct ConvolutionParameters
int64_t output_stride_w;
int64_t output_stride_h;
// output_channels not included as they do not affect the input.
+ int64_t dilation_w;
+ int64_t dilation_h;
int64_t padding_top;
int64_t padding_left;
float padding_value;
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index 45d1e43274..d5676c134e 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -35,6 +35,7 @@ namespace arm_gemm
{
// Avoid circular dependency with arm_gemm.hpp
struct GemmConfig;
+struct Requantize32;
// Abstract class for the GEMM/GEMV functions.
//
@@ -160,6 +161,12 @@ public:
{
}
+ /*** "Quantization update" interface (optional) ***/
+ /* Update quantization parameters at run time */
+ virtual void update_quantization_parameters(const Requantize32 &)
+ {
+ }
+
/*** Convolution interface (optional) ***/
/* Set the convolution parameters. */
virtual void set_convolution_parameters(ConvolutionParameters)
@@ -189,7 +196,7 @@ public:
* 'set_arrays' to capture the provided arguments in protected class
* members, as essentially any implementation will need these.
*/
-template <typename To, typename Tr>
+template <typename To, typename Tw, typename Tr>
class GemmCommon : public IGemmCommon
{
protected:
@@ -197,7 +204,7 @@ protected:
int _lda = 0;
int _A_batch_stride = 0;
int _A_multi_stride = 0;
- const To *_Bptr = nullptr;
+ const Tw *_Bptr = nullptr;
int _ldb = 0;
int _B_multi_stride = 0;
Tr *_Cptr = nullptr;
@@ -214,7 +221,7 @@ public:
const int lda,
const int A_batch_stride,
const int A_multi_stride,
- const To *B,
+ const Tw *B,
const int ldb,
/* batches share B */ const int B_multi_stride,
Tr *C,
@@ -254,7 +261,7 @@ public:
const void *bias,
/* no row or batch stride needed */ const int bias_multi_stride) override
{
- set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const To *>(B), ldb,
+ set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const Tw *>(B), ldb,
B_multi_stride, static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride,
static_cast<const Tr *>(bias), bias_multi_stride);
}
@@ -262,17 +269,17 @@ public:
/*** "Pretransposed" interface ***/
/* Compute col sums over all columns */
- virtual void requantize_bias(void *, const To *, const int, const int){};
+ virtual void requantize_bias(void *, const Tw *, const int, const int){};
/* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
/* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
- virtual void pretranspose_B_array(void *, const To *, const int, const int, bool){};
+ virtual void pretranspose_B_array(void *, const Tw *, const int, const int, bool){};
/* Implementation of the void * overload which casts its arguments to the appropriate type. */
void pretranspose_B_array_generic(
void *out, const void *in, const int row_stride, const int multi_stride, bool transposed) override
{
- pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride, transposed);
+ pretranspose_B_array(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed);
}
/* Threaded versions of the above.
@@ -280,7 +287,7 @@ public:
* just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only
* legal values for start and end are 0 and 1 respectively. */
virtual void pretranspose_B_array_part(
- void *out, const To *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t)
+ void *out, const Tw *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t)
{
pretranspose_B_array(out, in, row_stride, multi_stride, transposed);
};
@@ -293,7 +300,7 @@ public:
size_t start,
size_t end) override
{
- pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, transposed, start, end);
+ pretranspose_B_array_part(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed, start, end);
}
/*** Indirect interface ***/
diff --git a/src/cpu/kernels/conv3d/neon/list.h b/src/cpu/kernels/conv3d/generic/neon/float_impl.h
index 082c60be29..5b5611a02f 100644
--- a/src/cpu/kernels/conv3d/neon/list.h
+++ b/src/cpu/kernels/conv3d/generic/neon/float_impl.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,21 +21,25 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
-#define SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
+#ifndef ACL_SRC_CPU_KERNELS_CONV3D_GENERIC_NEON_FLOAT_IMPL_H
+#define ACL_SRC_CPU_KERNELS_CONV3D_GENERIC_NEON_FLOAT_IMPL_H
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/Types.h"
-#include "arm_compute/core/utils/misc/Traits.h"
+#include "arm_compute/core/Window.h"
#include "arm_compute/runtime/FunctionDescriptors.h"
#include "src/core/helpers/WindowHelpers.h"
#include "src/core/NEON/wrapper/wrapper.h"
-#include "src/cpu/kernels/conv3d/neon/quantized.h"
namespace arm_compute
{
namespace cpu
{
+namespace kernels
+{
+
template <typename T>
void directconv3d_float_neon_ndhwc(const ITensor *src0,
const ITensor *src1,
@@ -192,6 +196,7 @@ void directconv3d_float_neon_ndhwc(const ITensor *src0,
out);
}
+} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif // SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
+#endif // ACL_SRC_CPU_KERNELS_CONV3D_GENERIC_NEON_FLOAT_IMPL_H
diff --git a/src/cpu/kernels/conv3d/generic/neon/fp16.cpp b/src/cpu/kernels/conv3d/generic/neon/fp16.cpp
new file mode 100644
index 0000000000..1737556e51
--- /dev/null
+++ b/src/cpu/kernels/conv3d/generic/neon/fp16.cpp
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/conv3d/generic/neon/float_impl.h"
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+void directconv3d_fp16_neon_ndhwc(const ITensor *src0,
+ const ITensor *src1,
+ const ITensor *src2,
+ ITensor *dst,
+ const Conv3dInfo &conv_info,
+ const Window &window)
+{
+ directconv3d_float_neon_ndhwc<float16_t>(src0, src1, src2, dst, conv_info, window);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
diff --git a/src/cpu/kernels/conv3d/generic/neon/fp32.cpp b/src/cpu/kernels/conv3d/generic/neon/fp32.cpp
new file mode 100644
index 0000000000..1cd0793442
--- /dev/null
+++ b/src/cpu/kernels/conv3d/generic/neon/fp32.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/conv3d/generic/neon/float_impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+
+void directconv3d_fp32_neon_ndhwc(const ITensor *src0,
+ const ITensor *src1,
+ const ITensor *src2,
+ ITensor *dst,
+ const Conv3dInfo &conv_info,
+ const Window &window)
+{
+ directconv3d_float_neon_ndhwc<float>(src0, src1, src2, dst, conv_info, window);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/conv3d/generic/neon/qasymm8.cpp b/src/cpu/kernels/conv3d/generic/neon/qasymm8.cpp
new file mode 100644
index 0000000000..d0cb6fc1c1
--- /dev/null
+++ b/src/cpu/kernels/conv3d/generic/neon/qasymm8.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/conv3d/generic/neon/quantized_impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+
+void directconv3d_qu8_neon_ndhwc(const ITensor *src0,
+ const ITensor *src1,
+ const ITensor *src2,
+ ITensor *dst,
+ const Conv3dInfo &conv_info,
+ const Window &window)
+{
+ directconv3d_quantized_neon_ndhwc<uint8_t>(src0, src1, src2, dst, conv_info, window);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/conv3d/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/conv3d/generic/neon/qasymm8_signed.cpp
new file mode 100644
index 0000000000..adffc1a3f8
--- /dev/null
+++ b/src/cpu/kernels/conv3d/generic/neon/qasymm8_signed.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/conv3d/generic/neon/quantized_impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+
+void directconv3d_qs8_neon_ndhwc(const ITensor *src0,
+ const ITensor *src1,
+ const ITensor *src2,
+ ITensor *dst,
+ const Conv3dInfo &conv_info,
+ const Window &window)
+{
+ directconv3d_quantized_neon_ndhwc<int8_t>(src0, src1, src2, dst, conv_info, window);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/conv3d/neon/quantized.h b/src/cpu/kernels/conv3d/generic/neon/quantized_impl.h
index f0fc9b5a71..b6b41035f8 100644
--- a/src/cpu/kernels/conv3d/neon/quantized.h
+++ b/src/cpu/kernels/conv3d/generic/neon/quantized_impl.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,9 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_CORE_NEON_KERNELS_CONV3D_QUANTIZED_H
-#define SRC_CORE_NEON_KERNELS_CONV3D_QUANTIZED_H
+#ifndef ACL_SRC_CPU_KERNELS_CONV3D_GENERIC_NEON_QUANTIZED_IMPL_H
+#define ACL_SRC_CPU_KERNELS_CONV3D_GENERIC_NEON_QUANTIZED_IMPL_H
+#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/utils/misc/Traits.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
@@ -37,6 +38,8 @@ namespace arm_compute
{
namespace cpu
{
+namespace kernels
+{
template <typename T>
void directconv3d_quantized_neon_ndhwc(const ITensor *src0,
const ITensor *src1,
@@ -270,6 +273,7 @@ void directconv3d_quantized_neon_ndhwc(const ITensor *src0,
},
out);
}
+} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif // SRC_CORE_NEON_KERNELS_CONV3D_QUANTIZED_H
+#endif // ACL_SRC_CPU_KERNELS_CONV3D_GENERIC_NEON_QUANTIZED_IMPL_H
diff --git a/src/cpu/kernels/conv3d/list.h b/src/cpu/kernels/conv3d/list.h
new file mode 100644
index 0000000000..256d28825d
--- /dev/null
+++ b/src/cpu/kernels/conv3d/list.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_CONV3D_LIST_H
+#define ACL_SRC_CPU_KERNELS_CONV3D_LIST_H
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+
+#define DECLARE_CONV3D_KERNEL(func_name) \
+ void func_name(const ITensor *src0, const ITensor *src1, const ITensor *src2, ITensor *dst, \
+ const Conv3dInfo &conv_info, const Window &window)
+
+DECLARE_CONV3D_KERNEL(directconv3d_fp16_neon_ndhwc);
+DECLARE_CONV3D_KERNEL(directconv3d_fp32_neon_ndhwc);
+DECLARE_CONV3D_KERNEL(directconv3d_qu8_neon_ndhwc);
+DECLARE_CONV3D_KERNEL(directconv3d_qs8_neon_ndhwc);
+#undef DECLARE_CONV3D_KERNEL
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+#endif // ACL_SRC_CPU_KERNELS_CONV3D_LIST_H
diff --git a/src/cpu/kernels/directconv2d_output_stage/generic/neon/float_impl.h b/src/cpu/kernels/directconv2d_output_stage/generic/neon/float_impl.h
new file mode 100644
index 0000000000..266fa68ab2
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/generic/neon/float_impl.h
@@ -0,0 +1,186 @@
+/*
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_GENERIC_NEON_FLOAT_IMPL_H
+#define ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_GENERIC_NEON_FLOAT_IMPL_H
+
+#include "arm_compute/core/Helpers.h" // Iterator
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+#include "src/core/NEON/wrapper/wrapper.h"
+
+#include <cstdint>
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+template <typename T>
+void output_stage_nchw_fp(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ const bool has_bias = bias != nullptr;
+ /** SIMD vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+ ARM_COMPUTE_ERROR_ON(src->info()->data_layout() == DataLayout::UNKNOWN);
+ ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
+ ARM_COMPUTE_UNUSED(result_shift);
+ ARM_COMPUTE_UNUSED(result_offset_after_shift);
+
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_step_x = 16 / src->info()->element_size();
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator in(src, win);
+ Iterator out(dst, win);
+ execute_window_loop(
+ win,
+ [&](const Coordinates &id)
+ {
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Get bias and pointer to input
+ const auto in_ptr = reinterpret_cast<const T *>(in.ptr()) + x;
+ auto v_in = wrapper::vloadq(in_ptr);
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto vb = wrapper::vdup_n(
+ *reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(id.z()))), ExactTagType{});
+ v_in = wrapper::vadd(v_in, vb);
+ }
+
+ const auto out_ptr = reinterpret_cast<T *>(out.ptr()) + x;
+ wrapper::vstore(out_ptr, v_in);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Get bias and pointer to input
+ auto s_in = *(reinterpret_cast<const T *>(in.ptr()) + x);
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto b = *reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(id.z())));
+ s_in += b;
+ }
+
+ *(reinterpret_cast<T *>(out.ptr()) + x) = s_in;
+ }
+ },
+ in, out);
+}
+
+template <typename T>
+void output_stage_nhwc_fp(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ const bool has_bias = bias != nullptr;
+ ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
+ ARM_COMPUTE_UNUSED(result_shift);
+ ARM_COMPUTE_UNUSED(result_offset_after_shift);
+
+ Window window_bias = window;
+ window_bias.set(Window::DimX, Window::Dimension(0, 1, 1));
+ window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
+ window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
+ window_bias.set(3, Window::Dimension(0, 0, 0));
+
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_step_x = 16 / src->info()->element_size();
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator in(src, win);
+ Iterator bi(bias, window_bias);
+ Iterator out(dst, win);
+
+ execute_window_loop(
+ win,
+ [&](const Coordinates &)
+ {
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Get bias and pointer to input
+ const auto in_ptr = reinterpret_cast<const T *>(in.ptr());
+ auto v_in = wrapper::vloadq(in_ptr + x);
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto bias_ptr = reinterpret_cast<T *>(bi.ptr()) + x;
+ v_in = wrapper::vadd(v_in, wrapper::vloadq(bias_ptr));
+ }
+
+ const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+ wrapper::vstore(out_ptr + x, v_in);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Get bias and pointer to input
+ auto s_in = *(reinterpret_cast<const T *>(in.ptr()) + x);
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto bias_ptr = reinterpret_cast<T *>(bi.ptr()) + x;
+ s_in += *bias_ptr;
+ }
+
+ const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+ *(out_ptr + x) = s_in;
+ }
+ },
+ in, bi, out);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_GENERIC_NEON_FLOAT_IMPL_H
diff --git a/src/cpu/kernels/directconv2d_output_stage/generic/neon/fp16.cpp b/src/cpu/kernels/directconv2d_output_stage/generic/neon/fp16.cpp
new file mode 100644
index 0000000000..d7550da721
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/generic/neon/fp16.cpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/directconv2d_output_stage/generic/neon/float_impl.h"
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+void output_stage_nhwc_fp16(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nhwc_fp<float16_t>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+void output_stage_nchw_fp16(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nchw_fp<float16_t>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
diff --git a/src/cpu/kernels/directconv2d_output_stage/generic/neon/fp32.cpp b/src/cpu/kernels/directconv2d_output_stage/generic/neon/fp32.cpp
new file mode 100644
index 0000000000..05dec370b9
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/generic/neon/fp32.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/directconv2d_output_stage/generic/neon/float_impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+void output_stage_nhwc_fp32(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nhwc_fp<float>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+void output_stage_nchw_fp32(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nchw_fp<float>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8.cpp b/src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8.cpp
new file mode 100644
index 0000000000..dbdf951b6f
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/directconv2d_output_stage/generic/neon/quantized_impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+void output_stage_nhwc_qu8(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nhwc_quant<uint8_t>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+void output_stage_nchw_qu8(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nchw_quant<uint8_t>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8_signed.cpp
new file mode 100644
index 0000000000..c00fe87161
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8_signed.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/directconv2d_output_stage/generic/neon/quantized_impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+void output_stage_nhwc_qs8(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nhwc_quant<int8_t>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+void output_stage_nchw_qs8(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nchw_quant<int8_t>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/directconv2d_output_stage/generic/neon/quantized_impl.h b/src/cpu/kernels/directconv2d_output_stage/generic/neon/quantized_impl.h
new file mode 100644
index 0000000000..f74ed55ad0
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/generic/neon/quantized_impl.h
@@ -0,0 +1,213 @@
+/*
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_GENERIC_NEON_QUANTIZED_IMPL_H
+#define ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_GENERIC_NEON_QUANTIZED_IMPL_H
+
+#include "arm_compute/core/Helpers.h" // Iterator
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+#include "src/core/NEON/NEAsymm.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+
+#include <cstdint>
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+
+template <typename TOut>
+void output_stage_nchw_quant(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ const bool has_bias = bias != nullptr;
+ using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
+ using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
+
+ const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
+
+ const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
+ const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
+
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_step_x = 16 / src->info()->element_size();
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator in(src, win);
+ Iterator out(dst, win);
+
+ execute_window_loop(
+ win,
+ [&](const Coordinates &id)
+ {
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Get bias and pointer to input
+ const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr()) + x;
+ int32x4x4_t v_in = {{wrapper::vloadq(in_ptr), wrapper::vloadq(in_ptr + 4), wrapper::vloadq(in_ptr + 8),
+ wrapper::vloadq(in_ptr + 12)}};
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto vb = wrapper::vdup_n(
+ *reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z()))), TagType{});
+ v_in = {{wrapper::vadd(v_in.val[0], vb), wrapper::vadd(v_in.val[1], vb),
+ wrapper::vadd(v_in.val[2], vb), wrapper::vadd(v_in.val[3], vb)}};
+ }
+
+ const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
+ wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift_s32, min, max, false));
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Get bias and pointer to input
+ int32_t s_in = *(reinterpret_cast<const int32_t *>(in.ptr()) + x);
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto b = *reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z())));
+ s_in += b;
+ }
+
+ const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
+ *out_ptr =
+ finalize_quantization(s_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift,
+ std::numeric_limits<TOut>::lowest(), std::numeric_limits<TOut>::max(), false);
+ }
+ },
+ in, out);
+}
+template <
+ typename TOut,
+ typename std::enable_if<std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int>::type = 0>
+void output_stage_nhwc_quant(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ const bool has_bias = bias != nullptr;
+ using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
+ using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
+
+ const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
+
+ const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
+ const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
+
+ Window window_bias = window;
+ window_bias.set(Window::DimX, Window::Dimension(0, 1, 1));
+ window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
+ window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
+ window_bias.set(3, Window::Dimension(0, 0, 0));
+
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_step_x = 16 / src->info()->element_size();
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator in(src, win);
+ Iterator bi(bias, window_bias);
+ Iterator out(dst, win);
+
+ execute_window_loop(
+ win,
+ [&](const Coordinates &)
+ {
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Get bias and pointer to input
+ const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr()) + x;
+ int32x4x4_t v_in = {{
+ wrapper::vloadq(in_ptr),
+ wrapper::vloadq(in_ptr + 4),
+ wrapper::vloadq(in_ptr + 8),
+ wrapper::vloadq(in_ptr + 12),
+ }};
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr()) + x;
+
+ wrapper::vadd(v_in.val[0], wrapper::vloadq(bias_ptr));
+ wrapper::vadd(v_in.val[1], wrapper::vloadq(bias_ptr + 4));
+ wrapper::vadd(v_in.val[2], wrapper::vloadq(bias_ptr + 8));
+ wrapper::vadd(v_in.val[3], wrapper::vloadq(bias_ptr + 12));
+ }
+
+ const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
+ wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift_s32, min, max, false));
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Get bias and pointer to input
+ const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr()) + x;
+ int32_t s_in = *in_ptr;
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr()) + x;
+ s_in += *bias_ptr;
+ }
+
+ const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
+ *out_ptr =
+ finalize_quantization(s_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift,
+ std::numeric_limits<TOut>::lowest(), std::numeric_limits<TOut>::max(), false);
+ }
+ },
+ in, bi, out);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_GENERIC_NEON_QUANTIZED_IMPL_H
diff --git a/src/cpu/kernels/directconv2d_output_stage/list.h b/src/cpu/kernels/directconv2d_output_stage/list.h
new file mode 100644
index 0000000000..9372269bca
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/list.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_LIST_H
+#define ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_LIST_H
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+
+#define DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(func_name) \
+ void func_name(ITensor *src, const ITensor *bias, const Window &window, ITensor *dst, \
+ int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nhwc_fp32);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nhwc_fp16);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nchw_fp32);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nchw_fp16);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nhwc_qs8);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nhwc_qu8);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nchw_qs8);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nchw_qu8);
+
+#undef DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_LIST_H
diff --git a/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp b/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp
index 404d070a37..580fdc3e8f 100644
--- a/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp
+++ b/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -81,7 +81,7 @@ void vector_matrix_multiply_f32(
// window_end_x is computed above which may cause out-of-bound writes to the dst.
for (; x < (window_end_x - window_step_x); x += window_step_x)
{
- if (x > width_matrix_b)
+ if (x >= width_matrix_b)
{
return;
}
@@ -203,7 +203,7 @@ void vector_matrix_multiply_f32(
// Left-over loop
for (; x < window_end_x; ++x)
{
- if (x > width_matrix_b)
+ if (x >= width_matrix_b)
{
return;
}
@@ -309,9 +309,21 @@ void matrix_matrix_multiply_f32(
Iterator inb(rhs, win_b);
Iterator out(dst, window);
- const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
+ // End address of matrix B at batch number n
+ const float *end_addr_mtx_b_at_batch_n =
+ reinterpret_cast<const float *>(inb.ptr()) + rhs->info()->dimension(0) * rhs->info()->dimension(1);
+ std::vector<const float *> end_addr_mtx_b_per_batch = {};
+ const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
+ const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
+ const size_t out_dim2 = static_cast<int>(dst->info()->dimension(2));
- const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
+ for (size_t b = 0; b < out_dim2; ++b)
+ {
+ // Store the ptrs to the last elem in the tensor for each batch
+ end_addr_mtx_b_per_batch.push_back(end_addr_mtx_b_at_batch_n);
+ end_addr_mtx_b_at_batch_n +=
+ rhs->info()->dimension(2) != 1 ? rhs->info()->dimension(0) * rhs->info()->dimension(1) : 0;
+ }
// The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW
// The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
@@ -341,220 +353,374 @@ void matrix_matrix_multiply_f32(
#endif /* __arm__ */
auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
- for (; mtx_b0 <= (mtx_b0_end_addr - 32);)
+
+ ARM_COMPUTE_ERROR_ON(end_addr_mtx_b_per_batch.size() == 0);
+ if (mtx_b1 < end_addr_mtx_b_per_batch[id.z()])
{
- float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
- float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
- float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
- float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
+ for (; mtx_b0 < (mtx_b0_end_addr - 32);)
+ {
+ float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
+ float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
+ float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
+ float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
- float32x4_t b00 = vld1q_f32(mtx_b0);
- float32x4_t b10 = vld1q_f32(mtx_b1);
- float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
- float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
+ float32x4_t b00 = vld1q_f32(mtx_b0);
+ float32x4_t b10 = vld1q_f32(mtx_b1);
+ float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
+ float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
#if __arm__
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
#endif /* __arm__ */
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b00, a0);
- acc10 = vmlaq_f32(acc10, b00, a1);
- acc20 = vmlaq_f32(acc20, b00, a2);
- acc30 = vmlaq_f32(acc30, b00, a3);
-
- float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
- float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
- float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
- float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b10, a0);
- acc11 = vmlaq_f32(acc11, b10, a1);
- acc21 = vmlaq_f32(acc21, b10, a2);
- acc31 = vmlaq_f32(acc31, b10, a3);
-
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b01, a4);
- acc10 = vmlaq_f32(acc10, b01, a5);
- acc20 = vmlaq_f32(acc20, b01, a6);
- acc30 = vmlaq_f32(acc30, b01, a7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b11, a4);
- acc11 = vmlaq_f32(acc11, b11, a5);
- acc21 = vmlaq_f32(acc21, b11, a6);
- acc31 = vmlaq_f32(acc31, b11, a7);
-
- mtx_a0 += 8;
- mtx_b0 += 8;
- mtx_b1 += 8;
-
- a0 = vld1q_dup_f32(mtx_a0 + 0);
- a1 = vld1q_dup_f32(mtx_a0 + 1);
- a2 = vld1q_dup_f32(mtx_a0 + 2);
- a3 = vld1q_dup_f32(mtx_a0 + 3);
-
- b00 = vld1q_f32(mtx_b0);
- b10 = vld1q_f32(mtx_b1);
- b01 = vld1q_f32(mtx_b0 + 4);
- b11 = vld1q_f32(mtx_b1 + 4);
-
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b00, a0);
- acc10 = vmlaq_f32(acc10, b00, a1);
- acc20 = vmlaq_f32(acc20, b00, a2);
- acc30 = vmlaq_f32(acc30, b00, a3);
-
- a4 = vld1q_dup_f32(mtx_a0 + 4);
- a5 = vld1q_dup_f32(mtx_a0 + 5);
- a6 = vld1q_dup_f32(mtx_a0 + 6);
- a7 = vld1q_dup_f32(mtx_a0 + 7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b10, a0);
- acc11 = vmlaq_f32(acc11, b10, a1);
- acc21 = vmlaq_f32(acc21, b10, a2);
- acc31 = vmlaq_f32(acc31, b10, a3);
-
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b01, a4);
- acc10 = vmlaq_f32(acc10, b01, a5);
- acc20 = vmlaq_f32(acc20, b01, a6);
- acc30 = vmlaq_f32(acc30, b01, a7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b11, a4);
- acc11 = vmlaq_f32(acc11, b11, a5);
- acc21 = vmlaq_f32(acc21, b11, a6);
- acc31 = vmlaq_f32(acc31, b11, a7);
-
- mtx_a0 += 8;
- mtx_b0 += 8;
- mtx_b1 += 8;
-
- a0 = vld1q_dup_f32(mtx_a0 + 0);
- a1 = vld1q_dup_f32(mtx_a0 + 1);
- a2 = vld1q_dup_f32(mtx_a0 + 2);
- a3 = vld1q_dup_f32(mtx_a0 + 3);
- b00 = vld1q_f32(mtx_b0);
- b10 = vld1q_f32(mtx_b1);
- b01 = vld1q_f32(mtx_b0 + 4);
- b11 = vld1q_f32(mtx_b1 + 4);
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
+ float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
+ float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
+ float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b10, a0);
+ acc11 = vmlaq_f32(acc11, b10, a1);
+ acc21 = vmlaq_f32(acc21, b10, a2);
+ acc31 = vmlaq_f32(acc31, b10, a3);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b11, a4);
+ acc11 = vmlaq_f32(acc11, b11, a5);
+ acc21 = vmlaq_f32(acc21, b11, a6);
+ acc31 = vmlaq_f32(acc31, b11, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+ mtx_b1 += 8;
+
+ a0 = vld1q_dup_f32(mtx_a0 + 0);
+ a1 = vld1q_dup_f32(mtx_a0 + 1);
+ a2 = vld1q_dup_f32(mtx_a0 + 2);
+ a3 = vld1q_dup_f32(mtx_a0 + 3);
+
+ b00 = vld1q_f32(mtx_b0);
+ b10 = vld1q_f32(mtx_b1);
+ b01 = vld1q_f32(mtx_b0 + 4);
+ b11 = vld1q_f32(mtx_b1 + 4);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ a4 = vld1q_dup_f32(mtx_a0 + 4);
+ a5 = vld1q_dup_f32(mtx_a0 + 5);
+ a6 = vld1q_dup_f32(mtx_a0 + 6);
+ a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b10, a0);
+ acc11 = vmlaq_f32(acc11, b10, a1);
+ acc21 = vmlaq_f32(acc21, b10, a2);
+ acc31 = vmlaq_f32(acc31, b10, a3);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b11, a4);
+ acc11 = vmlaq_f32(acc11, b11, a5);
+ acc21 = vmlaq_f32(acc21, b11, a6);
+ acc31 = vmlaq_f32(acc31, b11, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+ mtx_b1 += 8;
+
+ a0 = vld1q_dup_f32(mtx_a0 + 0);
+ a1 = vld1q_dup_f32(mtx_a0 + 1);
+ a2 = vld1q_dup_f32(mtx_a0 + 2);
+ a3 = vld1q_dup_f32(mtx_a0 + 3);
+ b00 = vld1q_f32(mtx_b0);
+ b10 = vld1q_f32(mtx_b1);
+ b01 = vld1q_f32(mtx_b0 + 4);
+ b11 = vld1q_f32(mtx_b1 + 4);
#if __arm__
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
#endif /* __arm__ */
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b00, a0);
- acc10 = vmlaq_f32(acc10, b00, a1);
- acc20 = vmlaq_f32(acc20, b00, a2);
- acc30 = vmlaq_f32(acc30, b00, a3);
-
- a4 = vld1q_dup_f32(mtx_a0 + 4);
- a5 = vld1q_dup_f32(mtx_a0 + 5);
- a6 = vld1q_dup_f32(mtx_a0 + 6);
- a7 = vld1q_dup_f32(mtx_a0 + 7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b10, a0);
- acc11 = vmlaq_f32(acc11, b10, a1);
- acc21 = vmlaq_f32(acc21, b10, a2);
- acc31 = vmlaq_f32(acc31, b10, a3);
-
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b01, a4);
- acc10 = vmlaq_f32(acc10, b01, a5);
- acc20 = vmlaq_f32(acc20, b01, a6);
- acc30 = vmlaq_f32(acc30, b01, a7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b11, a4);
- acc11 = vmlaq_f32(acc11, b11, a5);
- acc21 = vmlaq_f32(acc21, b11, a6);
- acc31 = vmlaq_f32(acc31, b11, a7);
-
- mtx_a0 += 8;
- mtx_b0 += 8;
- mtx_b1 += 8;
-
- a0 = vld1q_dup_f32(mtx_a0 + 0);
- a1 = vld1q_dup_f32(mtx_a0 + 1);
- a2 = vld1q_dup_f32(mtx_a0 + 2);
- a3 = vld1q_dup_f32(mtx_a0 + 3);
- b00 = vld1q_f32(mtx_b0);
- b10 = vld1q_f32(mtx_b1);
- b01 = vld1q_f32(mtx_b0 + 4);
- b11 = vld1q_f32(mtx_b1 + 4);
-
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b00, a0);
- acc10 = vmlaq_f32(acc10, b00, a1);
- acc20 = vmlaq_f32(acc20, b00, a2);
- acc30 = vmlaq_f32(acc30, b00, a3);
-
- a4 = vld1q_dup_f32(mtx_a0 + 4);
- a5 = vld1q_dup_f32(mtx_a0 + 5);
- a6 = vld1q_dup_f32(mtx_a0 + 6);
- a7 = vld1q_dup_f32(mtx_a0 + 7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b10, a0);
- acc11 = vmlaq_f32(acc11, b10, a1);
- acc21 = vmlaq_f32(acc21, b10, a2);
- acc31 = vmlaq_f32(acc31, b10, a3);
-
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b01, a4);
- acc10 = vmlaq_f32(acc10, b01, a5);
- acc20 = vmlaq_f32(acc20, b01, a6);
- acc30 = vmlaq_f32(acc30, b01, a7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b11, a4);
- acc11 = vmlaq_f32(acc11, b11, a5);
- acc21 = vmlaq_f32(acc21, b11, a6);
- acc31 = vmlaq_f32(acc31, b11, a7);
-
- mtx_a0 += 8;
- mtx_b0 += 8;
- mtx_b1 += 8;
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ a4 = vld1q_dup_f32(mtx_a0 + 4);
+ a5 = vld1q_dup_f32(mtx_a0 + 5);
+ a6 = vld1q_dup_f32(mtx_a0 + 6);
+ a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b10, a0);
+ acc11 = vmlaq_f32(acc11, b10, a1);
+ acc21 = vmlaq_f32(acc21, b10, a2);
+ acc31 = vmlaq_f32(acc31, b10, a3);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b11, a4);
+ acc11 = vmlaq_f32(acc11, b11, a5);
+ acc21 = vmlaq_f32(acc21, b11, a6);
+ acc31 = vmlaq_f32(acc31, b11, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+ mtx_b1 += 8;
+
+ a0 = vld1q_dup_f32(mtx_a0 + 0);
+ a1 = vld1q_dup_f32(mtx_a0 + 1);
+ a2 = vld1q_dup_f32(mtx_a0 + 2);
+ a3 = vld1q_dup_f32(mtx_a0 + 3);
+ b00 = vld1q_f32(mtx_b0);
+ b10 = vld1q_f32(mtx_b1);
+ b01 = vld1q_f32(mtx_b0 + 4);
+ b11 = vld1q_f32(mtx_b1 + 4);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ a4 = vld1q_dup_f32(mtx_a0 + 4);
+ a5 = vld1q_dup_f32(mtx_a0 + 5);
+ a6 = vld1q_dup_f32(mtx_a0 + 6);
+ a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b10, a0);
+ acc11 = vmlaq_f32(acc11, b10, a1);
+ acc21 = vmlaq_f32(acc21, b10, a2);
+ acc31 = vmlaq_f32(acc31, b10, a3);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b11, a4);
+ acc11 = vmlaq_f32(acc11, b11, a5);
+ acc21 = vmlaq_f32(acc21, b11, a6);
+ acc31 = vmlaq_f32(acc31, b11, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+ mtx_b1 += 8;
+ }
+
+ // Only consider one row from matrix b if subsequent row is out of boundary.
+ for (; mtx_b0 < mtx_b0_end_addr;)
+ {
+ float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
+ float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
+ float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
+ float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
+ float32x4_t b00 = vld1q_f32(mtx_b0);
+ float32x4_t b10 = vld1q_f32(mtx_b1);
+
+#if __arm__
+ asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
+ asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
+ asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
+#endif /* __arm__ */
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b10, a0);
+ acc11 = vmlaq_f32(acc11, b10, a1);
+ acc21 = vmlaq_f32(acc21, b10, a2);
+ acc31 = vmlaq_f32(acc31, b10, a3);
+
+ mtx_a0 += 4;
+ mtx_b0 += 4;
+ mtx_b1 += 4;
+ }
}
- for (; mtx_b0 < mtx_b0_end_addr;)
+ // Leftover last row in matrix b, in case of there are odd number of rows in matrix B
+ else if (mtx_b0 < end_addr_mtx_b_per_batch[id.z()])
{
- float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
- float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
- float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
- float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
- float32x4_t b00 = vld1q_f32(mtx_b0);
- float32x4_t b10 = vld1q_f32(mtx_b1);
+ for (; mtx_b0 < (mtx_b0_end_addr - 32);)
+ {
+ float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
+ float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
+ float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
+ float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
+
+ float32x4_t b00 = vld1q_f32(mtx_b0);
+ float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
#if __arm__
- asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
- asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
- asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
#endif /* __arm__ */
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b00, a0);
- acc10 = vmlaq_f32(acc10, b00, a1);
- acc20 = vmlaq_f32(acc20, b00, a2);
- acc30 = vmlaq_f32(acc30, b00, a3);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b10, a0);
- acc11 = vmlaq_f32(acc11, b10, a1);
- acc21 = vmlaq_f32(acc21, b10, a2);
- acc31 = vmlaq_f32(acc31, b10, a3);
-
- mtx_a0 += 4;
- mtx_b0 += 4;
- mtx_b1 += 4;
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
+ float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
+ float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
+ float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+
+ a0 = vld1q_dup_f32(mtx_a0 + 0);
+ a1 = vld1q_dup_f32(mtx_a0 + 1);
+ a2 = vld1q_dup_f32(mtx_a0 + 2);
+ a3 = vld1q_dup_f32(mtx_a0 + 3);
+
+ b00 = vld1q_f32(mtx_b0);
+ b01 = vld1q_f32(mtx_b0 + 4);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ a4 = vld1q_dup_f32(mtx_a0 + 4);
+ a5 = vld1q_dup_f32(mtx_a0 + 5);
+ a6 = vld1q_dup_f32(mtx_a0 + 6);
+ a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+
+ a0 = vld1q_dup_f32(mtx_a0 + 0);
+ a1 = vld1q_dup_f32(mtx_a0 + 1);
+ a2 = vld1q_dup_f32(mtx_a0 + 2);
+ a3 = vld1q_dup_f32(mtx_a0 + 3);
+ b00 = vld1q_f32(mtx_b0);
+ b01 = vld1q_f32(mtx_b0 + 4);
+
+#if __arm__
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
+#endif /* __arm__ */
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ a4 = vld1q_dup_f32(mtx_a0 + 4);
+ a5 = vld1q_dup_f32(mtx_a0 + 5);
+ a6 = vld1q_dup_f32(mtx_a0 + 6);
+ a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+
+ a0 = vld1q_dup_f32(mtx_a0 + 0);
+ a1 = vld1q_dup_f32(mtx_a0 + 1);
+ a2 = vld1q_dup_f32(mtx_a0 + 2);
+ a3 = vld1q_dup_f32(mtx_a0 + 3);
+ b00 = vld1q_f32(mtx_b0);
+ b01 = vld1q_f32(mtx_b0 + 4);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ a4 = vld1q_dup_f32(mtx_a0 + 4);
+ a5 = vld1q_dup_f32(mtx_a0 + 5);
+ a6 = vld1q_dup_f32(mtx_a0 + 6);
+ a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+ }
+ for (; mtx_b0 < mtx_b0_end_addr;)
+ {
+ float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
+ float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
+ float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
+ float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
+ float32x4_t b00 = vld1q_f32(mtx_b0);
+
+#if __arm__
+ asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
+ asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
+#endif /* __arm__ */
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ mtx_a0 += 4;
+ mtx_b0 += 4;
+ }
}
// Multiply by the weight of matrix product (alpha)
diff --git a/src/cpu/kernels/logistic/generic/sme2/fp32.cpp b/src/cpu/kernels/logistic/generic/sme2/fp32.cpp
new file mode 100644
index 0000000000..876e466594
--- /dev/null
+++ b/src/cpu/kernels/logistic/generic/sme2/fp32.cpp
@@ -0,0 +1,429 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// This function expects a collapsed 2D shape.
+void sme2_f32_logistic_kernel(const float *src,
+ float *dst,
+ const uintptr_t shape[2],
+ const uintptr_t src_strides[2],
+ const uintptr_t dst_strides[2])
+{
+ // Precondition:
+ assert(src_strides[0] == sizeof(float));
+ assert(dst_strides[0] == sizeof(float));
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+
+ ptrue p0.b
+ .inst 0x25207811 // ptrue pn9.b
+
+ // Registers
+ //
+ // * x9: temporary, index
+ // * x10: temporary, inf
+ // * x11: temporary, 0
+ // * x12: temporary, 1.0f
+ // * x13: temporary, body_length
+ //
+ // * x26: index_1
+ // * x27: src_1
+ // * x28: dst_1
+ //
+ // * z0: c1
+ // * z1: c2
+ // * z2: c3
+ // * z3: c4
+ // * z4: c5
+ // * z5: shift
+ // * z6: inv_ln2
+ // * z7: neg_ln2_hi
+ // * z8: neg_ln2_lo
+ // * z9: min_input, max_input
+ // * z10: 23, 0, 1, inf
+ // * z11: max_value
+ // * z12-z15: x, r_hi, r, r2
+ // * z16-z19: max_value, shift, z, scale, poly
+ // * z20-z21: n, p1, p12345
+ // * z22-z23: n, p23, p2345
+ // * z24-z25: p45
+ // * z26: max_input
+ // * z28-z31: sum_value
+ //
+ // * za0-za3: sum_value
+ //
+ // * p0: all-true
+ // * p1-p4: underflow,
+ // * p4: leftover predicate
+ // * p5-p8: overflow,
+ // * pn9: all-true
+
+ // TAYLORS CONSTANTS
+ mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ dup z0.s, w9 // c1.
+ dup z1.s, w10 // c2.
+ dup z2.s, w11 // c3.
+ dup z3.s, w12 // c4.
+ dup z4.s, w13 // c5.
+
+ mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+ mov w14, #0xBD71 // max_input 88.37 = 0x42B0BD71
+
+ movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+ movk w14, #0x42B0, LSL #16 // max_input (88.37) = 0x42B0BD71
+
+ dup z5.s, w9 // shift
+ dup z6.s, w10 // inv_ln2
+ dup z7.s, w11 // neg_ln2_hi
+ dup z8.s, w12 // neg_ln2_lo
+ dup z9.s, w13 // min_input
+ dup z26.s, w14 // max_input
+
+ mov w10, #0x0000 // inf: 0x7F800000
+ movk w10, #0x7F80, LSL #16 // inf: 0x7F800000
+
+ mov w15, #0x0000
+ movk w15, #0x3F80, LSL #16 // 1
+
+ mov w11, #0 // 0
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cntw x13, ALL, MUL #4 // x13 is vl
+ udiv x9, %x[length], x13 // length/vl
+ mul x13, x13, x9 // x13 = vl * result
+
+ // ==================================================
+ // Outer loop opening
+ // ==================================================
+
+ mov x27, %x[src] // starting point of pointers for src.
+ mov x28, %x[dst] // starting point of pointers for dst.
+ mov x26, %x[shape_1]
+
+outer_loop_start%=:
+ // for index_1 in shape_1 downto 1
+ cmp x26, #0
+ b.eq outer_loop_end%=
+ sub x26, x26, #1
+
+ mov x9, #0 // x9: index
+
+inner_body_start%=:
+ cmp x9, x13
+ b.eq inner_body_end%=
+
+ // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data
+ .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2]
+
+ // ---------------------------------------------------------------- z12-z15: x = neg(x)
+ fneg z12.s, p0/m, z12.s
+ fneg z13.s, p0/m, z13.s
+ fneg z14.s, p0/m, z14.s
+ fneg z15.s, p0/m, z15.s
+
+ // ---------------------------------------------------------------- p4-p7: underflow = x < min_input
+ fcmlt p1.s, p0/z, z12.s, z9.s
+ fcmlt p2.s, p0/z, z13.s, z9.s
+ fcmlt p3.s, p0/z, z14.s, z9.s
+ fcmlt p4.s, p0/z, z15.s, z9.s
+
+ // ---------------------------------------------------------------- p4-p7: overflow = x > max_input
+ fcmlt p5.s, p0/z, z26.s, z12.s
+ fcmlt p6.s, p0/z, z26.s, z13.s
+ fcmlt p7.s, p0/z, z26.s, z14.s
+ fcmlt p8.s, p0/z, z26.s, z15.s
+
+ // ---------------------------------------------------------------- z16-z19: shift
+ mov z16.d, z5.d
+ mov z17.d, z5.d
+ mov z18.d, z5.d
+ mov z19.d, z5.d
+
+ // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
+ fmla z16.s, p0/m, z12.s, z6.s
+ fmla z17.s, p0/m, z13.s, z6.s
+ fmla z18.s, p0/m, z14.s, z6.s
+ fmla z19.s, p0/m, z15.s, z6.s
+
+ // ---------------------------------------------------------------- z20-z23: n = z - shift
+ fsub z20.s, z16.s, z5.s
+ fsub z21.s, z17.s, z5.s
+ fsub z22.s, z18.s, z5.s
+ fsub z23.s, z19.s, z5.s
+
+ // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p0/m, z20.s, z7.s
+ fmla z13.s, p0/m, z21.s, z7.s
+ fmla z14.s, p0/m, z22.s, z7.s
+ fmla z15.s, p0/m, z23.s, z7.s
+
+ // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo
+ fmla z12.s, p0/m, z20.s, z8.s
+ fmla z13.s, p0/m, z21.s, z8.s
+ fmla z14.s, p0/m, z22.s, z8.s
+ fmla z15.s, p0/m, z23.s, z8.s
+
+ // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
+ dup z10.s, #23
+ urshl z16.s, p0/m, z16.s, z10.s
+ urshl z17.s, p0/m, z17.s, z10.s
+ urshl z18.s, p0/m, z18.s, z10.s
+ urshl z19.s, p0/m, z19.s, z10.s
+
+ // Processes the first 2 vectors.
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z12.s, z0.s
+ fmul z21.s, z13.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z12.s, z2.s
+ fmla z23.s, p0/m, z13.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z12.s, z4.s
+ fmla z25.s, p0/m, z13.s, z4.s
+
+ // ---------------------------------------------------------------- z12-z13: r2 = r * r
+ fmul z12.s, z12.s, z12.s
+ fmul z13.s, z13.s, z13.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z12.s, z24.s
+ fmla z23.s, p0/m, z13.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z12.s, z22.s
+ fmla z21.s, p0/m, z13.s, z23.s
+
+ // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
+ fmla z16.s, p0/m, z20.s, z16.s
+ fmla z17.s, p0/m, z21.s, z17.s
+
+ // Processes the last 2 vectors
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z14.s, z0.s
+ fmul z21.s, z15.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z14.s, z2.s
+ fmla z23.s, p0/m, z15.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z14.s, z4.s
+ fmla z25.s, p0/m, z15.s, z4.s
+
+ // ---------------------------------------------------------------- z14-z15: r2 = r * r
+ fmul z14.s, z14.s, z14.s
+ fmul z15.s, z15.s, z15.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z14.s, z24.s
+ fmla z23.s, p0/m, z15.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z14.s, z22.s
+ fmla z21.s, p0/m, z15.s, z23.s
+
+ // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
+ fmla z18.s, p0/m, z20.s, z18.s
+ fmla z19.s, p0/m, z21.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
+ dup z10.s, #0
+ sel z16.s, p1, z10.s, z16.s
+ sel z17.s, p2, z10.s, z17.s
+ sel z18.s, p3, z10.s, z18.s
+ sel z19.s, p4, z10.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = overflow ? inf : poly
+ dup z10.s, w10 // z10: inf
+ sel z16.s, p5, z10.s, z16.s
+ sel z17.s, p6, z10.s, z17.s
+ sel z18.s, p7, z10.s, z18.s
+ sel z19.s, p8, z10.s, z19.s
+
+ // 1 / 1 + poly
+ dup z10.s, w15 // z10: 1
+ fadd z16.s, z10.s, z16.s // poly + 1
+ fadd z17.s, z10.s, z17.s // poly + 1
+ fadd z18.s, z10.s, z18.s // poly + 1
+ fadd z19.s, z10.s, z19.s // poly + 1
+
+ fdivr z16.s, p0/m, z16.s, z10.s // z16: 1/(poly+1)
+ fdivr z17.s, p0/m, z17.s, z10.s // z16: 1/(poly+1)
+ fdivr z18.s, p0/m, z18.s, z10.s // z16: 1/(poly+1)
+ fdivr z19.s, p0/m, z19.s, z10.s // z16: 1/(poly+1)
+
+ // Stores 4 consecutive registers to the output
+ .inst 0xa029c790 // st1w {z16.s-z19.s}, pn9, [x28, x9, LSL #2]
+
+ incw x9, ALL, MUL #4
+ b inner_body_start%=
+inner_body_end%=:
+
+inner_leftover_start%=:
+ // Largely ordinary Sve code to handle taylor series 1/1+e^-x for leftover loop.
+ whilelo p1.s, x9, %x[length] // While x9<length
+ b.none inner_leftover_end%=
+
+ ld1w z12.s, p1/z, [x27, x9, LSL #2] // x12: input_data (LOADS POINTERS)
+ fneg z12.s, p1/m, z12.s
+
+ mov z16.d, z5.d // z16: shift
+ fcmlt p4.s, p1/z, z12.s, z9.s // p4: underflow = x < min_input
+ fcmlt p5.s, p1/z, z26.s, z12.s // p5: overflow = x > max_input
+
+ fmla z16.s, p1/m, z12.s, z6.s // z16: z = shift + x * inv_ln2
+ fsub z20.s, z16.s, z5.s // z20: n = z - shift
+ fmla z12.s, p1/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p1/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo
+ dup z10.s, #23 // z10: 23
+ urshl z16.s, p1/m, z16.s, z10.s // z16: scale = z << 23 (2^n)
+ fmul z20.s, z12.s, z0.s // z20: p1 = r * c1
+ mov z22.d, z1.d // z22: p23 = c2
+ fmla z22.s, p1/m, z12.s, z2.s // z22: p23 = c2 + r * c3
+ mov z24.d, z3.d // z24: c4
+ fmla z24.s, p1/m, z12.s, z4.s // z24: p45 = c4 + r * c5
+ fmul z12.s, z12.s, z12.s // z12: r2 = r * r
+ fmla z22.s, p1/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45
+ fmla z20.s, p1/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345
+ fmla z16.s, p1/m, z20.s, z16.s // z16: poly = scale + p12345 * scale
+ dup z10.s, #0 // z10: 0
+ sel z16.s, p4, z10.s, z16.s // z16: poly = underflow ? 0 : poly
+ dup z10.s, w10
+ sel z16.s, p5, z10.s, z16.s // z16: poly = overflow ? inf : poly
+
+ // 1 / 1+poly
+ dup z10.s, w15 // z10: 1
+ fadd z16.s, z10.s, z16.s // z16: z16 + 1
+ fdivr z16.s, p0/m, z16.s, z10.s // z16: 1/(poly+1)
+
+ st1w z16.s, p1, [x28, x9, LSL #2]
+
+ incw x9 // each word + 1
+ b inner_leftover_start%=
+inner_leftover_end%=:
+
+ // ==================================================
+ // Outer loop closing
+ // ==================================================
+
+ add x27, x27, %x[src_stride_1]
+ add x28, x28, %x[dst_stride_1]
+ b outer_loop_start%=
+outer_loop_end%=:
+
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : [src] "r"(src), [dst] "r"(dst), [shape_1] "r"(shape[1]), [src_stride_1] "r"(src_strides[1]),
+ [dst_stride_1] "r"(dst_strides[1]), [length] "r"(shape[0])
+ : "cc", "memory", //
+ "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p9", //
+ "x9", "x10", "x11", "x12", "x13", "x14", "x15", //
+ "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_fp32_logistic(const ITensor *in, ITensor *out, const ActivationLayerInfo &act_info, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(act_info);
+ const auto *src_info = in->info();
+ const auto *dst_info = out->info();
+
+ const auto &src_strides = src_info->strides_in_bytes();
+ const auto &dst_strides = dst_info->strides_in_bytes();
+
+ // Iterator calculates pointer offsets and takes into account padding.
+ Iterator input(in, window);
+ Iterator output(out, window);
+
+ // NOTE: This kernel uses collapsed 2D shapes.
+ // The excecution window is expected to be pre-collapsed in kernel configure(...) function.
+ const uintptr_t k_shape[] = {window.num_iterations(0), window.num_iterations(1)};
+
+ const uintptr_t k_src_strides[] = {src_strides[0], src_strides[1]};
+ const uintptr_t k_dst_strides[] = {dst_strides[0], dst_strides[1]};
+
+ const auto *k_src = reinterpret_cast<const float *>(input.ptr());
+ auto *k_dst = reinterpret_cast<float *>(output.ptr());
+
+ sme2_f32_logistic_kernel(k_src, k_dst, k_shape, k_src_strides, k_dst_strides);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/cpu/kernels/logistic/list.h b/src/cpu/kernels/logistic/list.h
new file mode 100644
index 0000000000..7893a61248
--- /dev/null
+++ b/src/cpu/kernels/logistic/list.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_LOGISTIC_LIST_H
+#define ACL_SRC_CPU_KERNELS_LOGISTIC_LIST_H
+
+namespace arm_compute
+{
+namespace cpu
+{
+#define DECLARE_LOGISTIC_KERNEL(func_name) \
+ void func_name(const ITensor *src, ITensor *dst, const ActivationLayerInfo &act_info, const Window &window)
+
+#ifdef __aarch64__
+DECLARE_LOGISTIC_KERNEL(sme2_fp32_logistic);
+#endif // __aarch64__
+
+#undef DECLARE_LOGISTIC_KERNEL
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_KERNELS_LOGISTIC_LIST_H
diff --git a/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp b/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
index 344b9df0c8..c73d1def6b 100644
--- a/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
@@ -53,23 +53,24 @@ void mean_stddev_normalization<float16_t, 8>(ITensor *input, ITensor *output, fl
auto in_ptr = reinterpret_cast<const float16_t *>(input_itr.ptr());
auto out_ptr = reinterpret_cast<float16_t *>(output_itr.ptr());
- float16x8_t sum_vec = vdupq_n_f16(static_cast<float16_t>(0.0f));
+ float32x4x2_t sum_vec = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
+
float32x4_t sum_sq_vec = vdupq_n_f32(0.0f);
for (; x <= (window_end_x - window_step_x); x += window_step_x)
{
float16x8_t data = vld1q_f16(in_ptr + x);
- sum_vec = vaddq_f16(sum_vec, data);
float32x4_t dl = vcvt_f32_f16(vget_low_f16(data));
float32x4_t dh = vcvt_f32_f16(vget_high_f16(data));
+ sum_vec.val[0] = vaddq_f32(sum_vec.val[0], dl);
+ sum_vec.val[1] = vaddq_f32(sum_vec.val[1], dh);
sum_sq_vec = vaddq_f32(sum_sq_vec, vmulq_f32(dl, dl));
sum_sq_vec = vaddq_f32(sum_sq_vec, vmulq_f32(dh, dh));
}
- float32x4_t sum_carry_res =
- vpaddq_f32(vcvt_f32_f16(vget_high_f16(sum_vec)), vcvt_f32_f16(vget_low_f16(sum_vec)));
- float sum = vaddvq_f32(sum_carry_res);
- float sum_sq = vaddvq_f32(sum_sq_vec);
+ float32x4_t sum_carry_res = vpaddq_f32(sum_vec.val[0], sum_vec.val[1]);
+ float sum = vaddvq_f32(sum_carry_res);
+ float sum_sq = vaddvq_f32(sum_sq_vec);
// Compute left-over elements
for (; x < window_end_x; ++x)
diff --git a/src/cpu/kernels/mul/generic/sme2/list.h b/src/cpu/kernels/mul/generic/sme2/list.h
new file mode 100644
index 0000000000..f6aecfb91c
--- /dev/null
+++ b/src/cpu/kernels/mul/generic/sme2/list.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_MUL_GENERIC_SME2_LIST_H
+#define ACL_SRC_CPU_KERNELS_MUL_GENERIC_SME2_LIST_H
+namespace arm_compute
+{
+namespace cpu
+{
+#define DECLARE_MUL_KERNEL(func_name) \
+ void func_name(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, float scale)
+
+DECLARE_MUL_KERNEL(sme2_q8_signed_mul);
+
+#undef DECLARE_MUL_KERNEL
+} // namespace cpu
+} // namespace arm_compute
+#endif // ACL_SRC_CPU_KERNELS_MUL_GENERIC_SME2_LIST_H
diff --git a/src/cpu/kernels/mul/generic/sme2/qasymm8_signed.cpp b/src/cpu/kernels/mul/generic/sme2/qasymm8_signed.cpp
new file mode 100644
index 0000000000..94fbaa1bb2
--- /dev/null
+++ b/src/cpu/kernels/mul/generic/sme2/qasymm8_signed.cpp
@@ -0,0 +1,410 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// Mul SME kernel
+void sme2_q8_signed_mul_kernel( //
+ const int8_t *src,
+ const int8_t *weights,
+ int8_t *dst,
+ const int16_t offset_a,
+ const int16_t offset_b,
+ const int16_t offset_c,
+ const float multiplier, // = (scale_a * scale_b * mul) / scale_c
+ const uintptr_t win_shape[4],
+ const uintptr_t src_strides[4],
+ const uintptr_t wei_strides[4],
+ const uintptr_t dst_strides[4])
+{
+ struct Args
+ {
+ uintptr_t shape1;
+ uintptr_t shape2;
+ uintptr_t shape3;
+ const int8_t *src;
+ const int8_t *wei;
+ int8_t *dst;
+ int multiplier14p18;
+ int offsetC14p18;
+ int16_t offsetA;
+ int16_t offsetB;
+ } args;
+
+ // Constants used to express values in the 14p18 fixed point format
+ constexpr int32_t two_pwr18i = 262144;
+ constexpr float two_pwr18f = 262144.f;
+
+ args.shape1 = win_shape[1];
+ args.shape2 = win_shape[2];
+ args.shape3 = win_shape[3];
+ args.src = src;
+ args.wei = weights;
+ args.dst = dst;
+ args.multiplier14p18 = static_cast<int>(multiplier * two_pwr18f);
+ args.offsetC14p18 = static_cast<int>(offset_c * two_pwr18i);
+ // Offsets a/b need to be negated as assembly kernel uses addition instructions where subtraction is needed.
+ // Offset C is not negated as it needs to be added rather than subtracted.
+ args.offsetA = offset_a * -1;
+ args.offsetB = offset_b * -1;
+
+ // Precondition:
+ assert(src_strides[0] == sizeof(int8_t));
+ assert(wei_strides[0] == sizeof(int8_t));
+ assert(dst_strides[0] == sizeof(int8_t));
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+ .inst 0x25207811 // ptrue pn9.b
+ ptrue p0.b
+
+ // ==================================================
+ // 3D loop opening
+ // ==================================================
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cntb x8, ALL, MUL #2 // x13 is vl (of 16 bit values)
+ udiv x9, %x[length], x8 // length/vl
+ mul x8, x8, x9 // x13 = vl * result
+
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ ldr x10, [%[args_ptr], %[offset_shape_3]]
+ ldr x11, [%[args_ptr], %[offset_src_ptr]]
+ ldr x12, [%[args_ptr], %[offset_wei_ptr]]
+ ldr x13, [%[args_ptr], %[offset_dst_ptr]]
+
+ // Could potentially be replaced with explicit loads.
+ ld1rh {z1.h}, p0/z, [%[args_ptr], %[offset_A_offset]]
+ ld1rh {z2.h}, p0/z, [%[args_ptr], %[offset_B_offset]]
+ ld1rw {z3.s}, p0/z, [%[args_ptr], %[multiplier_offset]]
+
+loop_3_start%=:
+ // for index_3 in shape_3 downto 1
+ cmp x10, #0
+ b.eq loop_3_end%=
+ sub x10, x10, #1
+
+ ldr x14, [%[args_ptr], %[offset_shape_2]]
+ mov x15, x11
+ mov x16, x12
+ mov x17, x13
+
+loop_2_start%=:
+ // for index_2 in shape_2 downto 1
+ cmp x14, #0
+ b.eq loop_2_end%=
+ sub x14, x14, #1
+
+ ldr x7, [%[args_ptr], %[offset_shape_1]]
+ mov x20, x15
+ mov x21, x16
+ mov x22, x17
+
+loop_1_start%=:
+ // for index_1 in shape_2 downto 1
+ cmp x7, #0
+ b.eq loop_1_end%=
+ sub x7, x7, #1
+
+ mov x9, #0 // x9: index/count
+
+inner_loop_body_start%=:
+ cmp x9, x8
+ b.eq inner_loop_body_end%=
+
+ // WIDEN LOAD. LOAD 4 Z-REGS FOR BOTH A/B
+
+ // NOTE: INSTEAD OF LOADING 4 LOAD 2 due to REG LIMITATIONS
+ .inst 0xa0090684 // ld1b {z4.b-z5.b}, pn9/z, [x20, x9]
+ .inst 0xa00906a6 // ld1b {z6.b-z7.b}, pn9/z, [x21, x9]
+
+ // Widen to 16 bits
+ .inst 0xc175e08c // sunpk {z12.h-z15.h}, {z4.b-z5.b} // (a)
+ .inst 0xc175e0d0 // sunpk {z16.h-z19.h}, {z6.b-z7.b} // (b)
+
+ // Apply offset to all registers in 16-bit
+ .inst 0xc161ab0c // add {z12.h-z15.h}, {z12.h-z15.h}, z1.h //a
+ .inst 0xc162ab10 // add {z16.h-z19.h}, {z16.h-z19.h}, z2.h //b
+
+ // Widen to 32-bit now.
+ // 12-19 are taken
+ // 4-11 a, 20-27 b
+ .inst 0xc1b5e184 // sunpk {z4.s-z7.s}, {z12.h-z13.h} //a
+ .inst 0xc1b5e1c8 // sunpk {z8.s-z11.s}, {z14.h-z15.h}
+ .inst 0xc1b5e214 // sunpk {z20.s-z23.s}, {z16.h-z17.h} //b
+ .inst 0xc1b5e258 // sunpk {z24.s-z27.s}, {z18.h-z19.h}
+
+ // Multiply a*b in int32
+ // Output in z4-z11
+ MUL z4.s, z4.s, z20.s
+ MUL z5.s, z5.s, z21.s
+ MUL z6.s, z6.s, z22.s
+ MUL z7.s, z7.s, z23.s
+ MUL z8.s, z8.s, z24.s
+ MUL z9.s, z9.s, z25.s
+ MUL z10.s, z10.s, z26.s
+ MUL z11.s, z11.s, z27.s
+
+ // offsets
+ dup z12.s, %w[offset_C]
+ dup z13.s, %w[offset_C]
+ dup z14.s, %w[offset_C]
+ dup z15.s, %w[offset_C]
+ dup z16.s, %w[offset_C]
+ dup z17.s, %w[offset_C]
+ dup z18.s, %w[offset_C]
+ dup z19.s, %w[offset_C]
+
+ // MLA Fixed Point multiplication integer
+ MLA z12.s, p0/m, z4.s, z3.s
+ MLA z13.s, p0/m, z5.s, z3.s
+ MLA z14.s, p0/m, z6.s, z3.s
+ MLA z15.s, p0/m, z7.s, z3.s
+ MLA z16.s, p0/m, z8.s, z3.s
+ MLA z17.s, p0/m, z9.s, z3.s
+ MLA z18.s, p0/m, z10.s, z3.s
+ MLA z19.s, p0/m, z11.s, z3.s
+
+ // Int32 to Int8 saturate
+ .inst 0xc16eda05 // sqrshr z5.b, {z16.s-z19.s}, #18
+ .inst 0xc16ed984 // sqrshr z4.b, {z12.s-z15.s}, #18
+ // Store
+ .inst 0xa02906c4 // st1b {z4.b-z5.b}, pn9, [x22, x9]
+
+ incb x9, ALL, MUL #2
+ b inner_loop_body_start%=
+inner_loop_body_end%=:
+
+inner_loop_leftover_start%=:
+ whilelo p1.b, x9, %x[length] // While x9<length
+ b.none inner_loop_leftover_end%=
+
+ // HANDLE MULTIPLICATION HERE
+ ld1b z4.b, p1/z, [x20, x9] // z4: a input_data
+ ld1b z5.b, p1/z, [x21, x9] // z5: b input_data
+
+ // Widen register z4 (a)
+ sunpklo z6.h, z4.b // lower as 16 bits
+ sunpkhi z7.h, z4.b // upper as 16 bits
+
+ // Widen register z5 (b)
+ sunpklo z8.h, z5.b // lower as 16 bits
+ sunpkhi z9.h, z5.b // upper as 16 bits
+
+ // Apply offset in 16bit maths to all resulting vectors.
+ add z6.h, z6.h, z1.h //a
+ add z7.h, z7.h, z1.h
+ add z8.h, z8.h, z2.h //b
+ add z9.h, z9.h, z2.h
+
+ // Widen a,b to 32-bit z-registers.
+ // Multiply a and b and store result as 32 bit int.
+ // a lower - 32-bit
+ sunpklo z10.s, z6.h
+ sunpkhi z11.s, z6.h
+ // a upper - 32-bit
+ sunpklo z12.s, z7.h
+ sunpkhi z13.s, z7.h
+
+ // b lower - 32-bit
+ sunpklo z14.s, z8.h
+ sunpkhi z15.s, z8.h
+ // b upper - 32-bit
+ sunpklo z16.s, z9.h
+ sunpkhi z17.s, z9.h
+
+ // offsets
+ dup z4.s, %w[offset_C]
+ dup z5.s, %w[offset_C]
+ dup z6.s, %w[offset_C]
+ dup z7.s, %w[offset_C]
+
+ // Multiply a*b (lower) in int32
+ MUL z10.s, z10.s, z14.s
+ MUL z11.s, z11.s, z15.s
+
+ // Multiply a*b (upper) in int32
+ MUL z12.s, z12.s, z16.s
+ MUL z13.s, z13.s, z17.s
+
+ // Still int32 here.
+ // Now MLA in fixed point
+ MLA z4.s, p0/m, z10.s, z3.s
+ MLA z5.s, p0/m, z11.s, z3.s
+ MLA z6.s, p0/m, z12.s, z3.s
+ MLA z7.s, p0/m, z13.s, z3.s
+
+ // Right shift, no narrow
+ LSR z20.s, z4.s, #8
+ LSR z21.s, z5.s, #8
+ LSR z22.s, z6.s, #8
+ LSR z23.s, z7.s, #8
+
+ // Right shift rounding (lower)
+ // Do not saturate.
+ RSHRNB z20.h, z20.s, #8
+ RSHRNB z21.h, z21.s, #8
+ UZP1 z25.h, z20.h, z21.h
+ // Right shift upper.
+ RSHRNB z22.h, z22.s, #8
+ RSHRNB z23.h, z23.s, #8
+ UZP1 z26.h, z22.h, z23.h
+
+ // Shift again to 8 bit both vectors. Recombine.
+ SQRSHRNB z25.b, z25.h, #2
+ SQRSHRNB z26.b, z26.h, #2
+ UZP1 z27.b, z25.b, z26.b
+
+ st1b z27.b, p1, [x22, x9]
+
+ incb x9 // x9 : x9 += sizeof(element) * predicate_count
+ b inner_loop_leftover_start%=
+inner_loop_leftover_end%=:
+
+ // ==================================================
+ // 3D loop closing
+ // ==================================================
+
+ add x20, x20, %[src_stride_1]
+ add x21, x21, %[wei_stride_1]
+ add x22, x22, %[dst_stride_1]
+ b loop_1_start%=
+loop_1_end%=:
+
+ add x15, x15, %[src_stride_2]
+ add x16, x16, %[wei_stride_2]
+ add x17, x17, %[dst_stride_2]
+ b loop_2_start%=
+loop_2_end%=:
+
+ add x11, x11, %[src_stride_3]
+ add x12, x12, %[wei_stride_3]
+ add x13, x13, %[dst_stride_3]
+ b loop_3_start%=
+loop_3_end%=:
+
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : // The following arguments are loaded via arg ptr values and a constant offset.
+ [args_ptr] "r"(&args), [offset_src_ptr] "I"(offsetof(Args, src)), [offset_wei_ptr] "I"(offsetof(Args, wei)),
+ [offset_dst_ptr] "I"(offsetof(Args, dst)), [offset_shape_1] "I"(offsetof(Args, shape1)),
+ [offset_shape_2] "I"(offsetof(Args, shape2)), [offset_shape_3] "I"(offsetof(Args, shape3)),
+ [multiplier_offset] "I"(offsetof(Args, multiplier14p18)), //
+ [offset_A_offset] "I"(offsetof(Args, offsetA)), //
+ [offset_B_offset] "I"(offsetof(Args, offsetB)), //
+ // Use registers for efficiency sake.
+ [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), [src_stride_3] "r"(src_strides[3]),
+ [wei_stride_1] "r"(wei_strides[1]), [wei_stride_2] "r"(wei_strides[2]), [wei_stride_3] "r"(wei_strides[3]),
+ [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), [dst_stride_3] "r"(dst_strides[3]),
+ [offset_C] "r"(args.offsetC14p18), //
+ [length] "r"(win_shape[0])
+ : "cc", "memory", //
+ "p0", "p1", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22",
+ "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_q8_signed_mul(const ITensor *in0, const ITensor *in1, ITensor *out, const Window &window, const float scale)
+{
+ const auto *src_info = in0->info();
+ const auto *src2_info = in1->info();
+ const auto *dst_info = out->info();
+
+ const UniformQuantizationInfo src_q_info = src_info->quantization_info().uniform();
+ const UniformQuantizationInfo src2_q_info = src2_info->quantization_info().uniform();
+ const UniformQuantizationInfo dst_q_info = dst_info->quantization_info().uniform();
+
+ const auto &src_strides_bytes = src_info->strides_in_bytes();
+ const auto &wei_strides_bytes = src2_info->strides_in_bytes();
+ const auto &dst_strides_bytes = dst_info->strides_in_bytes();
+
+ // NOTE: This kernel does not support shapes above 4D (Unless excecution window has been collapsed)
+ assert(window.num_iterations(4) == 1 && window.num_iterations(5) == 1);
+
+ // Note : The window is expected to handle y-broadcasting by setting relevant strides to 0.
+ const uintptr_t shape[] = {
+ window.num_iterations(0),
+ window.num_iterations(1),
+ window.num_iterations(2),
+ window.num_iterations(3),
+ };
+
+ Window input1_win = window.broadcast_if_dimension_le_one(src_info->tensor_shape());
+ Window input2_win = window.broadcast_if_dimension_le_one(src2_info->tensor_shape());
+
+ // First dim is always datasize. If broadcasting in other dims, set stride to 0.
+ uintptr_t src_strides[] = {src_strides_bytes[0], (input1_win.is_broadcasted(1)) ? 0 : src_strides_bytes[1],
+ (input1_win.is_broadcasted(2)) ? 0 : src_strides_bytes[2],
+ (input1_win.is_broadcasted(3)) ? 0 : src_strides_bytes[3]};
+ uintptr_t wei_strides[] = {wei_strides_bytes[0], (input2_win.is_broadcasted(1)) ? 0 : wei_strides_bytes[1],
+ (input2_win.is_broadcasted(2)) ? 0 : wei_strides_bytes[2],
+ (input2_win.is_broadcasted(3)) ? 0 : wei_strides_bytes[3]};
+
+ const uintptr_t dst_strides[] = {
+ dst_strides_bytes[0],
+ dst_strides_bytes[1],
+ dst_strides_bytes[2],
+ dst_strides_bytes[3],
+ };
+
+ const uintptr_t src_offset = window[0].start() * src_strides[0] + window[1].start() * src_strides[1] +
+ window[2].start() * src_strides[2] + window[3].start() * src_strides[3] +
+ in0->info()->offset_first_element_in_bytes();
+ const uintptr_t src2_offset = window[0].start() * wei_strides[0] + window[1].start() * wei_strides[1] +
+ window[2].start() * wei_strides[2] + window[3].start() * wei_strides[3] +
+ in1->info()->offset_first_element_in_bytes();
+ const uintptr_t dst_offset = window[0].start() * dst_strides[0] + window[1].start() * dst_strides[1] +
+ window[2].start() * dst_strides[2] + window[3].start() * dst_strides[3] +
+ out->info()->offset_first_element_in_bytes();
+
+ const auto *src = reinterpret_cast<const int8_t *>(in0->buffer() + src_offset);
+ const auto *src2 = reinterpret_cast<const int8_t *>(in1->buffer() + src2_offset);
+ auto *dst = reinterpret_cast<int8_t *>(out->buffer() + dst_offset);
+
+ // Calculate or retrieve necessary offsets/scale values.
+ const int16_t offset_a = src_q_info.offset;
+ const int16_t offset_b = src2_q_info.offset;
+ float multiplier = (src_q_info.scale * src2_q_info.scale * scale) / dst_q_info.scale;
+
+ sme2_q8_signed_mul_kernel(src, src2, dst, offset_a, offset_b, dst_q_info.offset, multiplier, shape, src_strides,
+ wei_strides, dst_strides);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2