aboutsummaryrefslogtreecommitdiff
path: root/src/cpu
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu')
-rw-r--r--src/cpu/kernels/CpuActivationKernel.cpp12
-rw-r--r--src/cpu/kernels/CpuDirectConv2dOutputStageKernel.cpp349
-rw-r--r--src/cpu/kernels/CpuDirectConv3dKernel.cpp20
-rw-r--r--src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp4
-rw-r--r--src/cpu/kernels/CpuKernelSelectionTypes.h9
-rw-r--r--src/cpu/kernels/CpuMulKernel.cpp44
-rw-r--r--src/cpu/kernels/CpuPermuteKernel.cpp65
-rw-r--r--src/cpu/kernels/CpuReshapeKernel.cpp16
-rw-r--r--src/cpu/kernels/CpuReshapeKernel.h19
-rw-r--r--src/cpu/kernels/CpuScatterKernel.cpp98
-rw-r--r--src/cpu/kernels/CpuScatterKernel.h91
-rw-r--r--src/cpu/kernels/CpuWeightsReshapeKernel.cpp7
-rw-r--r--src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h8
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp12
-rw-r--r--src/cpu/kernels/assembly/convolution_parameters.hpp2
-rw-r--r--src/cpu/kernels/assembly/gemm_common.hpp25
-rw-r--r--src/cpu/kernels/conv3d/generic/neon/float_impl.h (renamed from src/cpu/kernels/conv3d/neon/list.h)17
-rw-r--r--src/cpu/kernels/conv3d/generic/neon/fp16.cpp49
-rw-r--r--src/cpu/kernels/conv3d/generic/neon/fp32.cpp46
-rw-r--r--src/cpu/kernels/conv3d/generic/neon/qasymm8.cpp46
-rw-r--r--src/cpu/kernels/conv3d/generic/neon/qasymm8_signed.cpp46
-rw-r--r--src/cpu/kernels/conv3d/generic/neon/quantized_impl.h (renamed from src/cpu/kernels/conv3d/neon/quantized.h)12
-rw-r--r--src/cpu/kernels/conv3d/list.h47
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/generic/neon/float_impl.h186
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/generic/neon/fp16.cpp63
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/generic/neon/fp32.cpp59
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8.cpp59
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8_signed.cpp59
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/generic/neon/quantized_impl.h213
-rw-r--r--src/cpu/kernels/directconv2d_output_stage/list.h57
-rw-r--r--src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp572
-rw-r--r--src/cpu/kernels/logistic/generic/sme2/fp32.cpp429
-rw-r--r--src/cpu/kernels/logistic/list.h42
-rw-r--r--src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp13
-rw-r--r--src/cpu/kernels/mul/generic/sme2/list.h38
-rw-r--r--src/cpu/kernels/mul/generic/sme2/qasymm8_signed.cpp410
-rw-r--r--src/cpu/operators/CpuConv2d.h5
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp67
-rw-r--r--src/cpu/operators/CpuGemmConv2d.h8
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp98
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h7
-rw-r--r--src/cpu/operators/CpuMatMul.cpp7
-rw-r--r--src/cpu/operators/CpuPermute.cpp76
-rw-r--r--src/cpu/operators/CpuReshape.cpp6
-rw-r--r--src/cpu/operators/CpuScatter.cpp70
-rw-r--r--src/cpu/operators/CpuScatter.h81
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp311
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h24
48 files changed, 3179 insertions, 825 deletions
diff --git a/src/cpu/kernels/CpuActivationKernel.cpp b/src/cpu/kernels/CpuActivationKernel.cpp
index 4253027231..555705bd45 100644
--- a/src/cpu/kernels/CpuActivationKernel.cpp
+++ b/src/cpu/kernels/CpuActivationKernel.cpp
@@ -32,6 +32,7 @@
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
#include "src/cpu/kernels/activation/list.h"
+#include "src/cpu/kernels/logistic/list.h"
#include <array>
@@ -71,6 +72,12 @@ static const std::vector<CpuActivationKernel::ActivationKernel> available_kernel
},
REGISTER_Q8_NEON(arm_compute::cpu::neon_q8_activation_lut)},
#endif // __aarch64__
+ {"sme2_fp32_logistic",
+ [](const ActivationDataTypeISASelectorData &data) {
+ return data.dt == DataType::F32 && data.f == ActivationLayerInfo::ActivationFunction::LOGISTIC &&
+ data.isa.sme2;
+ },
+ REGISTER_FP32_SME2(arm_compute::cpu::sme2_fp32_logistic)},
{"sve2_qu8_activation",
[](const ActivationDataTypeISASelectorData &data) {
return data.dt == DataType::QASYMM8 && data.isa.sve2 &&
@@ -316,6 +323,11 @@ void CpuActivationKernel::configure(const ITensorInfo *src, ITensorInfo *dst, Ac
// Use squashed window
std::tie(win, _split_dimension) = calculate_squashed_or_max_window(*src);
+ // Collapse window with SME kernels in Y-Dim
+ if (std::string(uk->name) == "sme2_fp32_logistic")
+ {
+ win = win.collapse(win, Window::DimY);
+ }
ICPPKernel::configure(win);
}
diff --git a/src/cpu/kernels/CpuDirectConv2dOutputStageKernel.cpp b/src/cpu/kernels/CpuDirectConv2dOutputStageKernel.cpp
index d4af8bedaf..9e9137a266 100644
--- a/src/cpu/kernels/CpuDirectConv2dOutputStageKernel.cpp
+++ b/src/cpu/kernels/CpuDirectConv2dOutputStageKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -37,6 +37,7 @@
#include "src/core/NEON/NEAsymm.h"
#include "src/core/NEON/NEFixedPoint.h"
#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/cpu/kernels/directconv2d_output_stage/list.h"
#include <arm_neon.h>
#include <cstddef>
@@ -95,316 +96,6 @@ Status validate_arguments(const ITensorInfo *src
return Status{};
}
-
-template <typename T>
-typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type
-output_stage_nchw(ITensor *src,
- const ITensor *bias,
- const Window &window,
- ITensor *dst,
- int result_fixedpoint_multiplier,
- int result_shift,
- int result_offset_after_shift)
-{
- const bool has_bias = bias != nullptr;
- /** SIMD vector tag type. */
- using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
-
- ARM_COMPUTE_ERROR_ON(src->info()->data_layout() == DataLayout::UNKNOWN);
- ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
- ARM_COMPUTE_UNUSED(result_shift);
- ARM_COMPUTE_UNUSED(result_offset_after_shift);
-
- const int window_start_x = window.x().start();
- const int window_end_x = window.x().end();
- const int window_step_x = 16 / src->info()->element_size();
- Window win = window;
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator in(src, win);
- Iterator out(dst, win);
- execute_window_loop(
- win,
- [&](const Coordinates &id)
- {
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<const T *>(in.ptr()) + x;
- auto v_in = wrapper::vloadq(in_ptr);
-
- // Accumulate bias
- if (has_bias)
- {
- const auto vb = wrapper::vdup_n(
- *reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(id.z()))), ExactTagType{});
- v_in = wrapper::vadd(v_in, vb);
- }
-
- const auto out_ptr = reinterpret_cast<T *>(out.ptr()) + x;
- wrapper::vstore(out_ptr, v_in);
- }
-
- // Left-overs loop
- for (; x < window_end_x; ++x)
- {
- // Get bias and pointer to input
- auto s_in = *(reinterpret_cast<const T *>(in.ptr()) + x);
-
- // Accumulate bias
- if (has_bias)
- {
- const auto b = *reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(id.z())));
- s_in += b;
- }
-
- *(reinterpret_cast<T *>(out.ptr()) + x) = s_in;
- }
- },
- in, out);
-}
-
-template <typename T>
-typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type
-output_stage_nhwc(ITensor *src,
- const ITensor *bias,
- const Window &window,
- ITensor *dst,
- int result_fixedpoint_multiplier,
- int result_shift,
- int result_offset_after_shift)
-{
- const bool has_bias = bias != nullptr;
- ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
- ARM_COMPUTE_UNUSED(result_shift);
- ARM_COMPUTE_UNUSED(result_offset_after_shift);
-
- Window window_bias = window;
- window_bias.set(Window::DimX, Window::Dimension(0, 1, 1));
- window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
- window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
- window_bias.set(3, Window::Dimension(0, 0, 0));
-
- const int window_start_x = window.x().start();
- const int window_end_x = window.x().end();
- const int window_step_x = 16 / src->info()->element_size();
- Window win = window;
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator in(src, win);
- Iterator bi(bias, window_bias);
- Iterator out(dst, win);
-
- execute_window_loop(
- win,
- [&](const Coordinates &)
- {
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<const T *>(in.ptr());
- auto v_in = wrapper::vloadq(in_ptr + x);
-
- // Accumulate bias
- if (has_bias)
- {
- const auto bias_ptr = reinterpret_cast<T *>(bi.ptr()) + x;
- v_in = wrapper::vadd(v_in, wrapper::vloadq(bias_ptr));
- }
-
- const auto out_ptr = reinterpret_cast<T *>(out.ptr());
- wrapper::vstore(out_ptr + x, v_in);
- }
-
- // Left-overs loop
- for (; x < window_end_x; ++x)
- {
- // Get bias and pointer to input
- auto s_in = *(reinterpret_cast<const T *>(in.ptr()) + x);
-
- // Accumulate bias
- if (has_bias)
- {
- const auto bias_ptr = reinterpret_cast<T *>(bi.ptr()) + x;
- s_in += *bias_ptr;
- }
-
- const auto out_ptr = reinterpret_cast<T *>(out.ptr());
- *(out_ptr + x) = s_in;
- }
- },
- in, bi, out);
-}
-
-// Quantized case
-template <
- typename TOut,
- typename std::enable_if<std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int>::type = 0>
-void output_stage_nchw(ITensor *src,
- const ITensor *bias,
- const Window &window,
- ITensor *dst,
- int result_fixedpoint_multiplier,
- int result_shift,
- int result_offset_after_shift)
-{
- const bool has_bias = bias != nullptr;
- using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
- using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
-
- const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
-
- const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
- const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
-
- const int window_start_x = window.x().start();
- const int window_end_x = window.x().end();
- const int window_step_x = 16 / src->info()->element_size();
- Window win = window;
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator in(src, win);
- Iterator out(dst, win);
-
- execute_window_loop(
- win,
- [&](const Coordinates &id)
- {
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr()) + x;
- int32x4x4_t v_in = {{wrapper::vloadq(in_ptr), wrapper::vloadq(in_ptr + 4), wrapper::vloadq(in_ptr + 8),
- wrapper::vloadq(in_ptr + 12)}};
-
- // Accumulate bias
- if (has_bias)
- {
- const auto vb = wrapper::vdup_n(
- *reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z()))), TagType{});
- v_in = {{wrapper::vadd(v_in.val[0], vb), wrapper::vadd(v_in.val[1], vb),
- wrapper::vadd(v_in.val[2], vb), wrapper::vadd(v_in.val[3], vb)}};
- }
-
- const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
- wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift,
- result_offset_after_shift_s32, min, max, false));
- }
-
- // Left-overs loop
- for (; x < window_end_x; ++x)
- {
- // Get bias and pointer to input
- int32_t s_in = *(reinterpret_cast<const int32_t *>(in.ptr()) + x);
-
- // Accumulate bias
- if (has_bias)
- {
- const auto b = *reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z())));
- s_in += b;
- }
-
- const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
- *out_ptr =
- finalize_quantization(s_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift,
- std::numeric_limits<TOut>::lowest(), std::numeric_limits<TOut>::max(), false);
- }
- },
- in, out);
-}
-template <
- typename TOut,
- typename std::enable_if<std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int>::type = 0>
-void output_stage_nhwc(ITensor *src,
- const ITensor *bias,
- const Window &window,
- ITensor *dst,
- int result_fixedpoint_multiplier,
- int result_shift,
- int result_offset_after_shift)
-{
- const bool has_bias = bias != nullptr;
- using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
- using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
-
- const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
-
- const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
- const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
-
- Window window_bias = window;
- window_bias.set(Window::DimX, Window::Dimension(0, 1, 1));
- window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
- window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
- window_bias.set(3, Window::Dimension(0, 0, 0));
-
- const int window_start_x = window.x().start();
- const int window_end_x = window.x().end();
- const int window_step_x = 16 / src->info()->element_size();
- Window win = window;
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator in(src, win);
- Iterator bi(bias, window_bias);
- Iterator out(dst, win);
-
- execute_window_loop(
- win,
- [&](const Coordinates &)
- {
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr()) + x;
- int32x4x4_t v_in = {{
- wrapper::vloadq(in_ptr),
- wrapper::vloadq(in_ptr + 4),
- wrapper::vloadq(in_ptr + 8),
- wrapper::vloadq(in_ptr + 12),
- }};
-
- // Accumulate bias
- if (has_bias)
- {
- const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr()) + x;
-
- wrapper::vadd(v_in.val[0], wrapper::vloadq(bias_ptr));
- wrapper::vadd(v_in.val[1], wrapper::vloadq(bias_ptr + 4));
- wrapper::vadd(v_in.val[2], wrapper::vloadq(bias_ptr + 8));
- wrapper::vadd(v_in.val[3], wrapper::vloadq(bias_ptr + 12));
- }
-
- const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
- wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift,
- result_offset_after_shift_s32, min, max, false));
- }
-
- // Left-overs loop
- for (; x < window_end_x; ++x)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr()) + x;
- int32_t s_in = *in_ptr;
-
- // Accumulate bias
- if (has_bias)
- {
- const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr()) + x;
- s_in += *bias_ptr;
- }
-
- const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
- *out_ptr =
- finalize_quantization(s_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift,
- std::numeric_limits<TOut>::lowest(), std::numeric_limits<TOut>::max(), false);
- }
- },
- in, bi, out);
-}
} // namespace
void CpuDirectConv2dOutputStageKernel::configure(ITensorInfo *src,
@@ -447,24 +138,30 @@ void CpuDirectConv2dOutputStageKernel::configure(ITensorInfo
{
if (is_qasymm8_signed)
{
- _func = &output_stage_nchw<int8_t>;
+#ifdef ENABLE_QASYMM8_SIGNED_KERNELS
+ _func = &output_stage_nchw_qs8;
+#endif // ENABLE_QASYMM8_SIGNED_KERNELS
}
else
{
- _func = &output_stage_nchw<uint8_t>;
+#ifdef ENABLE_QASYMM8_KERNELS
+ _func = &output_stage_nchw_qu8;
+#endif // ENABLE_QASYMM8_KERNELS
}
break;
}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+#ifdef ENABLE_FP16_KERNELS
case DataType::F16:
{
- _func = &output_stage_nchw<float16_t>;
+ _func = &output_stage_nchw_fp16;
break;
}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+#endif // ENABLE_FP16_KERNELS
case DataType::F32:
{
- _func = &output_stage_nchw<float>;
+#ifdef ENABLE_FP32_KERNELS
+ _func = &output_stage_nchw_fp32;
+#endif // ENABLE_FP32_KERNELS
break;
}
default:
@@ -481,24 +178,30 @@ void CpuDirectConv2dOutputStageKernel::configure(ITensorInfo
{
if (is_qasymm8_signed)
{
- _func = &output_stage_nhwc<int8_t>;
+#ifdef ENABLE_QASYMM8_SIGNED_KERNELS
+ _func = &output_stage_nhwc_qs8;
+#endif // ENABLE_QASYMM8_SIGNED_KERNELS
}
else
{
- _func = &output_stage_nhwc<uint8_t>;
+#ifdef ENABLE_QASYMM8_KERNELS
+ _func = &output_stage_nhwc_qu8;
+#endif // QASYMM8_SIGNED_KERNELS
}
break;
}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+#ifdef ENABLE_FP16_KERNELS
case DataType::F16:
{
- _func = &output_stage_nhwc<float16_t>;
+ _func = &output_stage_nhwc_fp16;
break;
}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+#endif // ENABLE_FP16_KERNELS
case DataType::F32:
{
- _func = &output_stage_nhwc<float>;
+#ifdef ENABLE_FP32_KERNELS
+ _func = &output_stage_nhwc_fp32;
+#endif // ENABLE_FP32_KERNELS
break;
}
default:
diff --git a/src/cpu/kernels/CpuDirectConv3dKernel.cpp b/src/cpu/kernels/CpuDirectConv3dKernel.cpp
index b5b2aed1ba..9c37ece3dd 100644
--- a/src/cpu/kernels/CpuDirectConv3dKernel.cpp
+++ b/src/cpu/kernels/CpuDirectConv3dKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,8 +25,8 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/IAccessWindow.h"
#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Steps.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
@@ -35,10 +35,8 @@
#include "src/core/common/Registrars.h"
#include "src/core/CPP/Validate.h"
#include "src/core/helpers/AutoConfiguration.h"
-#include "src/core/NEON/wrapper/wrapper.h"
-#include "src/cpu/kernels/conv3d/neon/list.h"
-
-#include <algorithm>
+#include "src/core/helpers/WindowHelpers.h"
+#include "src/cpu/kernels/conv3d/list.h"
using namespace arm_compute::detail;
@@ -51,18 +49,16 @@ namespace kernels
namespace
{
static const std::vector<CpuDirectConv3dKernel::DirectConv3dKernel> available_kernels = {
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
{"neon_fp16_directconv3d",
[](const DataTypeISASelectorData &data) { return data.dt == DataType::F16 && data.isa.fp16; },
- REGISTER_FP16_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float16_t>)},
-#endif /* !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
+ REGISTER_FP16_NEON(directconv3d_fp16_neon_ndhwc)},
{"neon_fp32_directconv3d", [](const DataTypeISASelectorData &data) { return data.dt == DataType::F32; },
- REGISTER_FP32_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float>)},
+ REGISTER_FP32_NEON(directconv3d_fp32_neon_ndhwc)},
{"neon_qasymm8_directconv3d", [](const DataTypeISASelectorData &data) { return data.dt == DataType::QASYMM8; },
- REGISTER_QASYMM8_NEON(arm_compute::cpu::directconv3d_quantized_neon_ndhwc<uint8_t>)},
+ REGISTER_QASYMM8_NEON(directconv3d_qu8_neon_ndhwc)},
{"neon_qasymm8_signed_directconv3d",
[](const DataTypeISASelectorData &data) { return data.dt == DataType::QASYMM8_SIGNED; },
- REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::directconv3d_quantized_neon_ndhwc<int8_t>)}};
+ REGISTER_QASYMM8_SIGNED_NEON(directconv3d_qs8_neon_ndhwc)}};
Status validate_arguments(const ITensorInfo *src0,
const ITensorInfo *src1,
diff --git a/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp b/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp
index 5b88735e7a..87340e566e 100644
--- a/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp
@@ -684,6 +684,10 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
DataType::U8);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->data_type() == DataType::QASYMM8_SIGNED &&
+ src1->data_type() == DataType::QASYMM8,
+ "QASYMM8_SIGNED input with QASYMM8 weights not supported");
+
TensorShape in0_shape = src0->tensor_shape();
TensorShape in1_shape = src1->tensor_shape();
TensorShape out_shape = dst->tensor_shape();
diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h
index 7c1e4772a6..96ddad9d19 100644
--- a/src/cpu/kernels/CpuKernelSelectionTypes.h
+++ b/src/cpu/kernels/CpuKernelSelectionTypes.h
@@ -105,6 +105,13 @@ struct SoftmaxKernelDataTypeISASelectorData
cpuinfo::CpuIsaInfo isa;
bool is_log;
int axis;
+ uint64_t sme2_vector_length;
+};
+
+struct ScatterKernelDataTypeISASelectorData
+{
+ DataType dt;
+ cpuinfo::CpuIsaInfo isa;
unsigned long sme2_vector_length;
};
@@ -124,6 +131,8 @@ using ScaleKernelDataTypeISASelectorDataPtr =
std::add_pointer<bool(const ScaleKernelDataTypeISASelectorData &data)>::type;
using SoftmaxKernelDataTypeISASelectorDataPtr =
std::add_pointer<bool(const SoftmaxKernelDataTypeISASelectorData &data)>::type;
+using ScatterKernelDataTypeISASelectorDataPtr =
+ std::add_pointer<bool(const ScatterKernelDataTypeISASelectorData &data)>::type;
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/CpuMulKernel.cpp b/src/cpu/kernels/CpuMulKernel.cpp
index 8001482154..d7a3a77d51 100644
--- a/src/cpu/kernels/CpuMulKernel.cpp
+++ b/src/cpu/kernels/CpuMulKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2023 Arm Limited.
+ * Copyright (c) 2016-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,6 +34,7 @@
#include "src/core/NEON/NESymm.h"
#include "src/core/NEON/wrapper/wrapper.h"
#include "src/cpu/kernels/mul/generic/neon/list.h"
+#include "src/cpu/kernels/mul/generic/sme2/list.h"
#include <arm_neon.h>
@@ -317,6 +318,41 @@ void mul_saturate_quantized_8(const ITensor *src1, const ITensor *src2, ITensor
}
}
+bool mul_q8_sme_possible(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst, float scale)
+{
+ const auto &in0_shape = src0->tensor_shape();
+ const auto &in1_shape = src1->tensor_shape();
+ const unsigned int dst_dims = dst->num_dimensions();
+
+ // Calculate Scale
+ const auto iq0 = src0->quantization_info().uniform();
+ const auto iq1 = src1->quantization_info().uniform();
+ const auto oq = dst->quantization_info().uniform();
+ const auto multiplier = ((iq0.scale * iq1.scale) / oq.scale) * scale;
+ const auto max_result = multiplier * (127) * (127) + static_cast<float>(oq.offset);
+ const auto min_result = multiplier * (-128) * (-128) + static_cast<float>(oq.offset);
+
+ // Does not support broadcasting on x
+ // Does not support dims > 4D output, unless input shapes are identical (therefore collapsible)
+ // Checks whether CPU has SME2 Available
+ if (in0_shape.x() == in1_shape.x() && CPUInfo::get().has_sme2() && (in0_shape == in1_shape || dst_dims <= 4))
+ {
+ // Check if multiplier cannot be stored as a 14.18 signed fixed-point number
+ if (multiplier < -8191.f || multiplier > 8191.f)
+ {
+ return false;
+ }
+ // It might not be possible to store the result as a 14.18 signed fixed-point number.
+ if (max_result > 8191.f || min_result < -8191.f)
+ {
+ return false;
+ }
+ // Passed all checks
+ return true;
+ }
+ return false;
+}
+
bool mul_q8_neon_fixedpoint_possible(const ITensorInfo *src0,
const ITensorInfo *src1,
const ITensorInfo *dst,
@@ -1563,7 +1599,11 @@ void CpuMulKernel::configure(ITensorInfo *src1,
case DataType::QASYMM8_SIGNED:
if (dt_input2 == DataType::QASYMM8_SIGNED)
{
- if (mul_q8_neon_fixedpoint_possible(src1, src2, dst, scale))
+ if (mul_q8_sme_possible(src1, src2, dst, scale) && rounding_policy == RoundingPolicy::TO_ZERO)
+ {
+ _func_quantized = REGISTER_QASYMM8_SIGNED_SME2(arm_compute::cpu::sme2_q8_signed_mul);
+ }
+ else if (mul_q8_neon_fixedpoint_possible(src1, src2, dst, scale))
{
_func_quantized = &mul_q8_neon_fixedpoint<int8_t>;
}
diff --git a/src/cpu/kernels/CpuPermuteKernel.cpp b/src/cpu/kernels/CpuPermuteKernel.cpp
index b444a25ff7..c6e0dd3a5e 100644
--- a/src/cpu/kernels/CpuPermuteKernel.cpp
+++ b/src/cpu/kernels/CpuPermuteKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -97,15 +97,12 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const
template <typename T>
void run_permute(const Window &window, const ITensor *src, const ITensor *dst, const PermutationVector &perm)
{
- const DataLayout src_layout = src->info()->data_layout();
-
// Source window
Window window_src = window;
// we only support these two configs in src/core/NEON/kernels/convolution/common/shims.hpp, for all others
// we have to fall back to C++
- if ((src_layout == DataLayout::NCHW && perm == PermutationVector{2U, 0U, 1U}) ||
- (src_layout == DataLayout::NHWC && perm == PermutationVector{1U, 2U, 0U}))
+ if (perm == PermutationVector{2U, 0U, 1U} || perm == PermutationVector{1U, 2U, 0U})
{
window_src.set(Window::DimX,
Window::Dimension(window.x().start(), window.x().end(), window.x().end() - window.x().start()));
@@ -128,49 +125,16 @@ void run_permute(const Window &window, const ITensor *src, const ITensor *dst, c
Iterator src_it(src, window_src);
Iterator dst_it(dst, window_dst);
- int in_row_stride = 0;
- int in_col_stride = 0;
- int in_channel_stride = 0;
- int in_batch_stride = 0;
- int n_cols = 0;
- int n_rows = 0;
- int n_channels = 0;
- int n_batches = 0;
-
- switch (src_layout)
- {
- case DataLayout::NCHW:
- {
- in_row_stride = src->info()->strides_in_bytes().y() / sizeof(T);
- in_channel_stride = src->info()->strides_in_bytes().z() / sizeof(T);
- in_batch_stride = src->info()->strides_in_bytes()[3] / sizeof(T);
- n_cols = src->info()->tensor_shape().x();
- n_rows = window_src.y().step();
- n_channels = src->info()->tensor_shape().z();
- n_batches = src->info()->tensor_shape()[3];
- break;
- }
- case DataLayout::NHWC:
- {
- in_col_stride = src->info()->strides_in_bytes().y() / sizeof(T);
- in_row_stride = src->info()->strides_in_bytes().z() / sizeof(T);
- in_batch_stride = src->info()->strides_in_bytes()[3] / sizeof(T);
- n_channels = src->info()->tensor_shape().x();
- n_cols = window_src.y().step();
- n_rows = src->info()->tensor_shape().z();
- n_batches = src->info()->tensor_shape()[3];
- break;
- }
- default:
- {
- ARM_COMPUTE_ERROR("Invalid source data layout.");
- break;
- }
- }
-
// CHW -> HWC
- if (src_layout == DataLayout::NCHW && perm == PermutationVector{2U, 0U, 1U})
+ if (perm == PermutationVector{2U, 0U, 1U})
{
+ const int in_row_stride = src->info()->strides_in_bytes().y() / sizeof(T);
+ const int in_channel_stride = src->info()->strides_in_bytes().z() / sizeof(T);
+ const int in_batch_stride = src->info()->strides_in_bytes()[3] / sizeof(T);
+ const int n_cols = src->info()->tensor_shape().x();
+ const int n_rows = window_src.y().step();
+ const int n_channels = src->info()->tensor_shape().z();
+ const int n_batches = src->info()->tensor_shape()[3];
const int out_channel_stride = dst->info()->strides_in_bytes().x() / sizeof(T);
const int out_col_stride = dst->info()->strides_in_bytes().y() / sizeof(T);
const int out_row_stride = dst->info()->strides_in_bytes().z() / sizeof(T);
@@ -188,8 +152,15 @@ void run_permute(const Window &window, const ITensor *src, const ITensor *dst, c
src_it, dst_it);
}
// HWC -> CHW
- else if (src_layout == DataLayout::NHWC && perm == PermutationVector{1U, 2U, 0U})
+ else if (perm == PermutationVector{1U, 2U, 0U})
{
+ const int in_col_stride = src->info()->strides_in_bytes().y() / sizeof(T);
+ const int in_row_stride = src->info()->strides_in_bytes().z() / sizeof(T);
+ const int in_batch_stride = src->info()->strides_in_bytes()[3] / sizeof(T);
+ const int n_channels = src->info()->tensor_shape().x();
+ const int n_cols = window_src.y().step();
+ const int n_rows = src->info()->tensor_shape().z();
+ const int n_batches = src->info()->tensor_shape()[3];
const int out_col_stride = dst->info()->strides_in_bytes().x() / sizeof(T);
const int out_row_stride = dst->info()->strides_in_bytes().y() / sizeof(T);
const int out_channel_stride = dst->info()->strides_in_bytes().z() / sizeof(T);
diff --git a/src/cpu/kernels/CpuReshapeKernel.cpp b/src/cpu/kernels/CpuReshapeKernel.cpp
index 241e58fbce..78b08f19d2 100644
--- a/src/cpu/kernels/CpuReshapeKernel.cpp
+++ b/src/cpu/kernels/CpuReshapeKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -233,7 +233,9 @@ void CpuReshapeKernel::prepare(ITensorPack &tensors)
if (!src_has_holes && !dst_has_holes)
{
- std::tie(win, _split_dimension) = calculate_squashed_or_max_window(*dst_info);
+ size_t split_dimension;
+
+ std::tie(win, split_dimension) = calculate_squashed_or_max_window(*dst_info);
/*
Copy the tensor per window. If the src and dst tensors
are contiguous memory allocations without any holes or
@@ -241,7 +243,15 @@ void CpuReshapeKernel::prepare(ITensorPack &tensors)
we can use use a single memcopy call to copy the whole
window in reshape_tensor_per_window fn
*/
- _reshape_tensor_fn = reshape_tensor_per_window;
+ if (split_dimension != Window::DimY)
+ {
+ // Fall back when split dimension doesn't equal Window::DimY
+ _reshape_tensor_fn = reshape_tensor_per_row;
+ }
+ else
+ {
+ _reshape_tensor_fn = reshape_tensor_per_window;
+ }
}
else
{
diff --git a/src/cpu/kernels/CpuReshapeKernel.h b/src/cpu/kernels/CpuReshapeKernel.h
index ce566fd9e2..acad441b76 100644
--- a/src/cpu/kernels/CpuReshapeKernel.h
+++ b/src/cpu/kernels/CpuReshapeKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_RESHAPE_KERNEL_H
-#define ARM_COMPUTE_CPU_RESHAPE_KERNEL_H
+#ifndef ACL_SRC_CPU_KERNELS_CPURESHAPEKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPURESHAPEKERNEL_H
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuKernel.h"
@@ -74,21 +74,10 @@ public:
*/
size_t get_mws(const CPUInfo &platform, size_t thread_count) const override;
- /** Get the preferred dimension in which the scheduler splits the work into multiple jobs.
- *
- * @return The split dimension.
- */
- size_t get_split_dimension() const
- {
- return _split_dimension;
- }
-
private:
- size_t _split_dimension{Window::DimY};
-
std::function<void(const Window &window, const ITensor *src, ITensor *dst)> _reshape_tensor_fn{};
};
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_RESHAPE_KERNEL_H */
+#endif // ACL_SRC_CPU_KERNELS_CPURESHAPEKERNEL_H
diff --git a/src/cpu/kernels/CpuScatterKernel.cpp b/src/cpu/kernels/CpuScatterKernel.cpp
new file mode 100644
index 0000000000..bc0fa724b0
--- /dev/null
+++ b/src/cpu/kernels/CpuScatterKernel.cpp
@@ -0,0 +1,98 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/cpu/kernels/CpuScatterKernel.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/TensorInfo.h"
+
+#include <vector>
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+namespace
+{
+
+/* Scatter */
+static const std::vector<typename CpuScatterKernel::ScatterKernel> available_kernels = {
+
+};
+
+} // namespace
+
+const std::vector<typename CpuScatterKernel::ScatterKernel> &CpuScatterKernel::get_available_kernels()
+{
+ return available_kernels;
+}
+
+void CpuScatterKernel::configure(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ ITensorInfo *dst,
+ const ScatterInfo &info)
+{
+ ARM_COMPUTE_UNUSED(src);
+ ARM_COMPUTE_UNUSED(updates);
+ ARM_COMPUTE_UNUSED(indices);
+ ARM_COMPUTE_UNUSED(dst);
+ ARM_COMPUTE_UNUSED(info);
+
+ ARM_COMPUTE_UNUSED(_run_method);
+}
+
+Status CpuScatterKernel::validate(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ const ITensorInfo *dst,
+ const ScatterInfo &info)
+{
+ ARM_COMPUTE_UNUSED(src);
+ ARM_COMPUTE_UNUSED(updates);
+ ARM_COMPUTE_UNUSED(indices);
+ ARM_COMPUTE_UNUSED(dst);
+ ARM_COMPUTE_UNUSED(info);
+
+ return Status{ErrorCode::RUNTIME_ERROR, "No configuration implemented yet."};
+}
+
+void CpuScatterKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(tensors);
+ ARM_COMPUTE_UNUSED(window);
+ ARM_COMPUTE_UNUSED(info);
+
+ ARM_COMPUTE_UNUSED(_run_method);
+}
+
+const char *CpuScatterKernel::name() const
+{
+ return _name.c_str();
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/CpuScatterKernel.h b/src/cpu/kernels/CpuScatterKernel.h
new file mode 100644
index 0000000000..77c436034f
--- /dev/null
+++ b/src/cpu/kernels/CpuScatterKernel.h
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_CPUSCATTERKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPUSCATTERKERNEL_H
+
+#include "arm_compute/function_info/ScatterInfo.h"
+
+#include "src/core/common/Macros.h"
+#include "src/cpu/ICpuKernel.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+/** Arm(R) Neon(TM) kernel to perform the ScatterND operation */
+class CpuScatterKernel : public ICpuKernel<CpuScatterKernel>
+{
+private:
+ using ScatterKernelPtr = std::add_pointer<void(
+ const ITensor *, const ITensor *, const ITensor *, ITensor *, const ScatterInfo, const Window &)>::type;
+
+public:
+ CpuScatterKernel() = default;
+ ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuScatterKernel);
+ /** Initialise the kernel's input and output.
+ *
+ * @param[in] src Input tensor info for the source matrix.
+ * @param[in] updates Input tensor info for the Update matrix. Data type supported: same as @p src
+ * @param[in] indices Input tensor info for the Indices matrix. Data type supported: U32.
+ * @param[out] dst Output tensor info. Data type supported: same as @p src
+ * @param[in] info Attributes for Scatter Kernel
+ */
+ void configure(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ ITensorInfo *dst,
+ const ScatterInfo &info);
+ /** Static function to check if given info will lead to a valid configuration
+ *
+ * Similar to @ref CpuScatterKernel::configure()
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ const ITensorInfo *dst,
+ const ScatterInfo &info);
+
+ // Inherited methods overridden:
+ void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
+ const char *name() const override;
+ struct ScatterKernel
+ {
+ const char *name;
+ ScatterKernelPtr ukernel;
+ };
+
+ static const std::vector<ScatterKernel> &get_available_kernels();
+
+private:
+ ScatterKernelPtr _run_method{nullptr};
+ std::string _name{};
+};
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+#endif // ACL_SRC_CPU_KERNELS_CPUSCATTERKERNEL_H
diff --git a/src/cpu/kernels/CpuWeightsReshapeKernel.cpp b/src/cpu/kernels/CpuWeightsReshapeKernel.cpp
index 297ba63826..f8e9d123b5 100644
--- a/src/cpu/kernels/CpuWeightsReshapeKernel.cpp
+++ b/src/cpu/kernels/CpuWeightsReshapeKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -72,7 +72,10 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *biases, con
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(),
get_output_shape(src, biases != nullptr));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(src, dst);
+ if (!src->quantization_info().is_dynamic())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(src, dst);
+ }
}
return Status{};
diff --git a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h
index e2a27675b3..72fafca1bb 100644
--- a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h
+++ b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h
@@ -52,7 +52,7 @@ namespace kernel
*
*
*/
-template <typename TypeInput, typename TypeOutput>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
class CpuGemmAssemblyWrapperKernel final : public INEKernel
{
public:
@@ -101,7 +101,7 @@ public:
* @param[in] kernel Pointer to an assembly kernel implementation.
* @param[in] kernel_name_tag Tag to be attacehd to the kernel's name.
*/
- void configure(arm_gemm::GemmCommon<TypeInput, TypeOutput> *kernel, std::string kernel_name_tag)
+ void configure(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *kernel, std::string kernel_name_tag)
{
ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(kernel)));
_kernel = kernel;
@@ -131,8 +131,8 @@ public:
}
private:
- arm_gemm::GemmCommon<TypeInput, TypeOutput> *_kernel;
- std::string _name;
+ arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *_kernel;
+ std::string _name;
};
} // namespace kernel
} // namespace cpu
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 941fed0ba8..cbc8be416e 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -277,8 +277,8 @@ struct Nothing
{
};
-template <typename Top, typename Tret>
-using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
+template <typename Tlop, typename Trop, typename Tret>
+using UniqueGemmCommon = std::unique_ptr<GemmCommon<Tlop, Trop, Tret>>;
/* Low level API calls.
* These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
@@ -288,13 +288,13 @@ using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
template <typename Top, typename Tret, class OutputStage = Nothing>
KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {});
-template <typename Top, typename Tret, class OutputStage = Nothing>
-UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
+template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
+UniqueGemmCommon<Tlop, Trop, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
-template <typename Top, typename Tret, class OutputStage = Nothing>
+template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
-template <typename Top, typename Tret, class OutputStage = Nothing>
+template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm
diff --git a/src/cpu/kernels/assembly/convolution_parameters.hpp b/src/cpu/kernels/assembly/convolution_parameters.hpp
index a6cf96344c..09b73ca409 100644
--- a/src/cpu/kernels/assembly/convolution_parameters.hpp
+++ b/src/cpu/kernels/assembly/convolution_parameters.hpp
@@ -61,6 +61,8 @@ struct ConvolutionParameters
int64_t output_stride_w;
int64_t output_stride_h;
// output_channels not included as they do not affect the input.
+ int64_t dilation_w;
+ int64_t dilation_h;
int64_t padding_top;
int64_t padding_left;
float padding_value;
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index 45d1e43274..d5676c134e 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -35,6 +35,7 @@ namespace arm_gemm
{
// Avoid circular dependency with arm_gemm.hpp
struct GemmConfig;
+struct Requantize32;
// Abstract class for the GEMM/GEMV functions.
//
@@ -160,6 +161,12 @@ public:
{
}
+ /*** "Quantization update" interface (optional) ***/
+ /* Update quantization parameters at run time */
+ virtual void update_quantization_parameters(const Requantize32 &)
+ {
+ }
+
/*** Convolution interface (optional) ***/
/* Set the convolution parameters. */
virtual void set_convolution_parameters(ConvolutionParameters)
@@ -189,7 +196,7 @@ public:
* 'set_arrays' to capture the provided arguments in protected class
* members, as essentially any implementation will need these.
*/
-template <typename To, typename Tr>
+template <typename To, typename Tw, typename Tr>
class GemmCommon : public IGemmCommon
{
protected:
@@ -197,7 +204,7 @@ protected:
int _lda = 0;
int _A_batch_stride = 0;
int _A_multi_stride = 0;
- const To *_Bptr = nullptr;
+ const Tw *_Bptr = nullptr;
int _ldb = 0;
int _B_multi_stride = 0;
Tr *_Cptr = nullptr;
@@ -214,7 +221,7 @@ public:
const int lda,
const int A_batch_stride,
const int A_multi_stride,
- const To *B,
+ const Tw *B,
const int ldb,
/* batches share B */ const int B_multi_stride,
Tr *C,
@@ -254,7 +261,7 @@ public:
const void *bias,
/* no row or batch stride needed */ const int bias_multi_stride) override
{
- set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const To *>(B), ldb,
+ set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const Tw *>(B), ldb,
B_multi_stride, static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride,
static_cast<const Tr *>(bias), bias_multi_stride);
}
@@ -262,17 +269,17 @@ public:
/*** "Pretransposed" interface ***/
/* Compute col sums over all columns */
- virtual void requantize_bias(void *, const To *, const int, const int){};
+ virtual void requantize_bias(void *, const Tw *, const int, const int){};
/* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
/* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
- virtual void pretranspose_B_array(void *, const To *, const int, const int, bool){};
+ virtual void pretranspose_B_array(void *, const Tw *, const int, const int, bool){};
/* Implementation of the void * overload which casts its arguments to the appropriate type. */
void pretranspose_B_array_generic(
void *out, const void *in, const int row_stride, const int multi_stride, bool transposed) override
{
- pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride, transposed);
+ pretranspose_B_array(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed);
}
/* Threaded versions of the above.
@@ -280,7 +287,7 @@ public:
* just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only
* legal values for start and end are 0 and 1 respectively. */
virtual void pretranspose_B_array_part(
- void *out, const To *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t)
+ void *out, const Tw *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t)
{
pretranspose_B_array(out, in, row_stride, multi_stride, transposed);
};
@@ -293,7 +300,7 @@ public:
size_t start,
size_t end) override
{
- pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, transposed, start, end);
+ pretranspose_B_array_part(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed, start, end);
}
/*** Indirect interface ***/
diff --git a/src/cpu/kernels/conv3d/neon/list.h b/src/cpu/kernels/conv3d/generic/neon/float_impl.h
index 082c60be29..5b5611a02f 100644
--- a/src/cpu/kernels/conv3d/neon/list.h
+++ b/src/cpu/kernels/conv3d/generic/neon/float_impl.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,21 +21,25 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
-#define SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
+#ifndef ACL_SRC_CPU_KERNELS_CONV3D_GENERIC_NEON_FLOAT_IMPL_H
+#define ACL_SRC_CPU_KERNELS_CONV3D_GENERIC_NEON_FLOAT_IMPL_H
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/Types.h"
-#include "arm_compute/core/utils/misc/Traits.h"
+#include "arm_compute/core/Window.h"
#include "arm_compute/runtime/FunctionDescriptors.h"
#include "src/core/helpers/WindowHelpers.h"
#include "src/core/NEON/wrapper/wrapper.h"
-#include "src/cpu/kernels/conv3d/neon/quantized.h"
namespace arm_compute
{
namespace cpu
{
+namespace kernels
+{
+
template <typename T>
void directconv3d_float_neon_ndhwc(const ITensor *src0,
const ITensor *src1,
@@ -192,6 +196,7 @@ void directconv3d_float_neon_ndhwc(const ITensor *src0,
out);
}
+} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif // SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
+#endif // ACL_SRC_CPU_KERNELS_CONV3D_GENERIC_NEON_FLOAT_IMPL_H
diff --git a/src/cpu/kernels/conv3d/generic/neon/fp16.cpp b/src/cpu/kernels/conv3d/generic/neon/fp16.cpp
new file mode 100644
index 0000000000..1737556e51
--- /dev/null
+++ b/src/cpu/kernels/conv3d/generic/neon/fp16.cpp
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/conv3d/generic/neon/float_impl.h"
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+void directconv3d_fp16_neon_ndhwc(const ITensor *src0,
+ const ITensor *src1,
+ const ITensor *src2,
+ ITensor *dst,
+ const Conv3dInfo &conv_info,
+ const Window &window)
+{
+ directconv3d_float_neon_ndhwc<float16_t>(src0, src1, src2, dst, conv_info, window);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
diff --git a/src/cpu/kernels/conv3d/generic/neon/fp32.cpp b/src/cpu/kernels/conv3d/generic/neon/fp32.cpp
new file mode 100644
index 0000000000..1cd0793442
--- /dev/null
+++ b/src/cpu/kernels/conv3d/generic/neon/fp32.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/conv3d/generic/neon/float_impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+
+void directconv3d_fp32_neon_ndhwc(const ITensor *src0,
+ const ITensor *src1,
+ const ITensor *src2,
+ ITensor *dst,
+ const Conv3dInfo &conv_info,
+ const Window &window)
+{
+ directconv3d_float_neon_ndhwc<float>(src0, src1, src2, dst, conv_info, window);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/conv3d/generic/neon/qasymm8.cpp b/src/cpu/kernels/conv3d/generic/neon/qasymm8.cpp
new file mode 100644
index 0000000000..d0cb6fc1c1
--- /dev/null
+++ b/src/cpu/kernels/conv3d/generic/neon/qasymm8.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/conv3d/generic/neon/quantized_impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+
+void directconv3d_qu8_neon_ndhwc(const ITensor *src0,
+ const ITensor *src1,
+ const ITensor *src2,
+ ITensor *dst,
+ const Conv3dInfo &conv_info,
+ const Window &window)
+{
+ directconv3d_quantized_neon_ndhwc<uint8_t>(src0, src1, src2, dst, conv_info, window);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/conv3d/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/conv3d/generic/neon/qasymm8_signed.cpp
new file mode 100644
index 0000000000..adffc1a3f8
--- /dev/null
+++ b/src/cpu/kernels/conv3d/generic/neon/qasymm8_signed.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/conv3d/generic/neon/quantized_impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+
+void directconv3d_qs8_neon_ndhwc(const ITensor *src0,
+ const ITensor *src1,
+ const ITensor *src2,
+ ITensor *dst,
+ const Conv3dInfo &conv_info,
+ const Window &window)
+{
+ directconv3d_quantized_neon_ndhwc<int8_t>(src0, src1, src2, dst, conv_info, window);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/conv3d/neon/quantized.h b/src/cpu/kernels/conv3d/generic/neon/quantized_impl.h
index f0fc9b5a71..b6b41035f8 100644
--- a/src/cpu/kernels/conv3d/neon/quantized.h
+++ b/src/cpu/kernels/conv3d/generic/neon/quantized_impl.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,9 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_CORE_NEON_KERNELS_CONV3D_QUANTIZED_H
-#define SRC_CORE_NEON_KERNELS_CONV3D_QUANTIZED_H
+#ifndef ACL_SRC_CPU_KERNELS_CONV3D_GENERIC_NEON_QUANTIZED_IMPL_H
+#define ACL_SRC_CPU_KERNELS_CONV3D_GENERIC_NEON_QUANTIZED_IMPL_H
+#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/utils/misc/Traits.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
@@ -37,6 +38,8 @@ namespace arm_compute
{
namespace cpu
{
+namespace kernels
+{
template <typename T>
void directconv3d_quantized_neon_ndhwc(const ITensor *src0,
const ITensor *src1,
@@ -270,6 +273,7 @@ void directconv3d_quantized_neon_ndhwc(const ITensor *src0,
},
out);
}
+} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif // SRC_CORE_NEON_KERNELS_CONV3D_QUANTIZED_H
+#endif // ACL_SRC_CPU_KERNELS_CONV3D_GENERIC_NEON_QUANTIZED_IMPL_H
diff --git a/src/cpu/kernels/conv3d/list.h b/src/cpu/kernels/conv3d/list.h
new file mode 100644
index 0000000000..256d28825d
--- /dev/null
+++ b/src/cpu/kernels/conv3d/list.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_CONV3D_LIST_H
+#define ACL_SRC_CPU_KERNELS_CONV3D_LIST_H
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+
+#define DECLARE_CONV3D_KERNEL(func_name) \
+ void func_name(const ITensor *src0, const ITensor *src1, const ITensor *src2, ITensor *dst, \
+ const Conv3dInfo &conv_info, const Window &window)
+
+DECLARE_CONV3D_KERNEL(directconv3d_fp16_neon_ndhwc);
+DECLARE_CONV3D_KERNEL(directconv3d_fp32_neon_ndhwc);
+DECLARE_CONV3D_KERNEL(directconv3d_qu8_neon_ndhwc);
+DECLARE_CONV3D_KERNEL(directconv3d_qs8_neon_ndhwc);
+#undef DECLARE_CONV3D_KERNEL
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+#endif // ACL_SRC_CPU_KERNELS_CONV3D_LIST_H
diff --git a/src/cpu/kernels/directconv2d_output_stage/generic/neon/float_impl.h b/src/cpu/kernels/directconv2d_output_stage/generic/neon/float_impl.h
new file mode 100644
index 0000000000..266fa68ab2
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/generic/neon/float_impl.h
@@ -0,0 +1,186 @@
+/*
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_GENERIC_NEON_FLOAT_IMPL_H
+#define ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_GENERIC_NEON_FLOAT_IMPL_H
+
+#include "arm_compute/core/Helpers.h" // Iterator
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+#include "src/core/NEON/wrapper/wrapper.h"
+
+#include <cstdint>
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+template <typename T>
+void output_stage_nchw_fp(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ const bool has_bias = bias != nullptr;
+ /** SIMD vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+ ARM_COMPUTE_ERROR_ON(src->info()->data_layout() == DataLayout::UNKNOWN);
+ ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
+ ARM_COMPUTE_UNUSED(result_shift);
+ ARM_COMPUTE_UNUSED(result_offset_after_shift);
+
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_step_x = 16 / src->info()->element_size();
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator in(src, win);
+ Iterator out(dst, win);
+ execute_window_loop(
+ win,
+ [&](const Coordinates &id)
+ {
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Get bias and pointer to input
+ const auto in_ptr = reinterpret_cast<const T *>(in.ptr()) + x;
+ auto v_in = wrapper::vloadq(in_ptr);
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto vb = wrapper::vdup_n(
+ *reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(id.z()))), ExactTagType{});
+ v_in = wrapper::vadd(v_in, vb);
+ }
+
+ const auto out_ptr = reinterpret_cast<T *>(out.ptr()) + x;
+ wrapper::vstore(out_ptr, v_in);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Get bias and pointer to input
+ auto s_in = *(reinterpret_cast<const T *>(in.ptr()) + x);
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto b = *reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(id.z())));
+ s_in += b;
+ }
+
+ *(reinterpret_cast<T *>(out.ptr()) + x) = s_in;
+ }
+ },
+ in, out);
+}
+
+template <typename T>
+void output_stage_nhwc_fp(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ const bool has_bias = bias != nullptr;
+ ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
+ ARM_COMPUTE_UNUSED(result_shift);
+ ARM_COMPUTE_UNUSED(result_offset_after_shift);
+
+ Window window_bias = window;
+ window_bias.set(Window::DimX, Window::Dimension(0, 1, 1));
+ window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
+ window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
+ window_bias.set(3, Window::Dimension(0, 0, 0));
+
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_step_x = 16 / src->info()->element_size();
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator in(src, win);
+ Iterator bi(bias, window_bias);
+ Iterator out(dst, win);
+
+ execute_window_loop(
+ win,
+ [&](const Coordinates &)
+ {
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Get bias and pointer to input
+ const auto in_ptr = reinterpret_cast<const T *>(in.ptr());
+ auto v_in = wrapper::vloadq(in_ptr + x);
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto bias_ptr = reinterpret_cast<T *>(bi.ptr()) + x;
+ v_in = wrapper::vadd(v_in, wrapper::vloadq(bias_ptr));
+ }
+
+ const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+ wrapper::vstore(out_ptr + x, v_in);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Get bias and pointer to input
+ auto s_in = *(reinterpret_cast<const T *>(in.ptr()) + x);
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto bias_ptr = reinterpret_cast<T *>(bi.ptr()) + x;
+ s_in += *bias_ptr;
+ }
+
+ const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+ *(out_ptr + x) = s_in;
+ }
+ },
+ in, bi, out);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_GENERIC_NEON_FLOAT_IMPL_H
diff --git a/src/cpu/kernels/directconv2d_output_stage/generic/neon/fp16.cpp b/src/cpu/kernels/directconv2d_output_stage/generic/neon/fp16.cpp
new file mode 100644
index 0000000000..d7550da721
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/generic/neon/fp16.cpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/directconv2d_output_stage/generic/neon/float_impl.h"
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+void output_stage_nhwc_fp16(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nhwc_fp<float16_t>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+void output_stage_nchw_fp16(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nchw_fp<float16_t>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
diff --git a/src/cpu/kernels/directconv2d_output_stage/generic/neon/fp32.cpp b/src/cpu/kernels/directconv2d_output_stage/generic/neon/fp32.cpp
new file mode 100644
index 0000000000..05dec370b9
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/generic/neon/fp32.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/directconv2d_output_stage/generic/neon/float_impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+void output_stage_nhwc_fp32(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nhwc_fp<float>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+void output_stage_nchw_fp32(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nchw_fp<float>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8.cpp b/src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8.cpp
new file mode 100644
index 0000000000..dbdf951b6f
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/directconv2d_output_stage/generic/neon/quantized_impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+void output_stage_nhwc_qu8(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nhwc_quant<uint8_t>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+void output_stage_nchw_qu8(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nchw_quant<uint8_t>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8_signed.cpp
new file mode 100644
index 0000000000..c00fe87161
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/generic/neon/qasymm8_signed.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/cpu/kernels/directconv2d_output_stage/generic/neon/quantized_impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+void output_stage_nhwc_qs8(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nhwc_quant<int8_t>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+void output_stage_nchw_qs8(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ output_stage_nchw_quant<int8_t>(src, bias, window, dst, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/directconv2d_output_stage/generic/neon/quantized_impl.h b/src/cpu/kernels/directconv2d_output_stage/generic/neon/quantized_impl.h
new file mode 100644
index 0000000000..f74ed55ad0
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/generic/neon/quantized_impl.h
@@ -0,0 +1,213 @@
+/*
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_GENERIC_NEON_QUANTIZED_IMPL_H
+#define ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_GENERIC_NEON_QUANTIZED_IMPL_H
+
+#include "arm_compute/core/Helpers.h" // Iterator
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+#include "src/core/NEON/NEAsymm.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+
+#include <cstdint>
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+
+template <typename TOut>
+void output_stage_nchw_quant(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ const bool has_bias = bias != nullptr;
+ using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
+ using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
+
+ const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
+
+ const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
+ const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
+
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_step_x = 16 / src->info()->element_size();
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator in(src, win);
+ Iterator out(dst, win);
+
+ execute_window_loop(
+ win,
+ [&](const Coordinates &id)
+ {
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Get bias and pointer to input
+ const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr()) + x;
+ int32x4x4_t v_in = {{wrapper::vloadq(in_ptr), wrapper::vloadq(in_ptr + 4), wrapper::vloadq(in_ptr + 8),
+ wrapper::vloadq(in_ptr + 12)}};
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto vb = wrapper::vdup_n(
+ *reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z()))), TagType{});
+ v_in = {{wrapper::vadd(v_in.val[0], vb), wrapper::vadd(v_in.val[1], vb),
+ wrapper::vadd(v_in.val[2], vb), wrapper::vadd(v_in.val[3], vb)}};
+ }
+
+ const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
+ wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift_s32, min, max, false));
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Get bias and pointer to input
+ int32_t s_in = *(reinterpret_cast<const int32_t *>(in.ptr()) + x);
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto b = *reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z())));
+ s_in += b;
+ }
+
+ const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
+ *out_ptr =
+ finalize_quantization(s_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift,
+ std::numeric_limits<TOut>::lowest(), std::numeric_limits<TOut>::max(), false);
+ }
+ },
+ in, out);
+}
+template <
+ typename TOut,
+ typename std::enable_if<std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int>::type = 0>
+void output_stage_nhwc_quant(ITensor *src,
+ const ITensor *bias,
+ const Window &window,
+ ITensor *dst,
+ int result_fixedpoint_multiplier,
+ int result_shift,
+ int result_offset_after_shift)
+{
+ const bool has_bias = bias != nullptr;
+ using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
+ using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
+
+ const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
+
+ const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
+ const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
+
+ Window window_bias = window;
+ window_bias.set(Window::DimX, Window::Dimension(0, 1, 1));
+ window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
+ window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
+ window_bias.set(3, Window::Dimension(0, 0, 0));
+
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_step_x = 16 / src->info()->element_size();
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator in(src, win);
+ Iterator bi(bias, window_bias);
+ Iterator out(dst, win);
+
+ execute_window_loop(
+ win,
+ [&](const Coordinates &)
+ {
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Get bias and pointer to input
+ const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr()) + x;
+ int32x4x4_t v_in = {{
+ wrapper::vloadq(in_ptr),
+ wrapper::vloadq(in_ptr + 4),
+ wrapper::vloadq(in_ptr + 8),
+ wrapper::vloadq(in_ptr + 12),
+ }};
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr()) + x;
+
+ wrapper::vadd(v_in.val[0], wrapper::vloadq(bias_ptr));
+ wrapper::vadd(v_in.val[1], wrapper::vloadq(bias_ptr + 4));
+ wrapper::vadd(v_in.val[2], wrapper::vloadq(bias_ptr + 8));
+ wrapper::vadd(v_in.val[3], wrapper::vloadq(bias_ptr + 12));
+ }
+
+ const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
+ wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift,
+ result_offset_after_shift_s32, min, max, false));
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Get bias and pointer to input
+ const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr()) + x;
+ int32_t s_in = *in_ptr;
+
+ // Accumulate bias
+ if (has_bias)
+ {
+ const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr()) + x;
+ s_in += *bias_ptr;
+ }
+
+ const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()) + x;
+ *out_ptr =
+ finalize_quantization(s_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift,
+ std::numeric_limits<TOut>::lowest(), std::numeric_limits<TOut>::max(), false);
+ }
+ },
+ in, bi, out);
+}
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_GENERIC_NEON_QUANTIZED_IMPL_H
diff --git a/src/cpu/kernels/directconv2d_output_stage/list.h b/src/cpu/kernels/directconv2d_output_stage/list.h
new file mode 100644
index 0000000000..9372269bca
--- /dev/null
+++ b/src/cpu/kernels/directconv2d_output_stage/list.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_LIST_H
+#define ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_LIST_H
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+namespace kernels
+{
+
+#define DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(func_name) \
+ void func_name(ITensor *src, const ITensor *bias, const Window &window, ITensor *dst, \
+ int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nhwc_fp32);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nhwc_fp16);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nchw_fp32);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nchw_fp16);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nhwc_qs8);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nhwc_qu8);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nchw_qs8);
+DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL(output_stage_nchw_qu8);
+
+#undef DECLARE_DIRECTCONV2D_OUTPUT_STAGE_KERNEL
+
+} // namespace kernels
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_KERNELS_DIRECTCONV2D_OUTPUT_STAGE_LIST_H
diff --git a/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp b/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp
index 404d070a37..580fdc3e8f 100644
--- a/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp
+++ b/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -81,7 +81,7 @@ void vector_matrix_multiply_f32(
// window_end_x is computed above which may cause out-of-bound writes to the dst.
for (; x < (window_end_x - window_step_x); x += window_step_x)
{
- if (x > width_matrix_b)
+ if (x >= width_matrix_b)
{
return;
}
@@ -203,7 +203,7 @@ void vector_matrix_multiply_f32(
// Left-over loop
for (; x < window_end_x; ++x)
{
- if (x > width_matrix_b)
+ if (x >= width_matrix_b)
{
return;
}
@@ -309,9 +309,21 @@ void matrix_matrix_multiply_f32(
Iterator inb(rhs, win_b);
Iterator out(dst, window);
- const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
+ // End address of matrix B at batch number n
+ const float *end_addr_mtx_b_at_batch_n =
+ reinterpret_cast<const float *>(inb.ptr()) + rhs->info()->dimension(0) * rhs->info()->dimension(1);
+ std::vector<const float *> end_addr_mtx_b_per_batch = {};
+ const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
+ const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
+ const size_t out_dim2 = static_cast<int>(dst->info()->dimension(2));
- const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
+ for (size_t b = 0; b < out_dim2; ++b)
+ {
+ // Store the ptrs to the last elem in the tensor for each batch
+ end_addr_mtx_b_per_batch.push_back(end_addr_mtx_b_at_batch_n);
+ end_addr_mtx_b_at_batch_n +=
+ rhs->info()->dimension(2) != 1 ? rhs->info()->dimension(0) * rhs->info()->dimension(1) : 0;
+ }
// The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW
// The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
@@ -341,220 +353,374 @@ void matrix_matrix_multiply_f32(
#endif /* __arm__ */
auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
- for (; mtx_b0 <= (mtx_b0_end_addr - 32);)
+
+ ARM_COMPUTE_ERROR_ON(end_addr_mtx_b_per_batch.size() == 0);
+ if (mtx_b1 < end_addr_mtx_b_per_batch[id.z()])
{
- float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
- float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
- float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
- float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
+ for (; mtx_b0 < (mtx_b0_end_addr - 32);)
+ {
+ float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
+ float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
+ float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
+ float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
- float32x4_t b00 = vld1q_f32(mtx_b0);
- float32x4_t b10 = vld1q_f32(mtx_b1);
- float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
- float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
+ float32x4_t b00 = vld1q_f32(mtx_b0);
+ float32x4_t b10 = vld1q_f32(mtx_b1);
+ float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
+ float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
#if __arm__
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
#endif /* __arm__ */
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b00, a0);
- acc10 = vmlaq_f32(acc10, b00, a1);
- acc20 = vmlaq_f32(acc20, b00, a2);
- acc30 = vmlaq_f32(acc30, b00, a3);
-
- float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
- float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
- float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
- float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b10, a0);
- acc11 = vmlaq_f32(acc11, b10, a1);
- acc21 = vmlaq_f32(acc21, b10, a2);
- acc31 = vmlaq_f32(acc31, b10, a3);
-
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b01, a4);
- acc10 = vmlaq_f32(acc10, b01, a5);
- acc20 = vmlaq_f32(acc20, b01, a6);
- acc30 = vmlaq_f32(acc30, b01, a7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b11, a4);
- acc11 = vmlaq_f32(acc11, b11, a5);
- acc21 = vmlaq_f32(acc21, b11, a6);
- acc31 = vmlaq_f32(acc31, b11, a7);
-
- mtx_a0 += 8;
- mtx_b0 += 8;
- mtx_b1 += 8;
-
- a0 = vld1q_dup_f32(mtx_a0 + 0);
- a1 = vld1q_dup_f32(mtx_a0 + 1);
- a2 = vld1q_dup_f32(mtx_a0 + 2);
- a3 = vld1q_dup_f32(mtx_a0 + 3);
-
- b00 = vld1q_f32(mtx_b0);
- b10 = vld1q_f32(mtx_b1);
- b01 = vld1q_f32(mtx_b0 + 4);
- b11 = vld1q_f32(mtx_b1 + 4);
-
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b00, a0);
- acc10 = vmlaq_f32(acc10, b00, a1);
- acc20 = vmlaq_f32(acc20, b00, a2);
- acc30 = vmlaq_f32(acc30, b00, a3);
-
- a4 = vld1q_dup_f32(mtx_a0 + 4);
- a5 = vld1q_dup_f32(mtx_a0 + 5);
- a6 = vld1q_dup_f32(mtx_a0 + 6);
- a7 = vld1q_dup_f32(mtx_a0 + 7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b10, a0);
- acc11 = vmlaq_f32(acc11, b10, a1);
- acc21 = vmlaq_f32(acc21, b10, a2);
- acc31 = vmlaq_f32(acc31, b10, a3);
-
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b01, a4);
- acc10 = vmlaq_f32(acc10, b01, a5);
- acc20 = vmlaq_f32(acc20, b01, a6);
- acc30 = vmlaq_f32(acc30, b01, a7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b11, a4);
- acc11 = vmlaq_f32(acc11, b11, a5);
- acc21 = vmlaq_f32(acc21, b11, a6);
- acc31 = vmlaq_f32(acc31, b11, a7);
-
- mtx_a0 += 8;
- mtx_b0 += 8;
- mtx_b1 += 8;
-
- a0 = vld1q_dup_f32(mtx_a0 + 0);
- a1 = vld1q_dup_f32(mtx_a0 + 1);
- a2 = vld1q_dup_f32(mtx_a0 + 2);
- a3 = vld1q_dup_f32(mtx_a0 + 3);
- b00 = vld1q_f32(mtx_b0);
- b10 = vld1q_f32(mtx_b1);
- b01 = vld1q_f32(mtx_b0 + 4);
- b11 = vld1q_f32(mtx_b1 + 4);
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
+ float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
+ float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
+ float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b10, a0);
+ acc11 = vmlaq_f32(acc11, b10, a1);
+ acc21 = vmlaq_f32(acc21, b10, a2);
+ acc31 = vmlaq_f32(acc31, b10, a3);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b11, a4);
+ acc11 = vmlaq_f32(acc11, b11, a5);
+ acc21 = vmlaq_f32(acc21, b11, a6);
+ acc31 = vmlaq_f32(acc31, b11, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+ mtx_b1 += 8;
+
+ a0 = vld1q_dup_f32(mtx_a0 + 0);
+ a1 = vld1q_dup_f32(mtx_a0 + 1);
+ a2 = vld1q_dup_f32(mtx_a0 + 2);
+ a3 = vld1q_dup_f32(mtx_a0 + 3);
+
+ b00 = vld1q_f32(mtx_b0);
+ b10 = vld1q_f32(mtx_b1);
+ b01 = vld1q_f32(mtx_b0 + 4);
+ b11 = vld1q_f32(mtx_b1 + 4);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ a4 = vld1q_dup_f32(mtx_a0 + 4);
+ a5 = vld1q_dup_f32(mtx_a0 + 5);
+ a6 = vld1q_dup_f32(mtx_a0 + 6);
+ a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b10, a0);
+ acc11 = vmlaq_f32(acc11, b10, a1);
+ acc21 = vmlaq_f32(acc21, b10, a2);
+ acc31 = vmlaq_f32(acc31, b10, a3);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b11, a4);
+ acc11 = vmlaq_f32(acc11, b11, a5);
+ acc21 = vmlaq_f32(acc21, b11, a6);
+ acc31 = vmlaq_f32(acc31, b11, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+ mtx_b1 += 8;
+
+ a0 = vld1q_dup_f32(mtx_a0 + 0);
+ a1 = vld1q_dup_f32(mtx_a0 + 1);
+ a2 = vld1q_dup_f32(mtx_a0 + 2);
+ a3 = vld1q_dup_f32(mtx_a0 + 3);
+ b00 = vld1q_f32(mtx_b0);
+ b10 = vld1q_f32(mtx_b1);
+ b01 = vld1q_f32(mtx_b0 + 4);
+ b11 = vld1q_f32(mtx_b1 + 4);
#if __arm__
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
- asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
#endif /* __arm__ */
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b00, a0);
- acc10 = vmlaq_f32(acc10, b00, a1);
- acc20 = vmlaq_f32(acc20, b00, a2);
- acc30 = vmlaq_f32(acc30, b00, a3);
-
- a4 = vld1q_dup_f32(mtx_a0 + 4);
- a5 = vld1q_dup_f32(mtx_a0 + 5);
- a6 = vld1q_dup_f32(mtx_a0 + 6);
- a7 = vld1q_dup_f32(mtx_a0 + 7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b10, a0);
- acc11 = vmlaq_f32(acc11, b10, a1);
- acc21 = vmlaq_f32(acc21, b10, a2);
- acc31 = vmlaq_f32(acc31, b10, a3);
-
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b01, a4);
- acc10 = vmlaq_f32(acc10, b01, a5);
- acc20 = vmlaq_f32(acc20, b01, a6);
- acc30 = vmlaq_f32(acc30, b01, a7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b11, a4);
- acc11 = vmlaq_f32(acc11, b11, a5);
- acc21 = vmlaq_f32(acc21, b11, a6);
- acc31 = vmlaq_f32(acc31, b11, a7);
-
- mtx_a0 += 8;
- mtx_b0 += 8;
- mtx_b1 += 8;
-
- a0 = vld1q_dup_f32(mtx_a0 + 0);
- a1 = vld1q_dup_f32(mtx_a0 + 1);
- a2 = vld1q_dup_f32(mtx_a0 + 2);
- a3 = vld1q_dup_f32(mtx_a0 + 3);
- b00 = vld1q_f32(mtx_b0);
- b10 = vld1q_f32(mtx_b1);
- b01 = vld1q_f32(mtx_b0 + 4);
- b11 = vld1q_f32(mtx_b1 + 4);
-
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b00, a0);
- acc10 = vmlaq_f32(acc10, b00, a1);
- acc20 = vmlaq_f32(acc20, b00, a2);
- acc30 = vmlaq_f32(acc30, b00, a3);
-
- a4 = vld1q_dup_f32(mtx_a0 + 4);
- a5 = vld1q_dup_f32(mtx_a0 + 5);
- a6 = vld1q_dup_f32(mtx_a0 + 6);
- a7 = vld1q_dup_f32(mtx_a0 + 7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b10, a0);
- acc11 = vmlaq_f32(acc11, b10, a1);
- acc21 = vmlaq_f32(acc21, b10, a2);
- acc31 = vmlaq_f32(acc31, b10, a3);
-
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b01, a4);
- acc10 = vmlaq_f32(acc10, b01, a5);
- acc20 = vmlaq_f32(acc20, b01, a6);
- acc30 = vmlaq_f32(acc30, b01, a7);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b11, a4);
- acc11 = vmlaq_f32(acc11, b11, a5);
- acc21 = vmlaq_f32(acc21, b11, a6);
- acc31 = vmlaq_f32(acc31, b11, a7);
-
- mtx_a0 += 8;
- mtx_b0 += 8;
- mtx_b1 += 8;
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ a4 = vld1q_dup_f32(mtx_a0 + 4);
+ a5 = vld1q_dup_f32(mtx_a0 + 5);
+ a6 = vld1q_dup_f32(mtx_a0 + 6);
+ a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b10, a0);
+ acc11 = vmlaq_f32(acc11, b10, a1);
+ acc21 = vmlaq_f32(acc21, b10, a2);
+ acc31 = vmlaq_f32(acc31, b10, a3);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b11, a4);
+ acc11 = vmlaq_f32(acc11, b11, a5);
+ acc21 = vmlaq_f32(acc21, b11, a6);
+ acc31 = vmlaq_f32(acc31, b11, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+ mtx_b1 += 8;
+
+ a0 = vld1q_dup_f32(mtx_a0 + 0);
+ a1 = vld1q_dup_f32(mtx_a0 + 1);
+ a2 = vld1q_dup_f32(mtx_a0 + 2);
+ a3 = vld1q_dup_f32(mtx_a0 + 3);
+ b00 = vld1q_f32(mtx_b0);
+ b10 = vld1q_f32(mtx_b1);
+ b01 = vld1q_f32(mtx_b0 + 4);
+ b11 = vld1q_f32(mtx_b1 + 4);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ a4 = vld1q_dup_f32(mtx_a0 + 4);
+ a5 = vld1q_dup_f32(mtx_a0 + 5);
+ a6 = vld1q_dup_f32(mtx_a0 + 6);
+ a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b10, a0);
+ acc11 = vmlaq_f32(acc11, b10, a1);
+ acc21 = vmlaq_f32(acc21, b10, a2);
+ acc31 = vmlaq_f32(acc31, b10, a3);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b11, a4);
+ acc11 = vmlaq_f32(acc11, b11, a5);
+ acc21 = vmlaq_f32(acc21, b11, a6);
+ acc31 = vmlaq_f32(acc31, b11, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+ mtx_b1 += 8;
+ }
+
+ // Only consider one row from matrix b if subsequent row is out of boundary.
+ for (; mtx_b0 < mtx_b0_end_addr;)
+ {
+ float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
+ float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
+ float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
+ float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
+ float32x4_t b00 = vld1q_f32(mtx_b0);
+ float32x4_t b10 = vld1q_f32(mtx_b1);
+
+#if __arm__
+ asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
+ asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
+ asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
+#endif /* __arm__ */
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ // 4x4 block 1
+ acc01 = vmlaq_f32(acc01, b10, a0);
+ acc11 = vmlaq_f32(acc11, b10, a1);
+ acc21 = vmlaq_f32(acc21, b10, a2);
+ acc31 = vmlaq_f32(acc31, b10, a3);
+
+ mtx_a0 += 4;
+ mtx_b0 += 4;
+ mtx_b1 += 4;
+ }
}
- for (; mtx_b0 < mtx_b0_end_addr;)
+ // Leftover last row in matrix b, in case of there are odd number of rows in matrix B
+ else if (mtx_b0 < end_addr_mtx_b_per_batch[id.z()])
{
- float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
- float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
- float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
- float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
- float32x4_t b00 = vld1q_f32(mtx_b0);
- float32x4_t b10 = vld1q_f32(mtx_b1);
+ for (; mtx_b0 < (mtx_b0_end_addr - 32);)
+ {
+ float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
+ float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
+ float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
+ float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
+
+ float32x4_t b00 = vld1q_f32(mtx_b0);
+ float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
#if __arm__
- asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
- asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
- asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
#endif /* __arm__ */
- // 4x4 block 0
- acc00 = vmlaq_f32(acc00, b00, a0);
- acc10 = vmlaq_f32(acc10, b00, a1);
- acc20 = vmlaq_f32(acc20, b00, a2);
- acc30 = vmlaq_f32(acc30, b00, a3);
-
- // 4x4 block 1
- acc01 = vmlaq_f32(acc01, b10, a0);
- acc11 = vmlaq_f32(acc11, b10, a1);
- acc21 = vmlaq_f32(acc21, b10, a2);
- acc31 = vmlaq_f32(acc31, b10, a3);
-
- mtx_a0 += 4;
- mtx_b0 += 4;
- mtx_b1 += 4;
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
+ float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
+ float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
+ float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+
+ a0 = vld1q_dup_f32(mtx_a0 + 0);
+ a1 = vld1q_dup_f32(mtx_a0 + 1);
+ a2 = vld1q_dup_f32(mtx_a0 + 2);
+ a3 = vld1q_dup_f32(mtx_a0 + 3);
+
+ b00 = vld1q_f32(mtx_b0);
+ b01 = vld1q_f32(mtx_b0 + 4);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ a4 = vld1q_dup_f32(mtx_a0 + 4);
+ a5 = vld1q_dup_f32(mtx_a0 + 5);
+ a6 = vld1q_dup_f32(mtx_a0 + 6);
+ a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+
+ a0 = vld1q_dup_f32(mtx_a0 + 0);
+ a1 = vld1q_dup_f32(mtx_a0 + 1);
+ a2 = vld1q_dup_f32(mtx_a0 + 2);
+ a3 = vld1q_dup_f32(mtx_a0 + 3);
+ b00 = vld1q_f32(mtx_b0);
+ b01 = vld1q_f32(mtx_b0 + 4);
+
+#if __arm__
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
+ asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
+#endif /* __arm__ */
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ a4 = vld1q_dup_f32(mtx_a0 + 4);
+ a5 = vld1q_dup_f32(mtx_a0 + 5);
+ a6 = vld1q_dup_f32(mtx_a0 + 6);
+ a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+
+ a0 = vld1q_dup_f32(mtx_a0 + 0);
+ a1 = vld1q_dup_f32(mtx_a0 + 1);
+ a2 = vld1q_dup_f32(mtx_a0 + 2);
+ a3 = vld1q_dup_f32(mtx_a0 + 3);
+ b00 = vld1q_f32(mtx_b0);
+ b01 = vld1q_f32(mtx_b0 + 4);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ a4 = vld1q_dup_f32(mtx_a0 + 4);
+ a5 = vld1q_dup_f32(mtx_a0 + 5);
+ a6 = vld1q_dup_f32(mtx_a0 + 6);
+ a7 = vld1q_dup_f32(mtx_a0 + 7);
+
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b01, a4);
+ acc10 = vmlaq_f32(acc10, b01, a5);
+ acc20 = vmlaq_f32(acc20, b01, a6);
+ acc30 = vmlaq_f32(acc30, b01, a7);
+
+ mtx_a0 += 8;
+ mtx_b0 += 8;
+ }
+ for (; mtx_b0 < mtx_b0_end_addr;)
+ {
+ float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
+ float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
+ float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
+ float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
+ float32x4_t b00 = vld1q_f32(mtx_b0);
+
+#if __arm__
+ asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
+ asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
+#endif /* __arm__ */
+ // 4x4 block 0
+ acc00 = vmlaq_f32(acc00, b00, a0);
+ acc10 = vmlaq_f32(acc10, b00, a1);
+ acc20 = vmlaq_f32(acc20, b00, a2);
+ acc30 = vmlaq_f32(acc30, b00, a3);
+
+ mtx_a0 += 4;
+ mtx_b0 += 4;
+ }
}
// Multiply by the weight of matrix product (alpha)
diff --git a/src/cpu/kernels/logistic/generic/sme2/fp32.cpp b/src/cpu/kernels/logistic/generic/sme2/fp32.cpp
new file mode 100644
index 0000000000..876e466594
--- /dev/null
+++ b/src/cpu/kernels/logistic/generic/sme2/fp32.cpp
@@ -0,0 +1,429 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// This function expects a collapsed 2D shape.
+void sme2_f32_logistic_kernel(const float *src,
+ float *dst,
+ const uintptr_t shape[2],
+ const uintptr_t src_strides[2],
+ const uintptr_t dst_strides[2])
+{
+ // Precondition:
+ assert(src_strides[0] == sizeof(float));
+ assert(dst_strides[0] == sizeof(float));
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+
+ ptrue p0.b
+ .inst 0x25207811 // ptrue pn9.b
+
+ // Registers
+ //
+ // * x9: temporary, index
+ // * x10: temporary, inf
+ // * x11: temporary, 0
+ // * x12: temporary, 1.0f
+ // * x13: temporary, body_length
+ //
+ // * x26: index_1
+ // * x27: src_1
+ // * x28: dst_1
+ //
+ // * z0: c1
+ // * z1: c2
+ // * z2: c3
+ // * z3: c4
+ // * z4: c5
+ // * z5: shift
+ // * z6: inv_ln2
+ // * z7: neg_ln2_hi
+ // * z8: neg_ln2_lo
+ // * z9: min_input, max_input
+ // * z10: 23, 0, 1, inf
+ // * z11: max_value
+ // * z12-z15: x, r_hi, r, r2
+ // * z16-z19: max_value, shift, z, scale, poly
+ // * z20-z21: n, p1, p12345
+ // * z22-z23: n, p23, p2345
+ // * z24-z25: p45
+ // * z26: max_input
+ // * z28-z31: sum_value
+ //
+ // * za0-za3: sum_value
+ //
+ // * p0: all-true
+ // * p1-p4: underflow,
+ // * p4: leftover predicate
+ // * p5-p8: overflow,
+ // * pn9: all-true
+
+ // TAYLORS CONSTANTS
+ mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ dup z0.s, w9 // c1.
+ dup z1.s, w10 // c2.
+ dup z2.s, w11 // c3.
+ dup z3.s, w12 // c4.
+ dup z4.s, w13 // c5.
+
+ mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+ mov w14, #0xBD71 // max_input 88.37 = 0x42B0BD71
+
+ movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+ movk w14, #0x42B0, LSL #16 // max_input (88.37) = 0x42B0BD71
+
+ dup z5.s, w9 // shift
+ dup z6.s, w10 // inv_ln2
+ dup z7.s, w11 // neg_ln2_hi
+ dup z8.s, w12 // neg_ln2_lo
+ dup z9.s, w13 // min_input
+ dup z26.s, w14 // max_input
+
+ mov w10, #0x0000 // inf: 0x7F800000
+ movk w10, #0x7F80, LSL #16 // inf: 0x7F800000
+
+ mov w15, #0x0000
+ movk w15, #0x3F80, LSL #16 // 1
+
+ mov w11, #0 // 0
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cntw x13, ALL, MUL #4 // x13 is vl
+ udiv x9, %x[length], x13 // length/vl
+ mul x13, x13, x9 // x13 = vl * result
+
+ // ==================================================
+ // Outer loop opening
+ // ==================================================
+
+ mov x27, %x[src] // starting point of pointers for src.
+ mov x28, %x[dst] // starting point of pointers for dst.
+ mov x26, %x[shape_1]
+
+outer_loop_start%=:
+ // for index_1 in shape_1 downto 1
+ cmp x26, #0
+ b.eq outer_loop_end%=
+ sub x26, x26, #1
+
+ mov x9, #0 // x9: index
+
+inner_body_start%=:
+ cmp x9, x13
+ b.eq inner_body_end%=
+
+ // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data
+ .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2]
+
+ // ---------------------------------------------------------------- z12-z15: x = neg(x)
+ fneg z12.s, p0/m, z12.s
+ fneg z13.s, p0/m, z13.s
+ fneg z14.s, p0/m, z14.s
+ fneg z15.s, p0/m, z15.s
+
+ // ---------------------------------------------------------------- p4-p7: underflow = x < min_input
+ fcmlt p1.s, p0/z, z12.s, z9.s
+ fcmlt p2.s, p0/z, z13.s, z9.s
+ fcmlt p3.s, p0/z, z14.s, z9.s
+ fcmlt p4.s, p0/z, z15.s, z9.s
+
+ // ---------------------------------------------------------------- p4-p7: overflow = x > max_input
+ fcmlt p5.s, p0/z, z26.s, z12.s
+ fcmlt p6.s, p0/z, z26.s, z13.s
+ fcmlt p7.s, p0/z, z26.s, z14.s
+ fcmlt p8.s, p0/z, z26.s, z15.s
+
+ // ---------------------------------------------------------------- z16-z19: shift
+ mov z16.d, z5.d
+ mov z17.d, z5.d
+ mov z18.d, z5.d
+ mov z19.d, z5.d
+
+ // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
+ fmla z16.s, p0/m, z12.s, z6.s
+ fmla z17.s, p0/m, z13.s, z6.s
+ fmla z18.s, p0/m, z14.s, z6.s
+ fmla z19.s, p0/m, z15.s, z6.s
+
+ // ---------------------------------------------------------------- z20-z23: n = z - shift
+ fsub z20.s, z16.s, z5.s
+ fsub z21.s, z17.s, z5.s
+ fsub z22.s, z18.s, z5.s
+ fsub z23.s, z19.s, z5.s
+
+ // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p0/m, z20.s, z7.s
+ fmla z13.s, p0/m, z21.s, z7.s
+ fmla z14.s, p0/m, z22.s, z7.s
+ fmla z15.s, p0/m, z23.s, z7.s
+
+ // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo
+ fmla z12.s, p0/m, z20.s, z8.s
+ fmla z13.s, p0/m, z21.s, z8.s
+ fmla z14.s, p0/m, z22.s, z8.s
+ fmla z15.s, p0/m, z23.s, z8.s
+
+ // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
+ dup z10.s, #23
+ urshl z16.s, p0/m, z16.s, z10.s
+ urshl z17.s, p0/m, z17.s, z10.s
+ urshl z18.s, p0/m, z18.s, z10.s
+ urshl z19.s, p0/m, z19.s, z10.s
+
+ // Processes the first 2 vectors.
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z12.s, z0.s
+ fmul z21.s, z13.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z12.s, z2.s
+ fmla z23.s, p0/m, z13.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z12.s, z4.s
+ fmla z25.s, p0/m, z13.s, z4.s
+
+ // ---------------------------------------------------------------- z12-z13: r2 = r * r
+ fmul z12.s, z12.s, z12.s
+ fmul z13.s, z13.s, z13.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z12.s, z24.s
+ fmla z23.s, p0/m, z13.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z12.s, z22.s
+ fmla z21.s, p0/m, z13.s, z23.s
+
+ // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
+ fmla z16.s, p0/m, z20.s, z16.s
+ fmla z17.s, p0/m, z21.s, z17.s
+
+ // Processes the last 2 vectors
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z14.s, z0.s
+ fmul z21.s, z15.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z14.s, z2.s
+ fmla z23.s, p0/m, z15.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z14.s, z4.s
+ fmla z25.s, p0/m, z15.s, z4.s
+
+ // ---------------------------------------------------------------- z14-z15: r2 = r * r
+ fmul z14.s, z14.s, z14.s
+ fmul z15.s, z15.s, z15.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z14.s, z24.s
+ fmla z23.s, p0/m, z15.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z14.s, z22.s
+ fmla z21.s, p0/m, z15.s, z23.s
+
+ // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
+ fmla z18.s, p0/m, z20.s, z18.s
+ fmla z19.s, p0/m, z21.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
+ dup z10.s, #0
+ sel z16.s, p1, z10.s, z16.s
+ sel z17.s, p2, z10.s, z17.s
+ sel z18.s, p3, z10.s, z18.s
+ sel z19.s, p4, z10.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = overflow ? inf : poly
+ dup z10.s, w10 // z10: inf
+ sel z16.s, p5, z10.s, z16.s
+ sel z17.s, p6, z10.s, z17.s
+ sel z18.s, p7, z10.s, z18.s
+ sel z19.s, p8, z10.s, z19.s
+
+ // 1 / 1 + poly
+ dup z10.s, w15 // z10: 1
+ fadd z16.s, z10.s, z16.s // poly + 1
+ fadd z17.s, z10.s, z17.s // poly + 1
+ fadd z18.s, z10.s, z18.s // poly + 1
+ fadd z19.s, z10.s, z19.s // poly + 1
+
+ fdivr z16.s, p0/m, z16.s, z10.s // z16: 1/(poly+1)
+ fdivr z17.s, p0/m, z17.s, z10.s // z16: 1/(poly+1)
+ fdivr z18.s, p0/m, z18.s, z10.s // z16: 1/(poly+1)
+ fdivr z19.s, p0/m, z19.s, z10.s // z16: 1/(poly+1)
+
+ // Stores 4 consecutive registers to the output
+ .inst 0xa029c790 // st1w {z16.s-z19.s}, pn9, [x28, x9, LSL #2]
+
+ incw x9, ALL, MUL #4
+ b inner_body_start%=
+inner_body_end%=:
+
+inner_leftover_start%=:
+ // Largely ordinary Sve code to handle taylor series 1/1+e^-x for leftover loop.
+ whilelo p1.s, x9, %x[length] // While x9<length
+ b.none inner_leftover_end%=
+
+ ld1w z12.s, p1/z, [x27, x9, LSL #2] // x12: input_data (LOADS POINTERS)
+ fneg z12.s, p1/m, z12.s
+
+ mov z16.d, z5.d // z16: shift
+ fcmlt p4.s, p1/z, z12.s, z9.s // p4: underflow = x < min_input
+ fcmlt p5.s, p1/z, z26.s, z12.s // p5: overflow = x > max_input
+
+ fmla z16.s, p1/m, z12.s, z6.s // z16: z = shift + x * inv_ln2
+ fsub z20.s, z16.s, z5.s // z20: n = z - shift
+ fmla z12.s, p1/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p1/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo
+ dup z10.s, #23 // z10: 23
+ urshl z16.s, p1/m, z16.s, z10.s // z16: scale = z << 23 (2^n)
+ fmul z20.s, z12.s, z0.s // z20: p1 = r * c1
+ mov z22.d, z1.d // z22: p23 = c2
+ fmla z22.s, p1/m, z12.s, z2.s // z22: p23 = c2 + r * c3
+ mov z24.d, z3.d // z24: c4
+ fmla z24.s, p1/m, z12.s, z4.s // z24: p45 = c4 + r * c5
+ fmul z12.s, z12.s, z12.s // z12: r2 = r * r
+ fmla z22.s, p1/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45
+ fmla z20.s, p1/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345
+ fmla z16.s, p1/m, z20.s, z16.s // z16: poly = scale + p12345 * scale
+ dup z10.s, #0 // z10: 0
+ sel z16.s, p4, z10.s, z16.s // z16: poly = underflow ? 0 : poly
+ dup z10.s, w10
+ sel z16.s, p5, z10.s, z16.s // z16: poly = overflow ? inf : poly
+
+ // 1 / 1+poly
+ dup z10.s, w15 // z10: 1
+ fadd z16.s, z10.s, z16.s // z16: z16 + 1
+ fdivr z16.s, p0/m, z16.s, z10.s // z16: 1/(poly+1)
+
+ st1w z16.s, p1, [x28, x9, LSL #2]
+
+ incw x9 // each word + 1
+ b inner_leftover_start%=
+inner_leftover_end%=:
+
+ // ==================================================
+ // Outer loop closing
+ // ==================================================
+
+ add x27, x27, %x[src_stride_1]
+ add x28, x28, %x[dst_stride_1]
+ b outer_loop_start%=
+outer_loop_end%=:
+
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : [src] "r"(src), [dst] "r"(dst), [shape_1] "r"(shape[1]), [src_stride_1] "r"(src_strides[1]),
+ [dst_stride_1] "r"(dst_strides[1]), [length] "r"(shape[0])
+ : "cc", "memory", //
+ "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p9", //
+ "x9", "x10", "x11", "x12", "x13", "x14", "x15", //
+ "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_fp32_logistic(const ITensor *in, ITensor *out, const ActivationLayerInfo &act_info, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(act_info);
+ const auto *src_info = in->info();
+ const auto *dst_info = out->info();
+
+ const auto &src_strides = src_info->strides_in_bytes();
+ const auto &dst_strides = dst_info->strides_in_bytes();
+
+ // Iterator calculates pointer offsets and takes into account padding.
+ Iterator input(in, window);
+ Iterator output(out, window);
+
+ // NOTE: This kernel uses collapsed 2D shapes.
+ // The excecution window is expected to be pre-collapsed in kernel configure(...) function.
+ const uintptr_t k_shape[] = {window.num_iterations(0), window.num_iterations(1)};
+
+ const uintptr_t k_src_strides[] = {src_strides[0], src_strides[1]};
+ const uintptr_t k_dst_strides[] = {dst_strides[0], dst_strides[1]};
+
+ const auto *k_src = reinterpret_cast<const float *>(input.ptr());
+ auto *k_dst = reinterpret_cast<float *>(output.ptr());
+
+ sme2_f32_logistic_kernel(k_src, k_dst, k_shape, k_src_strides, k_dst_strides);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/cpu/kernels/logistic/list.h b/src/cpu/kernels/logistic/list.h
new file mode 100644
index 0000000000..7893a61248
--- /dev/null
+++ b/src/cpu/kernels/logistic/list.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_LOGISTIC_LIST_H
+#define ACL_SRC_CPU_KERNELS_LOGISTIC_LIST_H
+
+namespace arm_compute
+{
+namespace cpu
+{
+#define DECLARE_LOGISTIC_KERNEL(func_name) \
+ void func_name(const ITensor *src, ITensor *dst, const ActivationLayerInfo &act_info, const Window &window)
+
+#ifdef __aarch64__
+DECLARE_LOGISTIC_KERNEL(sme2_fp32_logistic);
+#endif // __aarch64__
+
+#undef DECLARE_LOGISTIC_KERNEL
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_KERNELS_LOGISTIC_LIST_H
diff --git a/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp b/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
index 344b9df0c8..c73d1def6b 100644
--- a/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
@@ -53,23 +53,24 @@ void mean_stddev_normalization<float16_t, 8>(ITensor *input, ITensor *output, fl
auto in_ptr = reinterpret_cast<const float16_t *>(input_itr.ptr());
auto out_ptr = reinterpret_cast<float16_t *>(output_itr.ptr());
- float16x8_t sum_vec = vdupq_n_f16(static_cast<float16_t>(0.0f));
+ float32x4x2_t sum_vec = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
+
float32x4_t sum_sq_vec = vdupq_n_f32(0.0f);
for (; x <= (window_end_x - window_step_x); x += window_step_x)
{
float16x8_t data = vld1q_f16(in_ptr + x);
- sum_vec = vaddq_f16(sum_vec, data);
float32x4_t dl = vcvt_f32_f16(vget_low_f16(data));
float32x4_t dh = vcvt_f32_f16(vget_high_f16(data));
+ sum_vec.val[0] = vaddq_f32(sum_vec.val[0], dl);
+ sum_vec.val[1] = vaddq_f32(sum_vec.val[1], dh);
sum_sq_vec = vaddq_f32(sum_sq_vec, vmulq_f32(dl, dl));
sum_sq_vec = vaddq_f32(sum_sq_vec, vmulq_f32(dh, dh));
}
- float32x4_t sum_carry_res =
- vpaddq_f32(vcvt_f32_f16(vget_high_f16(sum_vec)), vcvt_f32_f16(vget_low_f16(sum_vec)));
- float sum = vaddvq_f32(sum_carry_res);
- float sum_sq = vaddvq_f32(sum_sq_vec);
+ float32x4_t sum_carry_res = vpaddq_f32(sum_vec.val[0], sum_vec.val[1]);
+ float sum = vaddvq_f32(sum_carry_res);
+ float sum_sq = vaddvq_f32(sum_sq_vec);
// Compute left-over elements
for (; x < window_end_x; ++x)
diff --git a/src/cpu/kernels/mul/generic/sme2/list.h b/src/cpu/kernels/mul/generic/sme2/list.h
new file mode 100644
index 0000000000..f6aecfb91c
--- /dev/null
+++ b/src/cpu/kernels/mul/generic/sme2/list.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_KERNELS_MUL_GENERIC_SME2_LIST_H
+#define ACL_SRC_CPU_KERNELS_MUL_GENERIC_SME2_LIST_H
+namespace arm_compute
+{
+namespace cpu
+{
+#define DECLARE_MUL_KERNEL(func_name) \
+ void func_name(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, float scale)
+
+DECLARE_MUL_KERNEL(sme2_q8_signed_mul);
+
+#undef DECLARE_MUL_KERNEL
+} // namespace cpu
+} // namespace arm_compute
+#endif // ACL_SRC_CPU_KERNELS_MUL_GENERIC_SME2_LIST_H
diff --git a/src/cpu/kernels/mul/generic/sme2/qasymm8_signed.cpp b/src/cpu/kernels/mul/generic/sme2/qasymm8_signed.cpp
new file mode 100644
index 0000000000..94fbaa1bb2
--- /dev/null
+++ b/src/cpu/kernels/mul/generic/sme2/qasymm8_signed.cpp
@@ -0,0 +1,410 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// Mul SME kernel
+void sme2_q8_signed_mul_kernel( //
+ const int8_t *src,
+ const int8_t *weights,
+ int8_t *dst,
+ const int16_t offset_a,
+ const int16_t offset_b,
+ const int16_t offset_c,
+ const float multiplier, // = (scale_a * scale_b * mul) / scale_c
+ const uintptr_t win_shape[4],
+ const uintptr_t src_strides[4],
+ const uintptr_t wei_strides[4],
+ const uintptr_t dst_strides[4])
+{
+ struct Args
+ {
+ uintptr_t shape1;
+ uintptr_t shape2;
+ uintptr_t shape3;
+ const int8_t *src;
+ const int8_t *wei;
+ int8_t *dst;
+ int multiplier14p18;
+ int offsetC14p18;
+ int16_t offsetA;
+ int16_t offsetB;
+ } args;
+
+ // Constants used to express values in the 14p18 fixed point format
+ constexpr int32_t two_pwr18i = 262144;
+ constexpr float two_pwr18f = 262144.f;
+
+ args.shape1 = win_shape[1];
+ args.shape2 = win_shape[2];
+ args.shape3 = win_shape[3];
+ args.src = src;
+ args.wei = weights;
+ args.dst = dst;
+ args.multiplier14p18 = static_cast<int>(multiplier * two_pwr18f);
+ args.offsetC14p18 = static_cast<int>(offset_c * two_pwr18i);
+ // Offsets a/b need to be negated as assembly kernel uses addition instructions where subtraction is needed.
+ // Offset C is not negated as it needs to be added rather than subtracted.
+ args.offsetA = offset_a * -1;
+ args.offsetB = offset_b * -1;
+
+ // Precondition:
+ assert(src_strides[0] == sizeof(int8_t));
+ assert(wei_strides[0] == sizeof(int8_t));
+ assert(dst_strides[0] == sizeof(int8_t));
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+ .inst 0x25207811 // ptrue pn9.b
+ ptrue p0.b
+
+ // ==================================================
+ // 3D loop opening
+ // ==================================================
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cntb x8, ALL, MUL #2 // x13 is vl (of 16 bit values)
+ udiv x9, %x[length], x8 // length/vl
+ mul x8, x8, x9 // x13 = vl * result
+
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ ldr x10, [%[args_ptr], %[offset_shape_3]]
+ ldr x11, [%[args_ptr], %[offset_src_ptr]]
+ ldr x12, [%[args_ptr], %[offset_wei_ptr]]
+ ldr x13, [%[args_ptr], %[offset_dst_ptr]]
+
+ // Could potentially be replaced with explicit loads.
+ ld1rh {z1.h}, p0/z, [%[args_ptr], %[offset_A_offset]]
+ ld1rh {z2.h}, p0/z, [%[args_ptr], %[offset_B_offset]]
+ ld1rw {z3.s}, p0/z, [%[args_ptr], %[multiplier_offset]]
+
+loop_3_start%=:
+ // for index_3 in shape_3 downto 1
+ cmp x10, #0
+ b.eq loop_3_end%=
+ sub x10, x10, #1
+
+ ldr x14, [%[args_ptr], %[offset_shape_2]]
+ mov x15, x11
+ mov x16, x12
+ mov x17, x13
+
+loop_2_start%=:
+ // for index_2 in shape_2 downto 1
+ cmp x14, #0
+ b.eq loop_2_end%=
+ sub x14, x14, #1
+
+ ldr x7, [%[args_ptr], %[offset_shape_1]]
+ mov x20, x15
+ mov x21, x16
+ mov x22, x17
+
+loop_1_start%=:
+ // for index_1 in shape_2 downto 1
+ cmp x7, #0
+ b.eq loop_1_end%=
+ sub x7, x7, #1
+
+ mov x9, #0 // x9: index/count
+
+inner_loop_body_start%=:
+ cmp x9, x8
+ b.eq inner_loop_body_end%=
+
+ // WIDEN LOAD. LOAD 4 Z-REGS FOR BOTH A/B
+
+ // NOTE: INSTEAD OF LOADING 4 LOAD 2 due to REG LIMITATIONS
+ .inst 0xa0090684 // ld1b {z4.b-z5.b}, pn9/z, [x20, x9]
+ .inst 0xa00906a6 // ld1b {z6.b-z7.b}, pn9/z, [x21, x9]
+
+ // Widen to 16 bits
+ .inst 0xc175e08c // sunpk {z12.h-z15.h}, {z4.b-z5.b} // (a)
+ .inst 0xc175e0d0 // sunpk {z16.h-z19.h}, {z6.b-z7.b} // (b)
+
+ // Apply offset to all registers in 16-bit
+ .inst 0xc161ab0c // add {z12.h-z15.h}, {z12.h-z15.h}, z1.h //a
+ .inst 0xc162ab10 // add {z16.h-z19.h}, {z16.h-z19.h}, z2.h //b
+
+ // Widen to 32-bit now.
+ // 12-19 are taken
+ // 4-11 a, 20-27 b
+ .inst 0xc1b5e184 // sunpk {z4.s-z7.s}, {z12.h-z13.h} //a
+ .inst 0xc1b5e1c8 // sunpk {z8.s-z11.s}, {z14.h-z15.h}
+ .inst 0xc1b5e214 // sunpk {z20.s-z23.s}, {z16.h-z17.h} //b
+ .inst 0xc1b5e258 // sunpk {z24.s-z27.s}, {z18.h-z19.h}
+
+ // Multiply a*b in int32
+ // Output in z4-z11
+ MUL z4.s, z4.s, z20.s
+ MUL z5.s, z5.s, z21.s
+ MUL z6.s, z6.s, z22.s
+ MUL z7.s, z7.s, z23.s
+ MUL z8.s, z8.s, z24.s
+ MUL z9.s, z9.s, z25.s
+ MUL z10.s, z10.s, z26.s
+ MUL z11.s, z11.s, z27.s
+
+ // offsets
+ dup z12.s, %w[offset_C]
+ dup z13.s, %w[offset_C]
+ dup z14.s, %w[offset_C]
+ dup z15.s, %w[offset_C]
+ dup z16.s, %w[offset_C]
+ dup z17.s, %w[offset_C]
+ dup z18.s, %w[offset_C]
+ dup z19.s, %w[offset_C]
+
+ // MLA Fixed Point multiplication integer
+ MLA z12.s, p0/m, z4.s, z3.s
+ MLA z13.s, p0/m, z5.s, z3.s
+ MLA z14.s, p0/m, z6.s, z3.s
+ MLA z15.s, p0/m, z7.s, z3.s
+ MLA z16.s, p0/m, z8.s, z3.s
+ MLA z17.s, p0/m, z9.s, z3.s
+ MLA z18.s, p0/m, z10.s, z3.s
+ MLA z19.s, p0/m, z11.s, z3.s
+
+ // Int32 to Int8 saturate
+ .inst 0xc16eda05 // sqrshr z5.b, {z16.s-z19.s}, #18
+ .inst 0xc16ed984 // sqrshr z4.b, {z12.s-z15.s}, #18
+ // Store
+ .inst 0xa02906c4 // st1b {z4.b-z5.b}, pn9, [x22, x9]
+
+ incb x9, ALL, MUL #2
+ b inner_loop_body_start%=
+inner_loop_body_end%=:
+
+inner_loop_leftover_start%=:
+ whilelo p1.b, x9, %x[length] // While x9<length
+ b.none inner_loop_leftover_end%=
+
+ // HANDLE MULTIPLICATION HERE
+ ld1b z4.b, p1/z, [x20, x9] // z4: a input_data
+ ld1b z5.b, p1/z, [x21, x9] // z5: b input_data
+
+ // Widen register z4 (a)
+ sunpklo z6.h, z4.b // lower as 16 bits
+ sunpkhi z7.h, z4.b // upper as 16 bits
+
+ // Widen register z5 (b)
+ sunpklo z8.h, z5.b // lower as 16 bits
+ sunpkhi z9.h, z5.b // upper as 16 bits
+
+ // Apply offset in 16bit maths to all resulting vectors.
+ add z6.h, z6.h, z1.h //a
+ add z7.h, z7.h, z1.h
+ add z8.h, z8.h, z2.h //b
+ add z9.h, z9.h, z2.h
+
+ // Widen a,b to 32-bit z-registers.
+ // Multiply a and b and store result as 32 bit int.
+ // a lower - 32-bit
+ sunpklo z10.s, z6.h
+ sunpkhi z11.s, z6.h
+ // a upper - 32-bit
+ sunpklo z12.s, z7.h
+ sunpkhi z13.s, z7.h
+
+ // b lower - 32-bit
+ sunpklo z14.s, z8.h
+ sunpkhi z15.s, z8.h
+ // b upper - 32-bit
+ sunpklo z16.s, z9.h
+ sunpkhi z17.s, z9.h
+
+ // offsets
+ dup z4.s, %w[offset_C]
+ dup z5.s, %w[offset_C]
+ dup z6.s, %w[offset_C]
+ dup z7.s, %w[offset_C]
+
+ // Multiply a*b (lower) in int32
+ MUL z10.s, z10.s, z14.s
+ MUL z11.s, z11.s, z15.s
+
+ // Multiply a*b (upper) in int32
+ MUL z12.s, z12.s, z16.s
+ MUL z13.s, z13.s, z17.s
+
+ // Still int32 here.
+ // Now MLA in fixed point
+ MLA z4.s, p0/m, z10.s, z3.s
+ MLA z5.s, p0/m, z11.s, z3.s
+ MLA z6.s, p0/m, z12.s, z3.s
+ MLA z7.s, p0/m, z13.s, z3.s
+
+ // Right shift, no narrow
+ LSR z20.s, z4.s, #8
+ LSR z21.s, z5.s, #8
+ LSR z22.s, z6.s, #8
+ LSR z23.s, z7.s, #8
+
+ // Right shift rounding (lower)
+ // Do not saturate.
+ RSHRNB z20.h, z20.s, #8
+ RSHRNB z21.h, z21.s, #8
+ UZP1 z25.h, z20.h, z21.h
+ // Right shift upper.
+ RSHRNB z22.h, z22.s, #8
+ RSHRNB z23.h, z23.s, #8
+ UZP1 z26.h, z22.h, z23.h
+
+ // Shift again to 8 bit both vectors. Recombine.
+ SQRSHRNB z25.b, z25.h, #2
+ SQRSHRNB z26.b, z26.h, #2
+ UZP1 z27.b, z25.b, z26.b
+
+ st1b z27.b, p1, [x22, x9]
+
+ incb x9 // x9 : x9 += sizeof(element) * predicate_count
+ b inner_loop_leftover_start%=
+inner_loop_leftover_end%=:
+
+ // ==================================================
+ // 3D loop closing
+ // ==================================================
+
+ add x20, x20, %[src_stride_1]
+ add x21, x21, %[wei_stride_1]
+ add x22, x22, %[dst_stride_1]
+ b loop_1_start%=
+loop_1_end%=:
+
+ add x15, x15, %[src_stride_2]
+ add x16, x16, %[wei_stride_2]
+ add x17, x17, %[dst_stride_2]
+ b loop_2_start%=
+loop_2_end%=:
+
+ add x11, x11, %[src_stride_3]
+ add x12, x12, %[wei_stride_3]
+ add x13, x13, %[dst_stride_3]
+ b loop_3_start%=
+loop_3_end%=:
+
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : // The following arguments are loaded via arg ptr values and a constant offset.
+ [args_ptr] "r"(&args), [offset_src_ptr] "I"(offsetof(Args, src)), [offset_wei_ptr] "I"(offsetof(Args, wei)),
+ [offset_dst_ptr] "I"(offsetof(Args, dst)), [offset_shape_1] "I"(offsetof(Args, shape1)),
+ [offset_shape_2] "I"(offsetof(Args, shape2)), [offset_shape_3] "I"(offsetof(Args, shape3)),
+ [multiplier_offset] "I"(offsetof(Args, multiplier14p18)), //
+ [offset_A_offset] "I"(offsetof(Args, offsetA)), //
+ [offset_B_offset] "I"(offsetof(Args, offsetB)), //
+ // Use registers for efficiency sake.
+ [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), [src_stride_3] "r"(src_strides[3]),
+ [wei_stride_1] "r"(wei_strides[1]), [wei_stride_2] "r"(wei_strides[2]), [wei_stride_3] "r"(wei_strides[3]),
+ [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), [dst_stride_3] "r"(dst_strides[3]),
+ [offset_C] "r"(args.offsetC14p18), //
+ [length] "r"(win_shape[0])
+ : "cc", "memory", //
+ "p0", "p1", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22",
+ "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_q8_signed_mul(const ITensor *in0, const ITensor *in1, ITensor *out, const Window &window, const float scale)
+{
+ const auto *src_info = in0->info();
+ const auto *src2_info = in1->info();
+ const auto *dst_info = out->info();
+
+ const UniformQuantizationInfo src_q_info = src_info->quantization_info().uniform();
+ const UniformQuantizationInfo src2_q_info = src2_info->quantization_info().uniform();
+ const UniformQuantizationInfo dst_q_info = dst_info->quantization_info().uniform();
+
+ const auto &src_strides_bytes = src_info->strides_in_bytes();
+ const auto &wei_strides_bytes = src2_info->strides_in_bytes();
+ const auto &dst_strides_bytes = dst_info->strides_in_bytes();
+
+ // NOTE: This kernel does not support shapes above 4D (Unless excecution window has been collapsed)
+ assert(window.num_iterations(4) == 1 && window.num_iterations(5) == 1);
+
+ // Note : The window is expected to handle y-broadcasting by setting relevant strides to 0.
+ const uintptr_t shape[] = {
+ window.num_iterations(0),
+ window.num_iterations(1),
+ window.num_iterations(2),
+ window.num_iterations(3),
+ };
+
+ Window input1_win = window.broadcast_if_dimension_le_one(src_info->tensor_shape());
+ Window input2_win = window.broadcast_if_dimension_le_one(src2_info->tensor_shape());
+
+ // First dim is always datasize. If broadcasting in other dims, set stride to 0.
+ uintptr_t src_strides[] = {src_strides_bytes[0], (input1_win.is_broadcasted(1)) ? 0 : src_strides_bytes[1],
+ (input1_win.is_broadcasted(2)) ? 0 : src_strides_bytes[2],
+ (input1_win.is_broadcasted(3)) ? 0 : src_strides_bytes[3]};
+ uintptr_t wei_strides[] = {wei_strides_bytes[0], (input2_win.is_broadcasted(1)) ? 0 : wei_strides_bytes[1],
+ (input2_win.is_broadcasted(2)) ? 0 : wei_strides_bytes[2],
+ (input2_win.is_broadcasted(3)) ? 0 : wei_strides_bytes[3]};
+
+ const uintptr_t dst_strides[] = {
+ dst_strides_bytes[0],
+ dst_strides_bytes[1],
+ dst_strides_bytes[2],
+ dst_strides_bytes[3],
+ };
+
+ const uintptr_t src_offset = window[0].start() * src_strides[0] + window[1].start() * src_strides[1] +
+ window[2].start() * src_strides[2] + window[3].start() * src_strides[3] +
+ in0->info()->offset_first_element_in_bytes();
+ const uintptr_t src2_offset = window[0].start() * wei_strides[0] + window[1].start() * wei_strides[1] +
+ window[2].start() * wei_strides[2] + window[3].start() * wei_strides[3] +
+ in1->info()->offset_first_element_in_bytes();
+ const uintptr_t dst_offset = window[0].start() * dst_strides[0] + window[1].start() * dst_strides[1] +
+ window[2].start() * dst_strides[2] + window[3].start() * dst_strides[3] +
+ out->info()->offset_first_element_in_bytes();
+
+ const auto *src = reinterpret_cast<const int8_t *>(in0->buffer() + src_offset);
+ const auto *src2 = reinterpret_cast<const int8_t *>(in1->buffer() + src2_offset);
+ auto *dst = reinterpret_cast<int8_t *>(out->buffer() + dst_offset);
+
+ // Calculate or retrieve necessary offsets/scale values.
+ const int16_t offset_a = src_q_info.offset;
+ const int16_t offset_b = src2_q_info.offset;
+ float multiplier = (src_q_info.scale * src2_q_info.scale * scale) / dst_q_info.scale;
+
+ sme2_q8_signed_mul_kernel(src, src2, dst, offset_a, offset_b, dst_q_info.offset, multiplier, shape, src_strides,
+ wei_strides, dst_strides);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/cpu/operators/CpuConv2d.h b/src/cpu/operators/CpuConv2d.h
index 3f98e71896..0012ff6609 100644
--- a/src/cpu/operators/CpuConv2d.h
+++ b/src/cpu/operators/CpuConv2d.h
@@ -85,6 +85,7 @@ public:
* |F16 |F16 |F16 |F16 |
* |F32 |F32 |F32 |F32 |
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
+ * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
@@ -93,7 +94,7 @@ public:
* while every optional dimension from 4 and above represent a batch of inputs.
* Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM].
- * Data type supported: Same as @p src, also could be QSYMM8_PER_CHANNEL if input is QASYMM8/QASYMM8_SIGNED.
+ * Data type supported: Same as @p src, also could be QSYMM8_PER_CHANNEL or QASYMM8_SIGNED if input is QASYMM8/QASYMM8_SIGNED.
* @param[in] biases Biases tensor info. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
* Data type supported: Same as @p src, except for input of QASYMM8/QASYMM8_SIGNED type where biases should be of S32 type.
* @param[out] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs.
@@ -139,7 +140,7 @@ public:
* while every optional dimension from 4 and above represent a batch of inputs.
* Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM].
- * Data type supported:Same as @p src, also could be QSYMM8_PER_CHANNEL if input is QASYMM8/QASYMM8_SIGNED.
+ * Data type supported:Same as @p src, also could be QSYMM8_PER_CHANNEL or QASYMM8_SIGNED if input is QASYMM8/QASYMM8_SIGNED.
* @param[in] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs.
* Data types supported: Same as @p src.
* @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo.
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index f3b78f8885..4ef5722438 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -227,6 +227,7 @@ CpuGemmConv2d::CpuGemmConv2d()
_is_prepared(false),
_wt_method(WeightTransformMethod::ReshapeThenTranspose),
_run_wt(true),
+ _act_info(),
_aux_mem(AuxTensorIdx::Count)
{
}
@@ -278,6 +279,7 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src,
int32_t min_activation = type_min.get<int32_t>();
int32_t max_activation = type_max.get<int32_t>();
+ _act_info = act_info;
if (supported_acts.count(act_info.activation()) != 0)
{
std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act_info, data_type, uoqinfo);
@@ -291,11 +293,14 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src,
output_info.is_quantized_per_channel = (tmp_weights.data_type() == DataType::QSYMM8_PER_CHANNEL);
quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info);
+ const GEMMInfo gemm_info =
+ GEMMInfo(false /* is_a_reshaped */, false /* is_b_reshaped */, true /* reshape_b_only_on_first_run */,
+ gemm_3d_depth, _skip_im2col, false /* retain_internal_weights */, output_info,
+ false /* fp_mixed_precision */, enable_fast_math, false /* broadcast_bias */, act_info,
+ fixed_format, weight_format, false /* pretranspose_B. TODO: COMPMID-6596 */);
+
_mm_gemmlowp = std::make_unique<CpuGemmLowpMatrixMultiplyCore>();
- _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst,
- GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false,
- enable_fast_math, false, act_info, fixed_format, weight_format,
- false /* pretranspose_B. TODO: COMPMID-6596 */));
+ _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, gemm_info);
auto mm_mem_req = _mm_gemmlowp->workspace();
for (unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
@@ -306,7 +311,7 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src,
else
{
// Create GEMMInfo structure
- const GEMMInfo &gemm_info =
+ const GEMMInfo gemm_info =
GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth,
_skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, false,
GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, fixed_format, weight_format,
@@ -800,6 +805,58 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src,
return Status{};
}
+void CpuGemmConv2d::update_quantization_parameters(ITensorPack &tensors)
+{
+ // Supported activations in GEMM
+ const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = {
+ ActivationLayerInfo::ActivationFunction::RELU, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
+ ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU};
+
+ auto src = tensors.get_const_tensor(ACL_SRC_0);
+ auto dst = tensors.get_tensor(ACL_DST);
+
+ auto wei = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+ TensorInfo tmp_src{*src->info()};
+ TensorInfo tmp_weights{*wei->info()};
+
+ const QuantizationInfo iqinfo = src->info()->quantization_info();
+ const QuantizationInfo wqinfo = wei->info()->quantization_info();
+ const QuantizationInfo oqinfo = (dst->info()->total_size() == 0) ? iqinfo : dst->info()->quantization_info();
+ const UniformQuantizationInfo uiqinfo = iqinfo.uniform();
+ const UniformQuantizationInfo uoqinfo = oqinfo.uniform();
+ const DataType data_type = src->info()->data_type();
+
+ tmp_src.set_quantization_info(QuantizationInfo(uiqinfo.scale, -uiqinfo.offset));
+ if (!is_data_type_quantized_per_channel(tmp_weights.data_type()))
+ {
+ const UniformQuantizationInfo uwqinfo = wqinfo.uniform();
+ tmp_weights.set_quantization_info(QuantizationInfo(uwqinfo.scale, -uwqinfo.offset));
+ }
+
+ // Merge activation with output stage
+ PixelValue type_min{};
+ PixelValue type_max{};
+ std::tie(type_min, type_max) = get_min_max(data_type);
+ int32_t min_activation = type_min.get<int32_t>();
+ int32_t max_activation = type_max.get<int32_t>();
+
+ if (supported_acts.count(_act_info.activation()) != 0)
+ {
+ std::tie(min_activation, max_activation) = get_quantized_activation_min_max(_act_info, data_type, uoqinfo);
+ }
+
+ GEMMLowpOutputStageInfo output_info;
+ output_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
+ output_info.gemmlowp_offset = uoqinfo.offset;
+ output_info.gemmlowp_min_bound = min_activation;
+ output_info.gemmlowp_max_bound = max_activation;
+ output_info.is_quantized_per_channel = (tmp_weights.data_type() == DataType::QSYMM8_PER_CHANNEL);
+ quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info);
+
+ _mm_gemmlowp->update_quantization_parameters(output_info, tmp_src.quantization_info(),
+ tmp_weights.quantization_info(), _is_prepared, true);
+}
+
void CpuGemmConv2d::run(ITensorPack &tensors)
{
prepare(tensors);
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index fa16ce860b..e4e34cc7c5 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -76,6 +76,7 @@ public:
* |F32 |F32 |F32 |F32 |
* |BFLOAT16 |BFLOAT16 |BFLOAT16 |BFLOAT16 |
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
+ * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
@@ -142,6 +143,12 @@ public:
const ActivationLayerInfo &act_info = ActivationLayerInfo(),
const bool enable_fast_math = false);
+ /** Update of quantization information at the run stage for convolution so that the quantization multipliers can be properly calculated.
+ *
+ * @param[in] tensors Vector that contains the tensors to operate on.
+ */
+ void update_quantization_parameters(ITensorPack &tensors);
+
// Inherited methods overridden:
void run(ITensorPack &tensors) override;
void prepare(ITensorPack &tensors) override;
@@ -292,6 +299,7 @@ private:
bool _is_prepared;
WeightTransformMethod _wt_method;
bool _run_wt;
+ ActivationLayerInfo _act_info;
experimental::MemoryRequirements _aux_mem{Count};
};
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index f3396fbb5c..0ea3c249df 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -128,24 +128,31 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_reshape_b_only_on_first_run;
_gemm_info = gemm_info;
- // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
- // It is not needed if the datatype is symmetric, because there is no offset
- bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic();
- bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic();
+ const ITensorInfo *a_to_use = a;
- _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
+ // Initialize assembly kernel meta-data
+ const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
- const ITensorInfo *a_to_use = a;
+ const int32_t offset_correction = 128;
+ const DataType dt = DataType::QASYMM8_SIGNED;
+ const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
+
+ _signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
+ QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
+
+ // If inputs are mixed-sign but this machine does not support mixed sign kernels,
+ // flip the sign so matched-sign kernels can be used.
+ if (!_flip_signedness && a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED &&
+ !bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, dst, asm_info)))
+ {
+ _flip_signedness = true;
+ }
+
+ _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
// Convert to QASYMM8 -> QASYMM8_SIGNED and back
if (_flip_signedness)
{
- const int32_t offset_correction = 128;
- const DataType dt = DataType::QASYMM8_SIGNED;
- const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
-
- _signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
- QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
_convert_to_signed_asymm = std::make_unique<kernels::CpuConvertQuantizedSignednessKernel>();
_convert_to_signed_asymm->configure(a_to_use, &_signed_a);
a_to_use = &_signed_a;
@@ -166,6 +173,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
matrix_a = &_signed_a;
}
+ // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
+ // It is not needed if the datatype is symmetric, because there is no offset
+ bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic();
+ bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic();
+
// If GEMMLowpOutputStage != NONE, fuse the offset contribution with the output stage
if (info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE)
{
@@ -173,8 +185,6 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_mm_result_s32 = TensorInfo(dst->tensor_shape(), 1, DataType::S32);
}
- // Initialize assembly kernel meta-data
- const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
#ifdef __aarch64__
if (!(!b->are_values_constant() &&
b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently.
@@ -375,10 +385,6 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
int32_t a_offset = a->quantization_info().uniform().offset;
int32_t b_offset = b->quantization_info().uniform().offset;
- // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
- bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic();
- bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic();
-
bool fuse_output_stage = info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE;
if (fuse_output_stage)
{
@@ -386,19 +392,31 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
a->clone()->set_tensor_shape(output->tensor_shape()).set_data_type(DataType::S32));
}
+ // Initialize assembly kernel meta-data
+ const AsmGemmInfo asm_info = init_assembly_metadata(info);
+
// Convert QASYMM8->QASYMM8_SIGNED
- TensorInfo signed_a{};
+ const int32_t offset_correction = 128;
+ const DataType dt = DataType::QASYMM8_SIGNED;
+ const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
+
+ TensorInfo signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
+ QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
TensorInfo signed_output{};
- bool flip_signedness = is_data_type_quantized_per_channel(b->data_type()) &&
+
+ bool flip_signedness = is_data_type_quantized_per_channel(b->data_type()) &&
(a->data_type() == DataType::QASYMM8) && info.reshape_b_only_on_first_run();
- if (flip_signedness)
+
+ // If inputs are mixed-sign but this machine does not support mixed sign kernels,
+ // flip the sign so matched-sign kernels can be used.
+ if (!flip_signedness && a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED &&
+ !bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info)))
{
- const int32_t offset_correction = 128;
- const DataType dt = DataType::QASYMM8_SIGNED;
- const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
+ flip_signedness = true;
+ }
- signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
- QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
+ if (flip_signedness)
+ {
ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuConvertQuantizedSignednessKernel::validate(a_to_use, &signed_a));
a_to_use = &signed_a;
a_offset = signed_a.quantization_info().uniform().offset;
@@ -418,8 +436,9 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
matrix_a_info = &signed_a;
}
- // Initialize assembly kernel meta-data
- const AsmGemmInfo asm_info = init_assembly_metadata(info);
+ // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
+ bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic();
+ bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic();
// Check if we need to run the optimized assembly kernel
bool run_optimised = false;
@@ -556,9 +575,12 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
kernels::CpuGemmLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output));
}
// Validate offset contribution kernel
- ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuGemmLowpOffsetContributionKernel::validate(
- output, a_offset_kernel_needed ? &info_vector_sum_col : nullptr,
- b_offset_kernel_needed ? &info_vector_sum_row : nullptr, a_offset, b_offset));
+ if (output->data_type() != DataType::QASYMM8 && output->data_type() != DataType::QASYMM8_SIGNED)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuGemmLowpOffsetContributionKernel::validate(
+ output, a_offset_kernel_needed ? &info_vector_sum_col : nullptr,
+ b_offset_kernel_needed ? &info_vector_sum_row : nullptr, a_offset, b_offset));
+ }
}
}
@@ -614,7 +636,6 @@ void CpuGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors)
if (_asm_glue->is_configured())
{
ITensorPack asm_glue_tensors = tensors;
- auto output_to_use = (_fuse_output_stage ? mm_result_s32.get() : dst);
if (is_data_type_quantized_asymmetric(a_to_use->info()->data_type()) &&
_gemm_info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
{
@@ -625,6 +646,7 @@ void CpuGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors)
}
else
{
+ auto output_to_use = (_fuse_output_stage ? mm_result_s32.get() : dst);
asm_glue_tensors.add_const_tensor(TensorType::ACL_SRC_0, a_to_use);
asm_glue_tensors.add_const_tensor(TensorType::ACL_SRC_1, b);
asm_glue_tensors.add_tensor(TensorType::ACL_DST, output_to_use);
@@ -775,5 +797,17 @@ experimental::MemoryRequirements CpuGemmLowpMatrixMultiplyCore::workspace() cons
{
return _aux_mem;
}
+
+void CpuGemmLowpMatrixMultiplyCore::update_quantization_parameters(const GEMMLowpOutputStageInfo &output_info,
+ const QuantizationInfo &a,
+ const QuantizationInfo &b,
+ const bool is_prepared,
+ const bool negated_offsets)
+{
+ auto lowp_os = output_info;
+ _gemm_info.set_gemmlowp_output_stage(lowp_os);
+ _asm_glue->update_quantization_parameters(output_info, a, b, is_prepared, negated_offsets);
+ _is_prepared = is_prepared;
+}
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
index 38121c9bb4..033979e93f 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
@@ -81,11 +81,13 @@ public:
* |src0 |src1 |src2 |dst |
* |:--------------|:------------------|:--------|:--------------|
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
+ * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8 |S32 |QASYMM8 |
* |QASYMM8 |QASYMM8 |S32 |S32 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |S32 |
* |QASYMM8 |QSYMM8 |S32 |S32 |
+ * |QASYMM8 |QASYMM8_SIGNED |F32 |F32 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8 |S32 |QASYMM8_SIGNED |
@@ -131,6 +133,11 @@ public:
void run(ITensorPack &tensors) override;
void prepare(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
+ void update_quantization_parameters(const GEMMLowpOutputStageInfo &output_info,
+ const QuantizationInfo &a,
+ const QuantizationInfo &b,
+ const bool is_prepared,
+ const bool negated_offsets);
private:
enum AuxTensorIdx
diff --git a/src/cpu/operators/CpuMatMul.cpp b/src/cpu/operators/CpuMatMul.cpp
index f68ae9883f..acc620edc6 100644
--- a/src/cpu/operators/CpuMatMul.cpp
+++ b/src/cpu/operators/CpuMatMul.cpp
@@ -215,6 +215,8 @@ void CpuMatMul::configure(ITensorInfo *lhs,
// Setup transpose LHS
_transpose_kernel_lhs = std::make_unique<cpu::kernels::CpuTransposeKernel>();
_transpose_kernel_lhs->configure(&lhs_to_use, &_lhs_transposed);
+
+ _aux_mem[TransposeLHS] = MemoryInfo(offset_int_vec(TransposeLHS), MemoryLifetime::Temporary, lhs->total_size());
}
if (_adj_rhs)
@@ -222,6 +224,8 @@ void CpuMatMul::configure(ITensorInfo *lhs,
// Setup transpose RHS
_transpose_kernel_rhs = std::make_unique<cpu::kernels::CpuTransposeKernel>();
_transpose_kernel_rhs->configure(&rhs_to_use, &_rhs_transposed);
+
+ _aux_mem[TransposeRHS] = MemoryInfo(offset_int_vec(TransposeRHS), MemoryLifetime::Temporary, rhs->total_size());
}
// 3. Configure assembly kernel using transposed tensors.
@@ -269,9 +273,6 @@ void CpuMatMul::configure(ITensorInfo *lhs,
_aux_mem[idx] = aux;
idx++;
}
- // Memory requirements for transposed tensors
- _aux_mem[TransposeLHS] = MemoryInfo(offset_int_vec(TransposeLHS), MemoryLifetime::Temporary, lhs->total_size());
- _aux_mem[TransposeRHS] = MemoryInfo(offset_int_vec(TransposeRHS), MemoryLifetime::Temporary, rhs->total_size());
}
void CpuMatMul::run(ITensorPack &tensors)
diff --git a/src/cpu/operators/CpuPermute.cpp b/src/cpu/operators/CpuPermute.cpp
index 25acc92d00..2d4e009d51 100644
--- a/src/cpu/operators/CpuPermute.cpp
+++ b/src/cpu/operators/CpuPermute.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,23 +23,91 @@
*/
#include "src/cpu/operators/CpuPermute.h"
+#include "arm_compute/core/CoreTypes.h"
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/ITensorInfo.h"
+
#include "src/common/utils/Log.h"
+#include "src/cpu/kernels/CpuCopyKernel.h"
#include "src/cpu/kernels/CpuPermuteKernel.h"
+#include "src/cpu/kernels/CpuTransposeKernel.h"
+
+#include <algorithm>
+#include <array>
+#include <memory>
namespace arm_compute
{
namespace cpu
{
+namespace
+{
+// Handle "No-op" cases
+bool prefer_copy(const PermutationVector &v)
+{
+ static const std::array<PermutationVector, 6> permutations = {{
+ PermutationVector(0U),
+ PermutationVector(0U, 1U),
+ PermutationVector(0U, 1U, 2U),
+ PermutationVector(0U, 1U, 2U, 3U),
+ PermutationVector(0U, 1U, 2U, 3U, 4U),
+ PermutationVector(0U, 1U, 2U, 3U, 4U, 5U),
+ }};
+
+ return std::find(permutations.begin(), permutations.end(), v) != permutations.end();
+}
+
+// Transpose kernel is optimized for permuting the first two dimensions of a tensor
+bool prefer_transpose(const PermutationVector &v)
+{
+ static const std::array<PermutationVector, 5> permutations = {{
+ PermutationVector(1U, 0U),
+ PermutationVector(1U, 0U, 2U),
+ PermutationVector(1U, 0U, 2U, 3U),
+ PermutationVector(1U, 0U, 2U, 3U, 4U),
+ PermutationVector(1U, 0U, 2U, 3U, 4U, 5U),
+ }};
+
+ return std::find(permutations.begin(), permutations.end(), v) != permutations.end();
+}
+} // namespace
+
void CpuPermute::configure(const ITensorInfo *src, ITensorInfo *dst, const PermutationVector &perm)
{
ARM_COMPUTE_LOG_PARAMS(src, dst, perm);
- auto k = std::make_unique<kernels::CpuPermuteKernel>();
- k->configure(src, dst, perm);
- _kernel = std::move(k);
+
+ if (prefer_copy(perm))
+ {
+ auto k = std::make_unique<kernels::CpuCopyKernel>();
+ k->configure(src, dst);
+ _kernel = std::move(k);
+ }
+ else if (prefer_transpose(perm))
+ {
+ auto k = std::make_unique<kernels::CpuTransposeKernel>();
+ k->configure(src, dst);
+ _kernel = std::move(k);
+ }
+ else
+ {
+ auto k = std::make_unique<kernels::CpuPermuteKernel>();
+ k->configure(src, dst, perm);
+ _kernel = std::move(k);
+ }
}
Status CpuPermute::validate(const ITensorInfo *src, const ITensorInfo *dst, const PermutationVector &perm)
{
+ if (prefer_copy(perm))
+ {
+ return kernels::CpuCopyKernel::validate(src, dst);
+ }
+
+ if (prefer_transpose(perm))
+ {
+ return kernels::CpuTransposeKernel::validate(src, dst);
+ }
+
return kernels::CpuPermuteKernel::validate(src, dst, perm);
}
} // namespace cpu
diff --git a/src/cpu/operators/CpuReshape.cpp b/src/cpu/operators/CpuReshape.cpp
index a423abb49a..fe4d0f9d7d 100644
--- a/src/cpu/operators/CpuReshape.cpp
+++ b/src/cpu/operators/CpuReshape.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,8 +53,8 @@ void CpuReshape::run(ITensorPack &tensors)
static_cast<kernels::CpuReshapeKernel *>(_kernel.get())->prepare(tensors);
_is_prepared = true;
}
- const auto split_dimension = static_cast<kernels::CpuReshapeKernel *>(_kernel.get())->get_split_dimension();
- NEScheduler::get().schedule_op(_kernel.get(), split_dimension, _kernel->window(), tensors);
+
+ ICpuOperator::run(tensors);
}
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/operators/CpuScatter.cpp b/src/cpu/operators/CpuScatter.cpp
new file mode 100644
index 0000000000..f82413d201
--- /dev/null
+++ b/src/cpu/operators/CpuScatter.cpp
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/cpu/operators/CpuScatter.h"
+
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/TensorInfo.h"
+
+#include "src/cpu/kernels/CpuScatterKernel.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+void CpuScatter::configure(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ ITensorInfo *dst,
+ const ScatterInfo &Scatter_info)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(src, indices, dst);
+ ARM_COMPUTE_UNUSED(src);
+ ARM_COMPUTE_UNUSED(updates);
+ ARM_COMPUTE_UNUSED(indices);
+ ARM_COMPUTE_UNUSED(dst);
+ ARM_COMPUTE_UNUSED(Scatter_info);
+}
+
+Status CpuScatter::validate(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ const ITensorInfo *dst,
+ const ScatterInfo &Scatter_info)
+{
+ ARM_COMPUTE_UNUSED(src);
+ ARM_COMPUTE_UNUSED(updates);
+ ARM_COMPUTE_UNUSED(indices);
+ ARM_COMPUTE_UNUSED(dst);
+ ARM_COMPUTE_UNUSED(Scatter_info);
+
+ return Status{ErrorCode::RUNTIME_ERROR, "No configuration implemented yet."};
+}
+
+void CpuScatter::run(ITensorPack &tensors)
+{
+ ARM_COMPUTE_UNUSED(tensors);
+}
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/operators/CpuScatter.h b/src/cpu/operators/CpuScatter.h
new file mode 100644
index 0000000000..d0161a778d
--- /dev/null
+++ b/src/cpu/operators/CpuScatter.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_CPU_OPERATORS_CPUSCATTER_H
+#define ACL_SRC_CPU_OPERATORS_CPUSCATTER_H
+
+#include "arm_compute/core/ITensorInfo.h"
+#include "arm_compute/function_info/ScatterInfo.h"
+
+#include "src/cpu/ICpuKernel.h"
+#include "src/cpu/ICpuOperator.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+/** Basic function to execute Scatter in Neon ™ */
+class CpuScatter : public ICpuOperator
+{
+public:
+ /** Initialise the kernel's inputs and output
+ *
+ * Valid data layouts:
+ * - All
+ *
+ * @note indices must always be U32
+ * @note src, updates and dst tensors must be same datatype.
+ *
+ * @param[in] src Source input tensor info. Can be nullptr when using "Add" Scatter Function with zero initialization.
+ * @param[in] updates Tensor info for tensor storing update values to use for scatter function. Data types supported: same as @p src.
+ * @param[in] indices Tensor info for tensor storing indices to use for scatter function. Data types supported: U32 only.
+ * @param[out] dst Output tensor to store the result of the Scatter Function. Data types supported: same as @p src and @p updates.
+ * @param[in] Scatter_info Contains Scatter operation information described in @ref ScatterInfo.
+ */
+ void configure(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ ITensorInfo *dst,
+ const ScatterInfo &Scatter_info);
+ /** Static function to check if given info will lead to a valid configuration
+ *
+ * Similar to @ref CpuScatter::configure()
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *src,
+ const ITensorInfo *updates,
+ const ITensorInfo *indices,
+ const ITensorInfo *dst,
+ const ScatterInfo &Scatter_info);
+
+ // Inherited methods overridden:
+ void run(ITensorPack &tensors) override;
+
+private:
+ std::unique_ptr<ICPPKernel> _scatter_kernel{nullptr};
+ std::unique_ptr<ICPPKernel> _fill_kernel{nullptr};
+};
+} // namespace cpu
+} // namespace arm_compute
+#endif // ACL_SRC_CPU_OPERATORS_CPUSCATTER_H
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index fb9bc15212..881142c374 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -45,6 +45,7 @@ namespace
/** Run pretranspose_B_array in parallel (1D static scheduling)
*
* @tparam TypeInput
+ * @tparam TypeWeight
* @tparam TypeOutput
*
* @param[in] gemm_asm GemmCommon kernel to run
@@ -54,14 +55,14 @@ namespace
* @param[in] src_multi_stride Stride in z ("multi")
* @param[in] num_threads Number of threads to run this method. Must be >= 1
*/
-template <typename TypeInput, typename TypeOutput>
-void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm,
- ITensor *dst,
- const TypeInput *src,
- int src_ld,
- int src_multi_stride,
- unsigned int num_threads,
- bool transpose)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
+void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *gemm_asm,
+ ITensor *dst,
+ const TypeWeight *src,
+ int src_ld,
+ int src_multi_stride,
+ unsigned int num_threads,
+ bool transpose)
{
ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr);
ARM_COMPUTE_ERROR_ON(num_threads == 0);
@@ -91,14 +92,6 @@ using namespace arm_compute::experimental;
namespace
{
-struct free_delete
-{
- void operator()(void *x)
- {
- free(x);
- }
-};
-
struct Params
{
unsigned int M;
@@ -113,14 +106,13 @@ struct Params
Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
- Params p;
- p.M = d->tensor_shape().y();
- p.K = a->tensor_shape().x();
- p.N = d->tensor_shape().x();
- p.batches = 1;
- p.multis = 1;
- p.sections = 1;
- p.indirect = false;
+ Params p{/* M */ static_cast<unsigned int>(d->tensor_shape().y()),
+ /* N */ static_cast<unsigned int>(d->tensor_shape().x()),
+ /* K */ static_cast<unsigned int>(a->tensor_shape().x()),
+ /* batches */ 1,
+ /* multis */ 1,
+ /* sections */ 1,
+ /* indirect */ false};
if (info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)
{
@@ -172,13 +164,10 @@ IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataTyp
}
/** Fallback in case ACL doesn't have a function */
-template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
class Fallback : public CpuGemmAssemblyDispatch::IFallback
{
public:
- /** Destructor */
- ~Fallback() = default;
-
/** Initialise the functions's input and output.
*
* @param[in] a Input tensor containing the Matrix A.
@@ -222,12 +211,45 @@ public:
bool isVarWeightsKernel() const override
{
if (!_gemm_kernel_asm)
+ {
return false;
+ }
const arm_compute::WeightFormat wf =
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY;
}
+ void update_quantization_parameters(const GEMMLowpOutputStageInfo &output_info,
+ const QuantizationInfo &a,
+ const QuantizationInfo &b,
+ const bool is_prepared,
+ const bool negated_offsets) override
+ {
+ const int32_t negation = negated_offsets ? 1 : -1;
+ const int32_t a_offset = -a.uniform().offset * negation;
+ const int32_t b_offset = -b.uniform().offset * negation;
+
+ arm_gemm::Requantize32 gemm_requant_info{};
+ if (output_info.gemmlowp_shifts.size() > 1)
+ {
+ const auto requantize_data =
+ this->set_requantize_data(output_info.gemmlowp_multipliers, output_info.gemmlowp_shifts);
+ gemm_requant_info = arm_gemm::Requantize32(
+ nullptr, 0, a_offset, b_offset, output_info.gemmlowp_offset,
+ (std::get<0>(requantize_data)) ? std::get<1>(requantize_data) : nullptr, std::get<2>(requantize_data),
+ std::get<3>(requantize_data), output_info.gemmlowp_min_bound, output_info.gemmlowp_max_bound);
+ }
+ else
+ {
+ gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, a_offset, b_offset, output_info.gemmlowp_offset,
+ -output_info.gemmlowp_shift, output_info.gemmlowp_multiplier,
+ output_info.gemmlowp_min_bound, output_info.gemmlowp_max_bound);
+ }
+
+ _gemm_kernel_asm->update_quantization_parameters(gemm_requant_info);
+ _is_prepared = is_prepared;
+ }
+
private:
enum AuxTensorIdx
{
@@ -251,7 +273,7 @@ private:
/** Operator to transpose B before gemm or pretranspose_B_array*/
std::unique_ptr<CpuTranspose> _pre_pretranspose_b{nullptr};
/** Assembly Gemm kernel */
- std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{nullptr};
+ std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput>> _gemm_kernel_asm{nullptr};
/** Optimised Arm® Neon™ kernel */
std::unique_ptr<INEKernel> _optimised_kernel{nullptr};
/** Assembly GEMM workspace tensor info */
@@ -273,22 +295,22 @@ private:
/** Per channel quantization multipliers */
std::vector<int32_t> _multipliers{};
/** Indirect buffer */
- std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
- std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
- experimental::MemoryRequirements _aux_mem{Count};
- bool _B_pretranspose_required{false};
- bool _is_b_constant{true};
- bool _is_c_constant{true};
- bool _run_pre_pretranspose_b{false};
- bool _B_pre_pretranspose_required{false};
+ std::vector<const TypeInput *const *> _indirect_arg{};
+ std::vector<const TypeInput *> _indirect_buf{};
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{Count};
+ bool _B_pretranspose_required{false};
+ bool _is_b_constant{true};
+ bool _is_c_constant{true};
+ bool _run_pre_pretranspose_b{false};
+ bool _B_pre_pretranspose_required{false};
};
-template <typename TypeInput, typename TypeOutput, class OutputStage>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
-Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
- const std::vector<int32_t> &multipliers)
+Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
+ const std::vector<int32_t> &multipliers)
{
_multipliers = multipliers;
_shifts = shifts;
@@ -305,8 +327,8 @@ Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vec
return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data());
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors)
{
auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0);
const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer());
@@ -343,14 +365,12 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITens
if (input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height)
{
- _indirect_buf
- .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
+ _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
_indirect_pad.data();
}
else
{
- _indirect_buf
- .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
+ _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A);
}
}
@@ -361,11 +381,11 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITens
}
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a,
- const ITensorInfo *b,
- const ITensorInfo *d,
- const AsmGemmInfo &info)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *d,
+ const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect));
@@ -375,13 +395,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
zeropad = a->quantization_info().uniform().offset;
}
- const int64_t input_width = static_cast<int64_t>(a->tensor_shape()[1]);
- const int64_t input_height = static_cast<int64_t>(a->tensor_shape()[2]);
- const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]);
- const int64_t kernel_width = static_cast<int64_t>(b->tensor_shape()[2]);
- const int64_t kernel_height = static_cast<int64_t>(b->tensor_shape()[3]);
- const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]);
- const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]);
+ const auto input_width = static_cast<int64_t>(a->tensor_shape()[1]);
+ const auto input_height = static_cast<int64_t>(a->tensor_shape()[2]);
+ const auto input_channels = static_cast<int64_t>(a->tensor_shape()[0]);
+ const auto kernel_width = static_cast<int64_t>(b->tensor_shape()[2]);
+ const auto kernel_height = static_cast<int64_t>(b->tensor_shape()[3]);
+ const auto output_width = static_cast<int64_t>(d->tensor_shape()[1]);
+ const auto output_height = static_cast<int64_t>(d->tensor_shape()[2]);
_cp = {input_width,
input_height,
@@ -392,6 +412,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
output_height,
info.ps_info.stride().first,
info.ps_info.stride().second,
+ 1,
+ 1,
info.padding_top,
info.padding_left,
zeropad};
@@ -414,10 +436,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
const int multi_size = batch_size * batches;
const size_t multi_stride = multi_size / sizeof(TypeInputPtr);
- _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(
- reinterpret_cast<const TypeInput **>(malloc(multi_size * multis)));
- _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(
- reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches)));
+ _indirect_buf = std::vector<const TypeInput *>(multi_size * multis);
+ _indirect_arg = std::vector<const TypeInput *const *>(sizeof(TypeInput **) * kernel_hw * multis * batches);
_indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad));
// Set indirect argument
@@ -428,29 +448,28 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
{
for (int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++)
{
- (_indirect_arg.get())[pos++] =
- _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw;
+ _indirect_arg[pos++] = &_indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw];
}
}
}
- _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get());
+ _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.data());
}
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a,
- const ITensorInfo *b,
- const ITensorInfo *c,
- ITensorInfo *d,
- arm_gemm::GemmArgs args,
- const AsmGemmInfo &gemm_info,
- const OutputStage &os)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ ITensorInfo *d,
+ arm_gemm::GemmArgs args,
+ const AsmGemmInfo &gemm_info,
+ const OutputStage &os)
{
_is_b_constant = b->are_values_constant();
_is_c_constant = c ? c->are_values_constant() : true;
- _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os);
+ _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeWeight, TypeOutput, OutputStage>(args, os);
if (_gemm_kernel_asm == nullptr)
{
//configuration not supported: Leave function unconfigured:
@@ -460,7 +479,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
arm_gemm::GemmConfig gemm_cfg = _gemm_kernel_asm->get_config();
// arm_compute wrapper for the Gemm object (see above)
- auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>();
+ auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeWeight, TypeOutput>>();
ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
const size_t workspace_size = _gemm_kernel_asm->get_working_size();
@@ -549,8 +568,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
}
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
{
if (!_is_prepared)
{
@@ -588,17 +607,17 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
- const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
- const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
- b_to_use->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
+ const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
+ const auto in1_ptr = reinterpret_cast<const TypeWeight *>(
+ b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
- run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(
+ run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>(
_gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b,
NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose);
@@ -616,20 +635,20 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
}
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+bool Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::is_configured() const
{
return _optimised_kernel != nullptr;
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+experimental::MemoryRequirements Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::workspace() const
{
return _aux_mem;
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &tensors)
{
auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0);
auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
@@ -663,8 +682,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size();
auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes());
- const TypeInput *in1_ptr = nullptr;
- auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
+ const TypeWeight *in1_ptr = nullptr;
+ auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
const ITensor *b_to_use = b;
@@ -686,8 +705,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
{
ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
- in1_ptr =
- reinterpret_cast<const TypeInput *>(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes());
+ in1_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() +
+ b_to_use->info()->offset_first_element_in_bytes());
}
// If necessary, run pretranspose every time if either weights or biases are non-constant
@@ -706,8 +725,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
- const auto b_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
- b_to_use->info()->offset_first_element_in_bytes());
+ const auto b_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() +
+ b_to_use->info()->offset_first_element_in_bytes());
const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
@@ -720,7 +739,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
else
{
const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
- run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(
+ run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>(
_gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b,
NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose);
}
@@ -744,7 +763,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
if (split_dim != IScheduler::split_dimensions_all)
{
// Make sure the kernel does not expect more threads than we can actually spawn
- const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim);
+ const unsigned int num_iterations = _optimised_kernel->window().num_iterations(split_dim);
num_threads = std::min(num_iterations, num_threads);
}
_gemm_kernel_asm->set_nthreads(num_threads);
@@ -775,7 +794,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint);
}
-template <typename TypeInput, typename TypeOutput>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -794,12 +813,12 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
- auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput>>();
fallback->configure(a, b, c, d, args, info);
arm_gemm = std::move(fallback);
}
-template <typename TypeInput, typename TypeOutput>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -820,7 +839,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback>
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
- auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::DequantizeFloat>>();
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::DequantizeFloat>>();
// Configure requantization info
const GEMMLowpOutputStageInfo os_info = info.output_stage;
@@ -832,7 +851,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback>
arm_gemm = std::move(fallback);
}
-template <typename TypeInput, typename TypeOutput>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -852,7 +871,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
- auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::Requantize32>>();
// Configure requantization info
const int32_t negation = info.negated_offsets ? 1 : -1;
@@ -905,12 +924,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
- // TODO: Incorporate info.transpose_b COMPMID-6595
+ // TODO(COMPMID-6595): Incorporate info.transpose_b
switch (a->data_type())
{
case DataType::F32:
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<float, float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for F32 input");
break;
#ifdef __aarch64__
@@ -919,13 +938,22 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
if (d->data_type() == DataType::S32)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for U8/QASYMM8 input and U32 output");
}
+ else if (b->data_type() == DataType::QASYMM8_SIGNED)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<uint8_t, int8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf,
+ args, {})),
+ "We could not find an optimized kernel for U8 input with S8 weights and U8 output");
+ }
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf,
+ args, {})),
"We could not find an optimized kernel for U8 input and U8 output");
}
break;
@@ -934,13 +962,15 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
if (d->data_type() == DataType::S32)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<int8_t, int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
}
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<int8_t, int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for S8 input and S8 output");
}
break;
@@ -952,13 +982,15 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
if (d->data_type() == DataType::BFLOAT16)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf,
+ args, {})),
"We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output");
}
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for BFLOAT16 input and F32 output");
}
break;
@@ -968,7 +1000,8 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
#if defined(ENABLE_FP16_KERNELS)
case DataType::F16:
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<float16_t, float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for F16 input and F16 output");
break;
#endif /* ENABLE_FP16_KERNELS */
@@ -1009,7 +1042,7 @@ Status CpuGemmAssemblyDispatch::validate(
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
}
- else
+ else if (!(a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED))
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
}
@@ -1024,12 +1057,13 @@ Status CpuGemmAssemblyDispatch::validate(
"Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32,
"Only S32 output supported for S8 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 &&
- (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32),
- "Only QASYMM8/S32 output supported for QASYMM8 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ a->data_type() == DataType::QASYMM8 &&
+ (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32 && d->data_type() != DataType::F32),
+ "Only QASYMM8/S32/F32 output supported for QASYMM8 input");
arm_compute::WeightFormat expected_weight_format = arm_compute::WeightFormat::UNSPECIFIED;
const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
- if ((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY)
+ if (bool(ret) && expected_weight_format != arm_compute::WeightFormat::ANY)
{
// Correctness check: if the format expected by the kernel is
// not "any", make sure that the one found matches the format
@@ -1062,33 +1096,44 @@ void CpuGemmAssemblyDispatch::configure(
switch (a->data_type())
{
case DataType::F32:
- create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<float, float, float>(_arm_gemm, a, b, c, d, act, info);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
- if (d->data_type() == DataType::S32)
+ if (b->data_type() == DataType::S8 || b->data_type() == DataType::QASYMM8_SIGNED)
{
- create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
+ if (d->data_type() == DataType::F32)
+ {
+ create_arm_gemm_dequant<uint8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info);
+ }
+ else
+ {
+ create_arm_gemm_quant<uint8_t, int8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
+ }
+ }
+ else if (d->data_type() == DataType::S32)
+ {
+ create_arm_gemm<uint8_t, uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm_quant<uint8_t, uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
}
break;
case DataType::S8:
case DataType::QASYMM8_SIGNED:
if (d->data_type() == DataType::S32)
{
- create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<int8_t, int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
}
else if (d->data_type() == DataType::F32)
{
- create_arm_gemm_dequant<int8_t, float>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm_dequant<int8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm_quant<int8_t, int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
}
break;
#endif /* __aarch64__ */
@@ -1096,17 +1141,17 @@ void CpuGemmAssemblyDispatch::configure(
case DataType::BFLOAT16:
if (d->data_type() == DataType::BFLOAT16)
{
- create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<bfloat16, bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<bfloat16, bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
}
break;
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
#ifdef ENABLE_FP16_KERNELS
case DataType::F16:
- create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<float16_t, float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info);
break;
#endif /* ENABLE_FP16_KERNELS */
default:
@@ -1136,5 +1181,15 @@ experimental::MemoryRequirements CpuGemmAssemblyDispatch::workspace() const
ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
return _arm_gemm->workspace();
}
+
+void CpuGemmAssemblyDispatch::update_quantization_parameters(const GEMMLowpOutputStageInfo &output_info,
+ const QuantizationInfo &a,
+ const QuantizationInfo &b,
+ const bool is_prepared,
+ const bool negated_offsets)
+{
+ ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
+ _arm_gemm->update_quantization_parameters(output_info, a, b, is_prepared, negated_offsets);
+}
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 44c5c189a5..0b6f22d45a 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -28,6 +28,7 @@
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuOperator.h"
+#include "src/cpu/kernels/assembly/arm_gemm.hpp"
namespace arm_compute
{
@@ -81,12 +82,17 @@ public:
class IFallback
{
public:
- virtual void run(ITensorPack &tensors) = 0;
- virtual void prepare(ITensorPack &tensors) = 0;
- virtual experimental::MemoryRequirements workspace() const = 0;
- virtual bool is_configured() const = 0;
- virtual bool isVarWeightsKernel() const = 0;
- virtual ~IFallback() = default;
+ virtual void run(ITensorPack &tensors) = 0;
+ virtual void prepare(ITensorPack &tensors) = 0;
+ virtual experimental::MemoryRequirements workspace() const = 0;
+ virtual bool is_configured() const = 0;
+ virtual bool isVarWeightsKernel() const = 0;
+ virtual void update_quantization_parameters(const GEMMLowpOutputStageInfo &,
+ const QuantizationInfo &,
+ const QuantizationInfo &,
+ const bool,
+ const bool) = 0;
+ virtual ~IFallback() = default;
};
public:
@@ -185,6 +191,12 @@ public:
return _arm_gemm && _arm_gemm->isVarWeightsKernel();
}
+ void update_quantization_parameters(const GEMMLowpOutputStageInfo &output_info,
+ const QuantizationInfo &a,
+ const QuantizationInfo &b,
+ const bool is_prepared,
+ const bool negated_offsets);
+
// Inherited methods overridden:
void prepare(ITensorPack &tensors) override;
void run(ITensorPack &tensors) override;