diff options
Diffstat (limited to 'src')
52 files changed, 5203 insertions, 2134 deletions
diff --git a/src/BUILD.bazel b/src/BUILD.bazel index e3cac07de1..499e5642a6 100644 --- a/src/BUILD.bazel +++ b/src/BUILD.bazel @@ -119,6 +119,8 @@ filegroup( "cpu/kernels/lut/generic/sve2/u8.cpp", "cpu/kernels/softmax/generic/sme2/fp16.cpp", "cpu/kernels/softmax/generic/sme2/fp32.cpp", + "cpu/kernels/softmax/generic/sme2/qasymm8.cpp", + "cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp", "cpu/kernels/softmax/generic/sve2/impl.cpp"] + glob(["**/*.h", "**/*.hpp", @@ -816,9 +818,18 @@ filegroup( "cpu/kernels/pool3d/neon/fp32.cpp", "cpu/kernels/pool3d/neon/qasymm8.cpp", "cpu/kernels/pool3d/neon/qasymm8_signed.cpp", + "cpu/kernels/quantize/generic/neon/fp16.cpp", + "cpu/kernels/quantize/generic/neon/fp32.cpp", + "cpu/kernels/quantize/generic/neon/integer.cpp", + "cpu/kernels/quantize/generic/neon/vquantize.cpp", "cpu/kernels/range/generic/neon/fp16.cpp", "cpu/kernels/range/generic/neon/fp32.cpp", "cpu/kernels/range/generic/neon/integer.cpp", + "cpu/kernels/reduction_layer/generic/neon/fp16.cpp", + "cpu/kernels/reduction_layer/generic/neon/fp32.cpp", + "cpu/kernels/reduction_layer/generic/neon/integer.cpp", + "cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp", + "cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp", "cpu/kernels/roialign/generic/neon/fp16.cpp", "cpu/kernels/roialign/generic/neon/fp32.cpp", "cpu/kernels/roialign/generic/neon/qasymm8.cpp", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 984db79c18..8d63ab57a3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -340,6 +340,8 @@ target_sources( cpu/kernels/lut/generic/sve2/u8.cpp cpu/kernels/softmax/generic/sme2/fp16.cpp cpu/kernels/softmax/generic/sme2/fp32.cpp + cpu/kernels/softmax/generic/sme2/qasymm8.cpp + cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp cpu/kernels/softmax/generic/sve2/impl.cpp ) @@ -807,9 +809,18 @@ target_sources( cpu/kernels/pool3d/neon/fp32.cpp cpu/kernels/pool3d/neon/qasymm8.cpp cpu/kernels/pool3d/neon/qasymm8_signed.cpp + cpu/kernels/quantize/generic/neon/fp16.cpp + cpu/kernels/quantize/generic/neon/fp32.cpp + cpu/kernels/quantize/generic/neon/integer.cpp + cpu/kernels/quantize/generic/neon/vquantize.cpp cpu/kernels/range/generic/neon/fp16.cpp cpu/kernels/range/generic/neon/fp32.cpp cpu/kernels/range/generic/neon/integer.cpp + cpu/kernels/reduction_layer/generic/neon/fp16.cpp + cpu/kernels/reduction_layer/generic/neon/fp32.cpp + cpu/kernels/reduction_layer/generic/neon/integer.cpp + cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp + cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp cpu/kernels/roialign/generic/neon/fp16.cpp cpu/kernels/roialign/generic/neon/fp32.cpp cpu/kernels/roialign/generic/neon/qasymm8.cpp diff --git a/src/common/cpuinfo/CpuInfo.cpp b/src/common/cpuinfo/CpuInfo.cpp index 93f51e599a..809ab3e2c3 100644 --- a/src/common/cpuinfo/CpuInfo.cpp +++ b/src/common/cpuinfo/CpuInfo.cpp @@ -363,6 +363,8 @@ CpuInfo CpuInfo::build() isainfo.neon = get_hw_capability("hw.optional.neon"); isainfo.fp16 = get_hw_capability("hw.optional.neon_fp16"); isainfo.dot = get_hw_capability("hw.optional.arm.FEAT_DotProd"); + isainfo.bf16 = get_hw_capability("hw.optional.arm.FEAT_BF16"); + isainfo.i8mm = get_hw_capability("hw.optional.arm.FEAT_I8MM"); CpuInfo info(isainfo, cpus_model); return info; #elif defined(__aarch64__) && defined(_WIN64) /* #elif defined(__aarch64__) && defined(__APPLE__) */ diff --git a/src/core/CPP/CPPTypes.cpp b/src/core/CPP/CPPTypes.cpp index 9980db42f3..f6761f27b0 100644 --- a/src/core/CPP/CPPTypes.cpp +++ b/src/core/CPP/CPPTypes.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022 Arm Limited. + * Copyright (c) 2018-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,6 +28,7 @@ #include "src/common/cpuinfo/CpuInfo.h" #include "src/common/cpuinfo/CpuIsaInfo.h" +#include "src/core/NEON/kernels/arm_gemm/utils.hpp" namespace arm_compute { @@ -135,4 +136,14 @@ unsigned int CPUInfo::get_L2_cache_size() const { return _impl->L2_cache_size; } + +unsigned long CPUInfo::get_sme2_vector_length() const +{ +#ifdef ARM_COMPUTE_ENABLE_SME2 + return arm_gemm::utils::sme::get_vector_length<int8_t>(); +#else // ARM_COMPUTE_ENABLE_SME2 + return 0; +#endif // ARM_COMPUTE_ENABLE_SME2 +} + } // namespace arm_compute diff --git a/src/core/NEON/NEAsymm.h b/src/core/NEON/NEAsymm.h index 5f4d08d0f6..b93e64a0ef 100644 --- a/src/core/NEON/NEAsymm.h +++ b/src/core/NEON/NEAsymm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020, 2023 Arm Limited. + * Copyright (c) 2017-2020, 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_NEASYMM_H -#define ARM_COMPUTE_NEASYMM_H +#ifndef ACL_SRC_CORE_NEON_NEASYMM_H +#define ACL_SRC_CORE_NEON_NEASYMM_H #include "src/core/NEON/NEMath.h" #include "src/core/NEON/wrapper/intrinsics/intrinsics.h" @@ -637,10 +637,10 @@ inline int32x4x4_t vquantize_internal(const float32x4x4_t &qv, float scale, int3 const float32x4_t vinvscale = vdupq_n_f32(1.f / scale); const int32x4x4_t rf = {{ #ifdef __aarch64__ - vaddq_s32(vcvtaq_s32_f32(vmulq_f32(qv.val[0], vinvscale)), voffset), - vaddq_s32(vcvtaq_s32_f32(vmulq_f32(qv.val[1], vinvscale)), voffset), - vaddq_s32(vcvtaq_s32_f32(vmulq_f32(qv.val[2], vinvscale)), voffset), - vaddq_s32(vcvtaq_s32_f32(vmulq_f32(qv.val[3], vinvscale)), voffset), + vaddq_s32(vcvtnq_s32_f32(vmulq_f32(qv.val[0], vinvscale)), voffset), + vaddq_s32(vcvtnq_s32_f32(vmulq_f32(qv.val[1], vinvscale)), voffset), + vaddq_s32(vcvtnq_s32_f32(vmulq_f32(qv.val[2], vinvscale)), voffset), + vaddq_s32(vcvtnq_s32_f32(vmulq_f32(qv.val[3], vinvscale)), voffset), #else //__aarch64__ vaddq_s32(vcvtq_s32_f32(vmulq_f32(qv.val[0], vinvscale)), voffset), vaddq_s32(vcvtq_s32_f32(vmulq_f32(qv.val[1], vinvscale)), voffset), @@ -698,4 +698,4 @@ inline uint16x8x2_t vquantize_qasymm16(const float32x4x4_t &qv, const UniformQua } // namespace arm_compute #include "src/core/NEON/NEAsymm.inl" -#endif // ARM_COMPUTE_NEASYMM_H +#endif // ACL_SRC_CORE_NEON_NEASYMM_H diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp index 455d604b3b..5380e6ccce 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp +++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2023 Arm Limited. + * Copyright (c) 2017-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -31,1747 +31,221 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/Validate.h" +#include "src/core/common/Registrars.h" #include "src/core/CPP/Validate.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" #include "src/core/NEON/INEKernel.h" -#include "src/core/NEON/NEMath.h" #include "src/core/NEON/wrapper/wrapper.h" -#include "support/SaturateCast.h" - -#include <arm_neon.h> +#include "src/cpu/kernels/reduction_layer/generic/neon/list.h" namespace arm_compute { -namespace -{ -// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized -template <typename T> -void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0) -{ - if (std::is_same<T, uint8_t>::value) - { - auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2)); - wrapper::vstore(output.ptr() + offset, res); - } - else - { - auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2)); - wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res); - } -} - -template <typename T> -uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis) -{ - uint32x4_t mask{0}; - if (op == ReductionOperation::ARG_IDX_MIN) - { - mask = wrapper::vcgt(b, a); - } - else - { - mask = wrapper::vclt(b, a); - } - - uint32x4_t vec_idx = {idx, idx + 1, idx + 2, idx + 3}; - if (axis != 0) - { - vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - } - uint32x4x4_t res = {{wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0}}; - - return res; -} - -template <typename T> -uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis) -{ - uint32x4x4_t mask{{0}}; - uint8x16_t mask_u8{0}; - if (op == ReductionOperation::ARG_IDX_MIN) - { - mask_u8 = wrapper::vcgt(b, a); - } - else - { - mask_u8 = wrapper::vclt(b, a); - } - auto wide_u16_1 = - wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8))); - auto wide_u16_2 = - wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8))); - mask.val[0] = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1))); - mask.val[1] = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1))); - mask.val[2] = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2))); - mask.val[3] = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2))); - - uint32x4x4_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, - {idx + 4, idx + 5, idx + 6, idx + 7}, - {idx + 8, idx + 9, idx + 10, idx + 11}, - {idx + 12, idx + 13, idx + 14, idx + 15}}}; - if (axis != 0) - { - vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - } - uint32x4x4_t res = { - {vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]), vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]), - vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]), vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])}}; - - return res; -} - -// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. -template <typename T> -inline typename std::enable_if< - std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value, - typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type -calculate_min(T in) -{ - auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); - return wrapper::vpmin(pmin, pmin); -} - -// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. -template <typename T> -inline typename std::enable_if< - std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value, - typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type -calculate_min(T in) -{ - auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); - pmin = wrapper::vpmin(pmin, pmin); - pmin = wrapper::vpmin(pmin, pmin); - return wrapper::vpmin(pmin, pmin); -} - -// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. -template <typename T> -inline typename std::enable_if< - std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value, - typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type -calculate_max(T in) -{ - auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); - return wrapper::vpmax(pmax, pmax); -} - -// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. -template <typename T> -inline typename std::enable_if< - std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value, - typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type -calculate_max(T in) -{ - auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); - pmax = wrapper::vpmax(pmax, pmax); - pmax = wrapper::vpmax(pmax, pmax); - return wrapper::vpmax(pmax, pmax); -} - -template <typename T> -uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) -{ - uint32x4_t res_idx_mask{0}; - uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); - - if (op == ReductionOperation::ARG_IDX_MIN) - { - auto pmin = calculate_min(vec_res_value); - auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); - res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); - } - else - { - auto pmax = calculate_max(vec_res_value); - auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); - res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); - } - - res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones); - auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask)); - pmin = wrapper::vpmin(pmin, pmin); - uint32_t res = wrapper::vgetlane(pmin, 0); - - return (res - 0xFFFFFFFF); -} - -template <typename T> -uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) -{ - uint32x4x4_t res_idx_mask{{0}}; - uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); - uint8x16_t mask_u8{0}; - if (op == ReductionOperation::ARG_IDX_MIN) - { - auto pmin = calculate_min(vec_res_value); - mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); - } - else - { - auto pmax = calculate_max(vec_res_value); - mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); - } - - // Widen vectors - auto wide_u16_1 = - wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8))); - auto wide_u16_2 = - wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8))); - auto wide_u32_1 = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1))); - auto wide_u32_2 = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1))); - auto wide_u32_3 = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2))); - auto wide_u32_4 = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2))); - res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1); - res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2); - res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3); - res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4); - res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones); - res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones); - res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones); - res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones); - - uint32_t res = 0xFFFFFFFF; - int iter = 0; - do - { - auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter])); - pmin = wrapper::vpmin(pmin, pmin); - res = std::min(wrapper::vgetlane(pmin, 0), res); - iter++; - } while (iter < 4); - - return (res - 0xFFFFFFFF); -} - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template <> -uint32x4x4_t -calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis) -{ - uint32x4x2_t mask{0}; - uint16x8_t mask_u16{0}; - if (op == ReductionOperation::ARG_IDX_MIN) - { - mask_u16 = wrapper::vcgt(b, a); - } - else - { - mask_u16 = wrapper::vclt(b, a); - } - mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16)); - mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16)); - uint32x4x2_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, {idx + 4, idx + 5, idx + 6, idx + 7}}}; - if (axis != 0) - { - vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - } - uint32x4x4_t res = {wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]), - wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]), 0, 0}; - - return res; -} - -// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. -inline float16x4_t calculate_min(float16x8_t in) -{ - auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); - pmin = wrapper::vpmin(pmin, pmin); - return wrapper::vpmin(pmin, pmin); -} -// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. -inline float16x4_t calculate_max(float16x8_t in) -{ - auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); - pmax = wrapper::vpmax(pmax, pmax); - return wrapper::vpmax(pmax, pmax); -} - -template <> -uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op) -{ - uint32x4x2_t res_idx_mask{0}; - uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); - uint16x8_t mask_u16; - if (op == ReductionOperation::ARG_IDX_MIN) - { - auto pmin = calculate_min(vec_res_value); - mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); - } - else - { - auto pmax = calculate_max(vec_res_value); - mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); - } - - // Widen vectors - auto wide_u32_1 = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16))); - auto wide_u32_2 = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16))); - res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1); - res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2); - res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones); - res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones); - - uint32_t res = 0xFFFFFFFF; - uint32_t iter = 0; - do - { - auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter])); - pmin = wrapper::vpmin(pmin, pmin); - res = std::min(wrapper::vgetlane(pmin, 0), res); - iter++; - } while (iter < 2); - - return (res - 0xFFFFFFFF); -} -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -template <class F> -class Reducer -{ -public: - static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) - { - // Set out window - Window out_window(window); - out_window.set(Window::DimX, Window::Dimension(0, 1, 1)); - - f(window, out_window, input, output, op); - } - static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) - { - // Set in window - Window in_window(window); - Window out_window(window); - - in_window.set(Window::DimY, Window::Dimension(0, 1, 1)); - out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1))); - - f(in_window, out_window, input, output, 1, op); - } - static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) - { - // Set in window - Window in_window(window); - Window out_window(window); - - in_window.set(Window::DimZ, Window::Dimension(0, 1, 1)); - out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2))); - - f(in_window, out_window, input, output, 2, op); - } - static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) - { - // Set in/out window - Window in_window(window); - Window out_window(window); - - in_window.set(3, Window::Dimension(0, 1, 1)); - out_window.set(3, Window::Dimension(0, 1, 1)); - - f(in_window, out_window, input, output, 3, op); - } -}; - -template <typename T, int S> -struct RedOpX -{ - /** SIMD vector tag type. */ - using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; - - inline void operator()( - const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op) - { - const size_t input_dim_0 = in->info()->dimension(0); - const int window_step_x = 16 / sizeof(T); - const auto window_start_x = static_cast<int>(in_window.x().start()); - const auto window_end_x = static_cast<int>(in_window.x().end()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_window); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast<const T *>(input.ptr()); - - auto init_res_value = static_cast<T>(0.f); - switch (op) - { - case ReductionOperation::ARG_IDX_MAX: - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - init_res_value = static_cast<T>(*input_ptr); - break; - } - case ReductionOperation::PROD: - { - init_res_value = static_cast<T>(1.f); - break; - } - default: - break; - } - auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{}); - uint32x4x4_t vec_res_idx{{0}}; - - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vec_elements = wrapper::vloadq(input_ptr + x); - switch (op) - { - case ReductionOperation::SUM_SQUARE: - vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); - break; - case ReductionOperation::MEAN_SUM: - case ReductionOperation::SUM: - vec_res_value = wrapper::vadd(vec_elements, vec_res_value); - break; - case ReductionOperation::PROD: - vec_res_value = wrapper::vmul(vec_elements, vec_res_value); - break; - case ReductionOperation::ARG_IDX_MIN: - { - auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, - vec_res_idx, op, 0); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, - vec_res_idx, op, 0); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::MIN: - { - vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - break; - } - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - case ReductionOperation::SUM_SQUARE: - { -#ifdef ARM_COMPUTE_DEBUG_ENABLED - auto res = static_cast<T>(0.f); - for (int i = 0; i < S; ++i) - { - res += wrapper::vgetlane(vec_res_value, i); - } -#else // ARM_COMPUTE_DEBUG_ENABLED - auto carry_res = - wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - for (int i = 0; i < S / 4; ++i) - { - carry_res = wrapper::vpadd(carry_res, carry_res); - } - auto res = wrapper::vgetlane(carry_res, 0); -#endif // ARM_COMPUTE_DEBUG_ENABLED - if (op == ReductionOperation::SUM_SQUARE) - { - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res += (*(input_ptr + x)) * (*(input_ptr + x)); - } - } - else - { - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res += *(input_ptr + x); - } - } - - if (op == ReductionOperation::MEAN_SUM) - { - res /= input_dim_0; - } - - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::PROD: - { - auto carry_res = - wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - T res = 1; - for (int i = 0; i < S / 2; ++i) - { - res *= wrapper::vgetlane(carry_res, i); - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res *= *(input_ptr + x); - } - - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::ARG_IDX_MIN: - { - auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); - auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - if (*(input_ptr + x) < res) - { - idx = x; - res = *(input_ptr + x); - } - } - *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); - auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - if (*(input_ptr + x) > res) - { - idx = x; - res = *(input_ptr + x); - } - } - *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; - break; - } - case ReductionOperation::MIN: - { - auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res = *(input_ptr + x) < res ? *(input_ptr + x) : res; - } - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::MAX: - { - auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res = *(input_ptr + x) > res ? *(input_ptr + x) : res; - } - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - }, - input, output); - } -}; - -template <typename T> -struct RedOpX_quantized -{ - inline void operator()( - const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op) - { - using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type; - - const auto oq_info = out->info()->quantization_info().uniform(); - - const TensorInfo in_info = *(in->info()); - const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform(); - - const int window_step_x = 16 / sizeof(T); - const auto window_start_x = static_cast<int>(in_window.x().start()); - const auto window_end_x = static_cast<int>(in_window.x().end()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_window); - - const auto in_offset = static_cast<float>(iq_info.offset); - const float in_scale = iq_info.scale; - - const auto out_offset = static_cast<float>(oq_info.offset); - const float out_scale = oq_info.scale; - - const auto num_elements = static_cast<float>(in_info.dimension(0)); - - const float A = in_scale / (out_scale * num_elements); - const float B = out_offset - (in_scale * in_offset) / (out_scale); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast<T *>(input.ptr()); - - auto vec_res_value1 = - wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); - auto vec_res_value2 = - wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); - auto vec_res_value3 = - wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); - auto vec_res_value4 = - wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); - - auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f)); - auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f)); - auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f)); - auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f)); - - typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = {0}; - - if (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || - op == ReductionOperation::MIN || op == ReductionOperation::MAX) - { - vec_res_value = wrapper::vdup_n(*input_ptr, wrapper::traits::vector_128_tag{}); - } - - uint32x4x4_t vec_res_idx{{0}}; - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vec_elements = wrapper::vloadq(input_ptr + x); - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - { - const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); - const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); - - const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); - const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); - const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); - const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); - - vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1); - vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2); - vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3); - vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4); - break; - } - case ReductionOperation::PROD: - { - const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset); - const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale); - - const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); - const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); - - const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); - const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); - const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); - const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); - - auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1); - auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2); - auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3); - auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4); - - //de-quantize vec_elements - temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4); - temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4); - temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4); - temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4); - - vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f); - vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f); - vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f); - vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f); - break; - } - case ReductionOperation::ARG_IDX_MIN: - { - auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>( - x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>( - x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::MIN: - { - vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - break; - } - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - switch (op) - { - case ReductionOperation::ARG_IDX_MIN: - { - auto idx = - calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); - auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - if (*(input_ptr + x) < res) - { - idx = x; - res = *(input_ptr + x); - } - } - *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto idx = - calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); - auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - if (*(input_ptr + x) > res) - { - idx = x; - res = *(input_ptr + x); - } - } - *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; - break; - } - case ReductionOperation::MIN: - { - auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res = *(input_ptr + x) < res ? *(input_ptr + x) : res; - } - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::MAX: - { - auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res = *(input_ptr + x) > res ? *(input_ptr + x) : res; - } - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::PROD: - { - auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); - carry_res = wrapper::vmul(carry_res, vec_res_value3_f); - carry_res = wrapper::vmul(carry_res, vec_res_value4_f); - - float res = wrapper::vgetlane(carry_res, 0); - res *= wrapper::vgetlane(carry_res, 1); - res *= wrapper::vgetlane(carry_res, 2); - res *= wrapper::vgetlane(carry_res, 3); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - //de-quantize input - if (std::is_same<T, uint8_t>::value) - { - res *= dequantize_qasymm8(*(input_ptr + x), iq_info); - } - else - { - res *= dequantize_qasymm8_signed(*(input_ptr + x), iq_info); - } - } - - //re-quantize result - if (std::is_same<T, uint8_t>::value) - { - res = quantize_qasymm8(res, iq_info); - } - else - { - res = quantize_qasymm8_signed(res, iq_info); - } - - *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res); - break; - } - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - { - auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2); - carry_res = wrapper::vadd(carry_res, vec_res_value3); - carry_res = wrapper::vadd(carry_res, vec_res_value4); - - auto carry_paddition = - wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res)); - carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition); - auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res += *(input_ptr + x); - } - - if (op == ReductionOperation::MEAN_SUM) - { - const int32_t resFinal = A * (static_cast<float>(res)) + B; - - *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(resFinal); - } - else - { - // Subtract accumulated offsets - res -= (in_info.dimension(0) - 1) * iq_info.offset; - *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res); - } - - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - }, - input, output); - } -}; - -template <typename T, int S> -struct RedOpYZW -{ - /** SIMD vector tag type. */ - using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; - using neon_vector = typename wrapper::traits::neon_vector<T, S>::type; - - inline void operator()(const Window &in_window, - Window &out_window, - const ITensor *in, - ITensor *out, - int axis, - const ReductionOperation op) - { - const TensorInfo in_info = *(in->info()); - const int window_step_x = 16 / sizeof(T); - const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); - const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); - // As it split over x-axis, need to set the correct spiltted window start and end. - const auto window_start_x = static_cast<int>(0); - const auto window_end_x = static_cast<int>(in_window.shape().x()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); - Window out_win_no_pad = out_window; - out_win_no_pad.set(Window::DimX, - Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_win_no_pad); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast<T *>(input.ptr()); - - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - neon_vector vec_res_value = {0}; - switch (op) - { - case ReductionOperation::ARG_IDX_MAX: - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vloadq(input_ptr + x); - break; - } - case ReductionOperation::PROD: - { - vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{}); - break; - } - default: - { - vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); - break; - } - } - uint32x4x4_t vec_res_idx{{0}}; - - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - const T *in_ptr = - reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim); - const auto vec_elements = wrapper::vloadq(in_ptr); - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - vec_res_value = wrapper::vadd(vec_elements, vec_res_value); - break; - case ReductionOperation::SUM_SQUARE: - vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); - break; - case ReductionOperation::PROD: - vec_res_value = wrapper::vmul(vec_elements, vec_res_value); - break; - case ReductionOperation::ARG_IDX_MIN: - { - auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = - calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = - calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::MIN: - { - vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - break; - } - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - if (op == ReductionOperation::MEAN_SUM) - { - auto vec_width_inv = - wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{})); - vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv); - } - - if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) - { - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - if (std::is_same<T, float16_t>::value) - { - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]); - } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - } - else - { - wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value); - } - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - auto res_value = 0.f; - switch (op) - { - case ReductionOperation::ARG_IDX_MAX: - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - res_value = *(input_ptr + x); - break; - } - case ReductionOperation::PROD: - { - res_value = static_cast<T>(1.f); - break; - } - default: - { - res_value = static_cast<T>(0.f); - break; - } - } - - uint32_t res_idx = 0; - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - const T *in_ptr = - reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim); - - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - res_value += *in_ptr; - break; - case ReductionOperation::SUM_SQUARE: - res_value += *in_ptr * *in_ptr; - break; - case ReductionOperation::PROD: - res_value *= *in_ptr; - break; - case ReductionOperation::ARG_IDX_MIN: - { - if (*in_ptr < res_value) - { - res_value = *in_ptr; - res_idx = dim; - } - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - if (*in_ptr > res_value) - { - res_value = *in_ptr; - res_idx = dim; - } - break; - } - case ReductionOperation::MIN: - { - res_value = *in_ptr < res_value ? *in_ptr : res_value; - break; - } - case ReductionOperation::MAX: - { - res_value = *in_ptr > res_value ? *in_ptr : res_value; - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - if (op == ReductionOperation::MEAN_SUM) - { - res_value /= in_info.dimension(axis); - } - - if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) - { - *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx; - } - else - { - *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value; - } - } - }, - input, output); - } -}; - -template <typename T, int S, int axis, ReductionOperation op> -struct RedOpYZW_complex -{ - /** SIMD vector tag type. */ - using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; - using neon_vector = typename wrapper::traits::neon_vector<T, S>::type; - - inline void operator()( - const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation) - { - ARM_COMPUTE_ERROR_ON(axis != 2); - ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM); - - const TensorInfo in_info = *(in->info()); - const size_t stride_z = in_info.strides_in_bytes()[axis]; - const int window_step_x = 16 / sizeof(T); - const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); - const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); - // As it split over x-axis, need to set the correct spiltted window start and end. - const auto window_start_x = static_cast<int>(0); - const auto window_end_x = static_cast<int>(in_window.shape().x()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); - Window out_win_no_pad = out_window; - out_win_no_pad.set(Window::DimX, - Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_win_no_pad); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - neon_vector vec_res_value_0 = {0}; - neon_vector vec_res_value_1 = {0}; - - vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); - vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); - - T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T)); - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim); - T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim); - - const auto vec_elements_0 = wrapper::vloadq(in_ptr_0); - const auto vec_elements_1 = wrapper::vloadq(in_ptr_1); - - vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0); - vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1); - } - - wrapper::vstore(out_ptr, vec_res_value_0); - wrapper::vstore(out_ptr + 4, vec_res_value_1); - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - auto res_value_0 = 0.f; - auto res_value_1 = 0.f; - - T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T)); - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim); - res_value_0 += *in_ptr; - res_value_1 += *(in_ptr + 1); - } - *out_ptr = res_value_0; - *(out_ptr + 1) = res_value_1; - } - }, - input, output); - } -}; - -template <typename T> -struct RedOpYZW_quantized -{ - inline void operator()(const Window &in_window, - Window &out_window, - const ITensor *in, - ITensor *out, - int axis, - const ReductionOperation op) - { - const TensorInfo in_info = *(in->info()); - const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform(); - using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type; - - const auto oq_info = out->info()->quantization_info().uniform(); - - const int window_step_x = 16 / sizeof(T); - const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); - const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); - // As it split over x-axis, need to set the correct spiltted window start and end. - const auto window_start_x = static_cast<int>(0); - const auto window_end_x = static_cast<int>(in_window.shape().x()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); - Window out_win_no_pad = out_window; - out_win_no_pad.set(Window::DimX, - Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_win_no_pad); - - using vector_type = - typename wrapper::traits::neon_bitvector<PromotedType, wrapper::traits::BitWidth::W128>::type; - using vector_type_f = typename wrapper::traits::neon_vector<float, 4>::type; - - vector_type vec_res_value1{}; - vector_type vec_res_value2{}; - vector_type vec_res_value3{}; - vector_type vec_res_value4{}; - - vector_type_f vec_res_value1_f{}; - vector_type_f vec_res_value2_f{}; - vector_type_f vec_res_value3_f{}; - vector_type_f vec_res_value4_f{}; - - const float in_offset = static_cast<float>(iq_info.offset); - const float in_scale = iq_info.scale; - - const float out_offset = static_cast<float>(oq_info.offset); - const float out_scale = oq_info.scale; - - const float num_elements = static_cast<float>(in_info.dimension(axis)); - - const float A = in_scale / (out_scale * num_elements); - const float B = out_offset - (in_scale * in_offset) / (out_scale); - - const auto vec_A = wrapper::vdup_n(static_cast<float>(A), wrapper::traits::vector_128_tag{}); - const auto vec_B = wrapper::vdup_n(static_cast<float>(B), wrapper::traits::vector_128_tag{}); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast<T *>(input.ptr()); - - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - uint32x4x4_t vec_res_idx{{0}}; - vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); - vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); - vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); - vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); - - vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); - vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); - vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); - vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); - - auto vec_res_value = wrapper::vloadq(input_ptr + x); - - for (unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim) - { - const T *in_ptr = input_ptr + x + in_info.strides_in_bytes()[axis] * index_dim; - const auto vec_elements = wrapper::vloadq(in_ptr); - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - { - const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); - const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); - - const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); - const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); - const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); - const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); - - vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1); - vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2); - vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3); - vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4); - break; - } - case ReductionOperation::PROD: - { - const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), - wrapper::traits::vector_128_tag{}); - const auto scale32x4f_4 = - wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{}); - - const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); - const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); - - const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); - const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); - const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); - const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); - - auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1); - auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2); - auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3); - auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4); - - //de-quantize vec_elements - temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4); - temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4); - temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4); - temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4); - - vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f); - vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f); - vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f); - vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f); - break; - } - case ReductionOperation::ARG_IDX_MIN: - { - auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, - vec_res_idx, op, axis); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, - vec_res_idx, op, axis); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::MIN: - { - vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - break; - } - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - switch (op) - { - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::ARG_IDX_MAX: - { - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x), vec_res_idx.val[0]); - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 4, vec_res_idx.val[1]); - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 8, vec_res_idx.val[2]); - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 12, - vec_res_idx.val[3]); - break; - } - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), vec_res_value); - break; - } - case ReductionOperation::SUM: - { - // Subtract offsets - auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset); - - auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1); - auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2); - auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3); - auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4); - vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets); - vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets); - vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets); - vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets); - - const auto temp16x8t_1 = - wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2)); - const auto temp16x8t_2 = - wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4)); - - combine_and_store<T>(temp16x8t_1, temp16x8t_2, output, x); - break; - } - case ReductionOperation::MEAN_SUM: - { - vec_res_value1_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value1), vec_A); - vec_res_value2_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value2), vec_A); - vec_res_value3_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value3), vec_A); - vec_res_value4_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value4), vec_A); - -#ifdef __aarch64__ - vec_res_value1 = wrapper::vcvta<PromotedType>(vec_res_value1_f); - vec_res_value2 = wrapper::vcvta<PromotedType>(vec_res_value2_f); - vec_res_value3 = wrapper::vcvta<PromotedType>(vec_res_value3_f); - vec_res_value4 = wrapper::vcvta<PromotedType>(vec_res_value4_f); -#else // defined(__aarch64__) - vec_res_value1 = wrapper::vcvt<PromotedType>(vec_res_value1_f); - vec_res_value2 = wrapper::vcvt<PromotedType>(vec_res_value2_f); - vec_res_value3 = wrapper::vcvt<PromotedType>(vec_res_value3_f); - vec_res_value4 = wrapper::vcvt<PromotedType>(vec_res_value4_f); -#endif // __aarch64__ - - const auto temp16x8t_1 = - wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2)); - const auto temp16x8t_2 = - wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4)); - auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2)); - - wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res); - break; - } - case ReductionOperation::PROD: - { - const auto offset32x4f_4 = - wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{}); - const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale)); - - //re-quantize - vec_res_value1_f = - wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4); - vec_res_value2_f = - wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4); - vec_res_value3_f = - wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4); - vec_res_value4_f = - wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4); - - vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f); - vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f); - vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f); - vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f); - - const auto temp16x8t_1 = - wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2)); - const auto temp16x8t_2 = - wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4)); - auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2)); - - wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - float res_value = 0.f; - int32_t res_value_q = 0; - - switch (op) - { - case ReductionOperation::ARG_IDX_MAX: - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - res_value = *(input_ptr + x); - break; - } - case ReductionOperation::PROD: - { - res_value = static_cast<T>(1.0f); - break; - } - default: - { - res_value = static_cast<T>(0.0f); - break; - } - } - uint32_t res_idx = 0; - - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - const T *in_ptr = - reinterpret_cast<T *>(input.ptr() + x + in_info.strides_in_bytes()[axis] * dim); - switch (op) - { - case ReductionOperation::SUM: - { - res_value += *in_ptr; - break; - } - case ReductionOperation::MEAN_SUM: - { - res_value_q += *in_ptr; - break; - } - case ReductionOperation::SUM_SQUARE: - { - res_value += *in_ptr * *in_ptr; - break; - } - case ReductionOperation::PROD: - { - //de-quantize input - if (std::is_same<T, uint8_t>::value) - { - res_value *= dequantize_qasymm8(*in_ptr, iq_info); - } - else - { - res_value *= dequantize_qasymm8_signed(*in_ptr, iq_info); - } - break; - } - case ReductionOperation::ARG_IDX_MIN: - { - if (*in_ptr < res_value) - { - res_value = *in_ptr; - res_idx = dim; - } - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - if (*in_ptr > res_value) - { - res_value = *in_ptr; - res_idx = dim; - } - break; - } - case ReductionOperation::MIN: - { - res_value = *in_ptr < res_value ? *in_ptr : res_value; - break; - } - case ReductionOperation::MAX: - { - res_value = *in_ptr > res_value ? *in_ptr : res_value; - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - switch (op) - { - case ReductionOperation::MEAN_SUM: - { - // Apply previously calculated coefficients (with rounding on aarch64) -#ifdef __aarch64__ - const int32_t res = - arm_compute::support::cpp11::round(A * (static_cast<float>(res_value_q)) + B); -#else // defined(__aarch64__) - const int32_t res = A * (static_cast<float>(res_value_q)) + B; -#endif // __aarch64__ - *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res); - break; - } - case ReductionOperation::SUM: - { - // Subtract accumulated offsets - res_value -= (in_info.dimension(axis) - 1) * iq_info.offset; - *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value); - break; - } - case ReductionOperation::PROD: - { - //re-quantize result - T res = 0; - if (std::is_same<T, uint8_t>::value) - { - res = quantize_qasymm8(res_value, iq_info); - } - else - { - res = quantize_qasymm8_signed(res_value, iq_info); - } - *(reinterpret_cast<T *>(output.ptr() + x)) = res; - break; - } - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::ARG_IDX_MAX: - { - *(reinterpret_cast<uint32_t *>(output.ptr() + x * 4)) = res_idx; - break; - } - default: - *(reinterpret_cast<T *>(output.ptr() + x)) = res_value; - } - } - }, - input, output); - } -}; - -void reduce_op( - const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op) +void NEReductionOperationKernel::reduce_op() { - const bool is_complex = (input->info()->num_channels() == 2); + const bool is_complex = (_input->info()->num_channels() == 2); if (is_complex) { - switch (axis) + switch (_reduction_axis) { case 2: - switch (input->info()->data_type()) + switch (_input->info()->data_type()) { case DataType::F32: - switch (op) + { + switch (_op) { case ReductionOperation::SUM: - return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ( - window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), - op); + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpYZW_complex_reduceZ_float32_4_2_SUM); + break; default: ARM_COMPUTE_ERROR("Not supported"); + break; } + break; + } default: + { ARM_COMPUTE_ERROR("Not supported"); + break; + } } + break; default: + { ARM_COMPUTE_ERROR("Not supported"); + break; + } } return; } - switch (axis) + switch (_reduction_axis) { case 0: { - switch (input->info()->data_type()) + switch (_input->info()->data_type()) { case DataType::QASYMM8: { - return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, - RedOpX_quantized<uint8_t>(), op); + _func = REGISTER_QASYMM8_NEON(cpu::reduce_RedOpX_reduceX_qasymm8); + break; } case DataType::QASYMM8_SIGNED: { - return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), - op); + _func = REGISTER_QASYMM8_SIGNED_NEON(cpu::reduce_RedOpX_reduceX_qasymm8_signed); + break; } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op); -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_FP16_NEON(cpu::reduce_RedOpX_reduceX_float16_8); + break; + } +#endif // ARM_COMPUTE_ENABLE_FP16 case DataType::F32: { - return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op); + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpX_reduceX_float32_4); + break; } case DataType::S32: { - return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op); + _func = REGISTER_INTEGER_NEON(cpu::reduce_RedOpX_reduceX_S32_4); + break; } default: { ARM_COMPUTE_ERROR("Not supported"); + break; } } + break; } case 1: - switch (input->info()->data_type()) + { + switch (_input->info()->data_type()) { case DataType::QASYMM8: { - return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, - RedOpYZW_quantized<uint8_t>(), op); + _func = REGISTER_QASYMM8_NEON(cpu::reduce_RedOpYZW_reduceY_qasymm8); + break; } case DataType::QASYMM8_SIGNED: { - return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, - RedOpYZW_quantized<int8_t>(), op); + _func = REGISTER_QASYMM8_SIGNED_NEON(cpu::reduce_RedOpYZW_reduceY_qasymm8_signed); + break; } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), - op); -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_FP16_NEON(cpu::reduce_RedOpYZW_reduceY_float16_8); + break; + } +#endif // ARM_COMPUTE_ENABLE_FP16 case DataType::F32: - return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op); + { + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpYZW_reduceY_float32_4); + break; + } case DataType::S32: - return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op); + { + _func = REGISTER_INTEGER_NEON(cpu::reduce_RedOpYZW_reduceY_S32_4); + break; + } default: + { ARM_COMPUTE_ERROR("Not supported"); + break; + } } + break; + } case 2: - switch (input->info()->data_type()) + { + switch (_input->info()->data_type()) { case DataType::QASYMM8: - return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, - RedOpYZW_quantized<uint8_t>(), op); + { + _func = REGISTER_QASYMM8_NEON(cpu::reduce_RedOpYZW_reduceZ_qasymm8); + break; + } case DataType::QASYMM8_SIGNED: - return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, - RedOpYZW_quantized<int8_t>(), op); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_QASYMM8_SIGNED_NEON(cpu::reduce_RedOpYZW_reduceZ_qasymm8_signed); + break; + } +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), - op); -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_FP16_NEON(cpu::reduce_RedOpYZW_reduceZ_float16_8); + break; + } +#endif // ARM_COMPUTE_ENABLE_FP16 case DataType::F32: - return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op); + { + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpYZW_reduceZ_float32_4); + break; + } case DataType::S32: - return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op); + { + _func = REGISTER_INTEGER_NEON(cpu::reduce_RedOpYZW_reduceZ_S32_4); + break; + } default: + { + std::cout << int(_input->info()->data_type()) << std::endl; ARM_COMPUTE_ERROR("Not supported"); + break; + } } + break; + } case 3: - switch (input->info()->data_type()) + { + switch (_input->info()->data_type()) { case DataType::QASYMM8: - return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, - RedOpYZW_quantized<uint8_t>(), op); + { + _func = REGISTER_QASYMM8_NEON(cpu::reduce_RedOpYZW_reduceW_qasymm8); + break; + } case DataType::QASYMM8_SIGNED: - return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, - RedOpYZW_quantized<int8_t>(), op); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_QASYMM8_SIGNED_NEON(cpu::reduce_RedOpYZW_reduceW_qasymm8_signed); + break; + } +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), - op); -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_FP16_NEON(cpu::reduce_RedOpYZW_reduceW_float16_8); + break; + } +#endif // ARM_COMPUTE_ENABLE_FP16 case DataType::F32: - return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op); + { + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpYZW_reduceW_float32_4); + break; + } case DataType::S32: - return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op); + { + _func = REGISTER_INTEGER_NEON(cpu::reduce_RedOpYZW_reduceW_S32_4); + break; + } default: + { ARM_COMPUTE_ERROR("Not supported"); + break; + } } + break; + } default: + { ARM_COMPUTE_ERROR("Unsupported reduction axis"); + break; + } } } @@ -1819,10 +293,9 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u return Status{}; } -} // namespace NEReductionOperationKernel::NEReductionOperationKernel() - : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE) + : _func(nullptr), _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE) { } @@ -1856,6 +329,8 @@ void NEReductionOperationKernel::configure(const ITensor *input, .set_data_type(output_data_type) .reset_padding() .set_is_resizable(true)); + // Determine the reduction function + NEReductionOperationKernel::reduce_op(); } Status NEReductionOperationKernel::validate(const ITensorInfo *input, @@ -1874,6 +349,6 @@ void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &inf ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - reduce_op(window, _input, _output, _reduction_axis, _op); + (*_func)(window, _input, _output, _op); } } // namespace arm_compute diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.h b/src/core/NEON/kernels/NEReductionOperationKernel.h index 78bec62c14..407e5de6d6 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.h +++ b/src/core/NEON/kernels/NEReductionOperationKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_NEREDUCTIONOPERATIONKERNEL_H -#define ARM_COMPUTE_NEREDUCTIONOPERATIONKERNEL_H +#ifndef ACL_SRC_CORE_NEON_KERNELS_NEREDUCTIONOPERATIONKERNEL_H +#define ACL_SRC_CORE_NEON_KERNELS_NEREDUCTIONOPERATIONKERNEL_H #include "src/core/NEON/INEKernel.h" @@ -80,14 +80,24 @@ public: static Status validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op); +private: // Inherited methods overridden: void run(const Window &window, const ThreadInfo &info) override; + /** Common signature for all the specialized Reduction functions + * + * @param[in] window Region on which to execute the kernel. + */ + using ReductionFunction = void (*)(const Window &window, const ITensor *in, ITensor *out, ReductionOperation op); -private: + /** Populate the _func with the right reduction operation handler + */ + void reduce_op(); + + ReductionFunction _func; const ITensor *_input; ITensor *_output; unsigned int _reduction_axis; ReductionOperation _op; }; } // namespace arm_compute -#endif /*ARM_COMPUTE_NEREDUCTIONOPERATIONKERNEL_H */ +#endif // ACL_SRC_CORE_NEON_KERNELS_NEREDUCTIONOPERATIONKERNEL_H diff --git a/src/core/NEON/kernels/NEReorderKernel.cpp b/src/core/NEON/kernels/NEReorderKernel.cpp index f5bea3e163..fe8882f59f 100644 --- a/src/core/NEON/kernels/NEReorderKernel.cpp +++ b/src/core/NEON/kernels/NEReorderKernel.cpp @@ -27,6 +27,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/runtime/Scheduler.h" #include "src/common/utils/Log.h" #include "src/core/NEON/kernels/arm_gemm/transform.hpp" @@ -233,13 +234,20 @@ Status NEReorderKernel::validate(const ITensorInfo *input, } } - int ksize; + int ksize = 0; switch (output_wf) { #if defined(ARM_COMPUTE_ENABLE_SVE) case WeightFormat::OHWIo8: { - ksize = 8; + if (Scheduler::get().cpu_info().has_sve() && arm_gemm::utils::get_vector_length<float>() == 8) + { + ksize = 8; + } + else + { + ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported weight format."); + } break; } #endif /* ARM_COMPUTE_ENABLE_SVE */ diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp index 5c08e6137d..0ddca04846 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp @@ -86,7 +86,7 @@ static const GemmImplementation<bfloat16, float> gemm_bf16_methods[] = "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, bfloat16, float>(args); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index 3b444ae333..c7adf8e4ac 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -69,19 +69,19 @@ static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = { }, { GemmMethod::GEMM_INTERLEAVED, - "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL", + "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } }, { GemmMethod::GEMM_INTERLEAVED, - "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL", + "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } + return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } }, { GemmMethod::GEMM_INTERLEAVED, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index af0d38ec37..0c1d3a387b 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -141,7 +141,7 @@ GemmImplementation<float, float>::with_estimate( "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, float, float>(args); } }, #endif // ARM_COMPUTE_ENABLE_BF16 @@ -150,7 +150,7 @@ GemmImplementation<float, float>::with_estimate( "sme2_interleaved_nomerge_fp32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_1VLx4VL, float, float>(args); } }, #ifdef ARM_COMPUTE_ENABLE_BF16 @@ -199,14 +199,14 @@ GemmImplementation<float, float>::with_estimate( GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_fp32bf16fp32_mmla_6x4VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>(args); } ), GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_fp32bf16fp32_mmla_4x6VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>(args); } ), diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index 0dc0d55b27..fedda3a47a 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -63,7 +63,7 @@ static const GemmImplementation<int8_t, int32_t> gemm_s8_methods[] = { "sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<int32_t>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL, int8_t, int32_t>(args); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index ae344f09b5..897ec9d05f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -190,10 +190,19 @@ void kernel_and_merge<false, false, Requantize32>::run( auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k); #endif + // Offset C pointer in a similar way to non-quantized case above. + Tri *offset_c_ptr; + + if (c_ptr == nullptr) { + offset_c_ptr = nullptr; + } else { + offset_c_ptr = c_ptr + m_0 * ldc + n_0; + } + strat.kernel(// A and B pointers are just the packed panels. a_ptr, b_panel, // Provide relevant part of output array and row stride. - c_ptr + m_0 * ldc + n_0, ldc, + offset_c_ptr, ldc, // M, N, K sizes m_max-m_0, n_max - n_0, kern_k, // Bias, activation, accumulation. Need to offset the bias as needed. @@ -663,15 +672,27 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return roundup(args._cfg->inner_block_size, strategy::k_unroll()); } - // K blocking not supported if we are requantizing. - if (std::is_same<OutputStage, Requantize32>::value) { + // K blocking not supported if we are requantizing with the merging + // kernels. + if (std::is_same<OutputStage, Requantize32>::value && MergeStep) { return get_ktotal(args); } + const unsigned int L1_size = args._ci->get_L1_cache_size(); + // Special blocking for SME if (is_sme<strategy>::value) { - // Don't bother to block below this size threshold, experimentally determined to be 320 for FP32 - unsigned int scaling_threshold = 1280 / sizeof(Toi); + // Target 512 bytes for 64kB L1, or 1024 bytes for 128kB L1. + unsigned int target_bytes_per_block = L1_size / 128; + + // Default cache size in gemm-linux is 32kB though - so make + // sure minimum is 512 + if (target_bytes_per_block < 512) { + target_bytes_per_block = 512; + } + + // Don't bother to block below this size threshold (1.25X target size) + unsigned int scaling_threshold = ((target_bytes_per_block * 5) / 4) / sizeof(Toi); if (get_ktotal(args) <= scaling_threshold) { return get_ktotal(args); @@ -679,7 +700,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { // Once we are blocking, this (lower) threshold determines when we should use more blocks // NOTE: Could be that some factor-based solution would work better here. - unsigned int max_block_size = 1024 / sizeof(Toi); + unsigned int max_block_size = target_bytes_per_block / sizeof(Toi); unsigned int num_k_blocks = iceildiv(get_ktotal(args), max_block_size); @@ -688,7 +709,6 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return k_block; } - const unsigned int L1_size = args._ci->get_L1_cache_size(); unsigned int k_block; // k_block: Find out how much of the larger array can be loaded into half the cache. @@ -723,6 +743,17 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return roundup(args._cfg->outer_block_size, strategy::out_width()); } + // Special blocking for SME + if (is_sme<strategy>::value) { + // If total width is less than 4x kernel width, return the entire width. + if (args._Nsize < strategy::out_width()*4) { + return roundup(args._Nsize, strategy::out_width()); + } + + // Otherwise block to single kernel width. + return strategy::out_width(); + } + unsigned int x_block; const unsigned int L2_size = args._ci->get_L2_cache_size(); const unsigned int k_block = get_k_block_size(args); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp index d1c4e49edb..321c97262f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -82,7 +82,7 @@ static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods "sme2_interleaved_nomerge_s8q_mopa_1VLx4VL", [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_s8q_mopa_1VLx4VL, int8_t, int8_t>(args, qp); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index b85b1c4fcf..93eecf991e 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -78,7 +78,7 @@ static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_meth "sme2_interleaved_nomerge_u8q_mopa_1VLx4VL", [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<uint32_t>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_1VLx4VL, uint8_t, uint8_t>(args, qp); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp b/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp index 59591935cd..7c09608e3e 100644 --- a/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp +++ b/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. + * Copyright (c) 2020-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -330,11 +330,11 @@ template void Interleave<8, 2, VLType::None>(float *, const float *, size_t, uns #endif // ARM_COMPUTE_ENABLE_SVE && ARM_COMPUTE_ENABLE_SVEF32MM /* FP16 */ -#if defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16) template void IndirectInterleave<8, 1, VLType::None>(__fp16 *, const __fp16 * const * const *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void ConvolutionInterleave<8, 1, VLType::None>(__fp16 *, const __fp16 *, size_t, const convolver<__fp16> &, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void Interleave<8, 1, VLType::None>(__fp16 *, const __fp16 *, size_t, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); -#endif // FP16_KERNELS ar __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // FP16_KERNELS ar ARM_COMPUTE_ENABLE_FP16 template void IndirectInterleave<8, 1, VLType::None>(float *, const __fp16 * const * const *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void ConvolutionInterleave<8, 1, VLType::None>(float *, const __fp16 *, size_t, const convolver<__fp16> &, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp index 586d6a64a4..d9668aae02 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16)) #include "../performance_parameters.hpp" #include "../std_transforms_fixed.hpp" @@ -89,4 +89,4 @@ public: } // namespace arm_gemm -#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // __aarch64__ && (FP16_KERNELS || ARM_COMPUTE_ENABLE_FP16) diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp index a81d4504ae..ba47e0aa54 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2020, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16)) template<> void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const __fp16 *bias, Activation act, bool append) @@ -86,7 +86,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -140,7 +140,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -217,7 +217,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -317,7 +317,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -439,7 +439,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -584,7 +584,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -752,7 +752,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -944,7 +944,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1150,7 +1150,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1204,7 +1204,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1278,7 +1278,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1372,7 +1372,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1485,7 +1485,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1618,7 +1618,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1771,7 +1771,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1945,7 +1945,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -2112,4 +2112,4 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } } -#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // __aarch64__ && (FP16_KERNELS || ARM_COMPUTE_ENABLE_FP16) diff --git a/src/core/NEON/kernels/arm_gemm/transform.cpp b/src/core/NEON/kernels/arm_gemm/transform.cpp index 45e4f0e1de..06d9e2416c 100644 --- a/src/core/NEON/kernels/arm_gemm/transform.cpp +++ b/src/core/NEON/kernels/arm_gemm/transform.cpp @@ -129,17 +129,17 @@ void Transform( // We don't have assembler transforms for AArch32, generate templated ones here. #ifdef __arm__ template void Transform<8, 1, true, VLType::None>(float *, const float *, int, int, int, int, int); -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(ARM_COMPUTE_ENABLE_FP16) template void Transform<8, 1, true, VLType::None>(float *, const __fp16 *, int, int, int, int, int); -#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // defined(ARM_COMPUTE_ENABLE_FP16) #ifdef ARM_COMPUTE_ENABLE_BF16 template void Transform<8, 1, true, VLType::None>(float *, const bfloat16 *, int, int, int, int, int); #endif // ARM_COMPUTE_ENABLE_BF16 #endif // AArch32 -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(ARM_COMPUTE_ENABLE_FP16) template void Transform<12, 1, false, VLType::None>(float *, const __fp16 *, int, int, int, int, int); -#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // defined(ARM_COMPUTE_ENABLE_FP16) #ifdef ARM_COMPUTE_ENABLE_BF16 template void Transform<12, 1, false, VLType::None>(float *, const bfloat16 *, int, int, int, int, int); #endif // ARM_COMPUTE_ENABLE_BF16 diff --git a/src/core/common/Registrars.h b/src/core/common/Registrars.h index a74316b486..cd849c3666 100644 --- a/src/core/common/Registrars.h +++ b/src/core/common/Registrars.h @@ -72,9 +72,13 @@ #endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ #if defined(ARM_COMPUTE_ENABLE_SME2) -#define REGISTER_FP32_SME2(func_name) &(func_name) +#define REGISTER_FP32_SME2(func_name) &(func_name) +#define REGISTER_QASYMM8_SME2(func_name) &(func_name) +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) &(func_name) #else /* !defined(ARM_COMPUTE_ENABLE_SME2) */ -#define REGISTER_FP32_SME2(func_name) nullptr +#define REGISTER_FP32_SME2(func_name) nullptr +#define REGISTER_QASYMM8_SME2(func_name) nullptr +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) nullptr #endif /* defined(ARM_COMPUTE_ENABLE_SME2) */ #if defined(ARM_COMPUTE_ENABLE_NEON) @@ -106,10 +110,17 @@ #define REGISTER_QASYMM8_SIGNED_SVE2(func_name) nullptr #endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ +#if defined(ARM_COMPUTE_ENABLE_SME2) +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) &(func_name) +#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */ +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) nullptr +#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */ + #else /* defined(ENABLE_QASYMM8_SIGNED_KERNELS) */ #define REGISTER_QASYMM8_SIGNED_NEON(func_name) nullptr #define REGISTER_QASYMM8_SIGNED_SVE(func_name) nullptr #define REGISTER_QASYMM8_SIGNED_SVE2(func_name) nullptr +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) nullptr #endif /* defined(ENABLE_QASYMM8_SIGNED_KERNELS) */ #if defined(ENABLE_QASYMM8_KERNELS) @@ -127,10 +138,17 @@ #define REGISTER_QASYMM8_SVE2(func_name) nullptr #endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ +#if defined(ARM_COMPUTE_ENABLE_SME2) +#define REGISTER_QASYMM8_SME2(func_name) &(func_name) +#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */ +#define REGISTER_QASYMM8_SME2(func_name) nullptr +#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */ + #else /* defined(ENABLE_QASYMM8_KERNELS) */ #define REGISTER_QASYMM8_NEON(func_name) nullptr #define REGISTER_QASYMM8_SVE(func_name) nullptr #define REGISTER_QASYMM8_SVE2(func_name) nullptr +#define REGISTER_QASYMM8_SME2(func_name) nullptr #endif /* defined(ENABLE_QASYMM8_KERNELS) */ #if defined(ENABLE_QSYMM16_KERNELS) diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h index d71789cc39..7c1e4772a6 100644 --- a/src/cpu/kernels/CpuKernelSelectionTypes.h +++ b/src/cpu/kernels/CpuKernelSelectionTypes.h @@ -105,6 +105,7 @@ struct SoftmaxKernelDataTypeISASelectorData cpuinfo::CpuIsaInfo isa; bool is_log; int axis; + unsigned long sme2_vector_length; }; // Selector pointer types diff --git a/src/cpu/kernels/CpuQuantizeKernel.cpp b/src/cpu/kernels/CpuQuantizeKernel.cpp index d2ac6cf8ac..ed4675ae3d 100644 --- a/src/cpu/kernels/CpuQuantizeKernel.cpp +++ b/src/cpu/kernels/CpuQuantizeKernel.cpp @@ -29,12 +29,12 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" +#include "src/core/common/Registrars.h" #include "src/core/CPP/Validate.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" -#include "src/core/NEON/NEAsymm.h" -#include "src/core/NEON/NEMath.h" #include "src/core/NEON/wrapper/wrapper.h" +#include "src/cpu/kernels/quantize/generic/neon/list.h" #include <arm_neon.h> #include <map> @@ -47,7 +47,6 @@ namespace kernels { namespace { -constexpr auto window_step = 16; Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst) { @@ -63,59 +62,6 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst) return Status{}; } -template <typename T> -inline float32x4x4_t load_value(const T *input_ptr) -{ - using Tx16_t = typename wrapper::traits::neon_vector<T, 16>::type; - return arm_compute::convert_to_float32x4x4<Tx16_t>(wrapper::vloadq(input_ptr)); -} - -template <> -inline float32x4x4_t load_value(const float *input_ptr) -{ - return {wrapper::vloadq(input_ptr), wrapper::vloadq(input_ptr + 4), wrapper::vloadq(input_ptr + 8), - wrapper::vloadq(input_ptr + 12)}; -} -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template <> -inline float32x4x4_t load_value(const float16_t *input_ptr) -{ - return {vcvt_f32_f16(wrapper::vload(input_ptr)), vcvt_f32_f16(wrapper::vload(input_ptr + 4)), - vcvt_f32_f16(wrapper::vload(input_ptr + 8)), vcvt_f32_f16(wrapper::vload(input_ptr + 12))}; -} - -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -template <typename element_type> -using vector_type = wrapper::traits::neon_vector_t<element_type, window_step>; - -template <typename quantized_type> -vector_type<quantized_type> vquantize_qasymm8(const float32x4x4_t &qv, const UniformQuantizationInfo &qi); - -template <> -vector_type<uint8_t> vquantize_qasymm8<uint8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi) -{ - return vquantize(qv, qi); -} - -template <> -vector_type<int8_t> vquantize_qasymm8<int8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi) -{ - return vquantize_signed(qv, qi); -} - -template <typename TOut, typename = typename std::enable_if<std::is_signed<TOut>::value, bool>::type> -inline int8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper) -{ - return wrapper::vcombine(wrapper::vqmovn(lower), wrapper::vqmovn(upper)); -} - -template <typename TOut, typename = typename std::enable_if<std::is_unsigned<TOut>::value, bool>::type> -inline uint8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper) -{ - return wrapper::vcombine(wrapper::vqmovun(lower), wrapper::vqmovun(upper)); -} - } // namespace void CpuQuantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst) @@ -124,38 +70,36 @@ void CpuQuantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst) ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, dst)); static const std::map<std::string, QuantizeFunctionExecutorPtr> quant_map = { - {"op_QASYMM8_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<uint8_t, uint8_t>}, - {"op_QASYMM8_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<uint8_t, int8_t>}, - {"op_QASYMM8_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<uint8_t>}, + {"op_QASYMM8_QASYMM8", REGISTER_INTEGER_NEON(u8_u8_run_quantize_qasymm8)}, + {"op_QASYMM8_QASYMM8_SIGNED", REGISTER_INTEGER_NEON(u8_i8_run_quantize_qasymm8)}, + {"op_QASYMM8_QASYMM16", REGISTER_INTEGER_NEON(u8_run_quantize_qasymm16)}, - {"op_QASYMM8_SIGNED_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<int8_t, uint8_t>}, - {"op_QASYMM8_SIGNED_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<int8_t, int8_t>}, - {"op_QASYMM8_SIGNED_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<int8_t>}, + {"op_QASYMM8_SIGNED_QASYMM8", REGISTER_INTEGER_NEON(i8_u8_run_quantize_qasymm8)}, + {"op_QASYMM8_SIGNED_QASYMM8_SIGNED", REGISTER_INTEGER_NEON(i8_i8_run_quantize_qasymm8)}, + {"op_QASYMM8_SIGNED_QASYMM16", REGISTER_INTEGER_NEON(i8_run_quantize_qasymm16)}, // Functions for offset only requantization - {"op_OFFSET_ONLY_QASYMM8_QASYMM8", &CpuQuantizeKernel::run_requantize_offset_only<uint8_t, uint8_t>}, - {"op_OFFSET_ONLY_QASYMM8_QASYMM8_SIGNED", &CpuQuantizeKernel::run_requantize_offset_only<uint8_t, int8_t>}, - {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8", &CpuQuantizeKernel::run_requantize_offset_only<int8_t, uint8_t>}, - {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8_SIGNED", - &CpuQuantizeKernel::run_requantize_offset_only<int8_t, int8_t>}, + {"op_OFFSET_ONLY_QASYMM8_QASYMM8", REGISTER_INTEGER_NEON(u8_u8_run_requantize_offset_only)}, + {"op_OFFSET_ONLY_QASYMM8_QASYMM8_SIGNED", REGISTER_INTEGER_NEON(u8_i8_run_requantize_offset_only)}, + {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8", REGISTER_INTEGER_NEON(i8_u8_run_requantize_offset_only)}, + {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8_SIGNED", REGISTER_INTEGER_NEON(i8_i8_run_requantize_offset_only)}, // Functions for offset uint8 to int8 and vice versa quantization (no scale changes) {"op_OFFSET_ONLY_CONVERT_QASYMM8_SIGNED_QASYMM8", - &CpuQuantizeKernel::run_requantize_offset_only_convert<int8_t, uint8_t>}, + REGISTER_INTEGER_NEON(i8_u8_run_requantize_offset_only_convert)}, {"op_OFFSET_ONLY_CONVERT_QASYMM8_QASYMM8_SIGNED", - &CpuQuantizeKernel::run_requantize_offset_only_convert<uint8_t, int8_t>}, - - {"op_F32_QSYMM8", &CpuQuantizeKernel::run_quantize_qsymm8<float, int8_t>}, - - {"op_F32_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<float, uint8_t>}, - {"op_F32_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<float, int8_t>}, - {"op_F32_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<float>}, - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - {"op_F16_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<float16_t, uint8_t>}, - {"op_F16_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<float16_t, int8_t>}, - {"op_F16_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<float16_t>}, -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC*/ + REGISTER_INTEGER_NEON(u8_i8_run_requantize_offset_only_convert)}, + + {"op_F32_QSYMM8", REGISTER_FP32_NEON(fp32_i8_run_quantize_qsymm8)}, + {"op_F32_QASYMM8", REGISTER_FP32_NEON(fp32_u8_run_quantize_qasymm8)}, + {"op_F32_QASYMM8_SIGNED", REGISTER_FP32_NEON(fp32_i8_run_quantize_qasymm8)}, + {"op_F32_QASYMM16", REGISTER_FP32_NEON(fp32_run_quantize_qasymm16)}, + +#ifdef ARM_COMPUTE_ENABLE_FP16 + {"op_F16_QASYMM8", REGISTER_FP16_NEON(fp16_u8_run_quantize_qasymm8)}, + {"op_F16_QASYMM8_SIGNED", REGISTER_FP16_NEON(fp16_i8_run_quantize_qasymm8)}, + {"op_F16_QASYMM16", REGISTER_FP16_NEON(fp16_run_quantize_qasymm16)}, +#endif /* ARM_COMPUTE_ENABLE_FP16 */ }; std::string function_to_call("op_"); @@ -203,242 +147,6 @@ Status CpuQuantizeKernel::validate(const ITensorInfo *src, const ITensorInfo *ds return Status{}; } -template <typename TIn, typename TOut> -void CpuQuantizeKernel::run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window) -{ - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); - UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); - uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(src, win_collapsed); - Iterator output(dst, win_collapsed); - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); - auto output_ptr = reinterpret_cast<TOut *>(output.ptr()); - int x = window_start_x; - for (; x <= (window_end_x - window_step); x += window_step) - { - wrapper::vstore(&output_ptr[x], vquantize_qasymm8<TOut>(load_value(&input_ptr[x]), uqinfo)); - } - // Compute left-over elements - for (; x < window_end_x; ++x) - { - output_ptr[x] = quantize_qsymm8(input_ptr[x], dst->info()->quantization_info()); - } - }, - input, output); -} - -template <typename TIn, typename TOut> -void CpuQuantizeKernel::run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window) -{ - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - // Calculate output offset difference. - const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); - UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); - uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Duplicate offset in signed vector format - const int8x16_t offset = wrapper::vdup_n(static_cast<int8_t>(uqinfo.offset), wrapper::traits::vector_128_tag{}); - - Iterator input(src, win_collapsed); - Iterator output(dst, win_collapsed); - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); - auto output_ptr = reinterpret_cast<TOut *>(output.ptr()); - int x = window_start_x; - for (; x <= (window_end_x - window_step); x += window_step) - { - const wrapper::traits::neon_vector_t<TIn, window_step> qv = - wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype - - // Signed addition. - auto res = vaddq_s8(reinterpret_cast<int8x16_t>(qv), offset); - - // Output is dependent on datatype. - wrapper::vstore(&output_ptr[x], - reinterpret_cast<wrapper::traits::neon_vector_t<TOut, window_step>>(res)); - } - // Compute left-over elements - for (; x < window_end_x; ++x) - { - auto result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]); - output_ptr[x] = static_cast<TOut>(result); - } - }, - input, output); -} - -template <typename TIn, typename TOut> -void CpuQuantizeKernel::run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window) -{ - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); - UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); - uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Duplicate offset in signed vector format - const int16x8_t offset = wrapper::vdup_n(static_cast<int16_t>(uqinfo.offset), wrapper::traits::vector_128_tag{}); - - const int32_t low_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 0 : -128; - const int32_t upper_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 255 : 127; - - Iterator input(src, win_collapsed); - Iterator output(dst, win_collapsed); - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); - TOut *output_ptr = reinterpret_cast<TOut *>(output.ptr()); - - int x = window_start_x; - for (; x <= (window_end_x - window_step); x += window_step) - { - const auto qv = wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype - int16x8_t lower = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgetlow(qv))); - int16x8_t upper = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgethigh(qv))); - - // Signed addition. - lower = wrapper::vqadd(lower, offset); - upper = wrapper::vqadd(upper, offset); - - // Output is dependent on datatype. - auto res = recombine_8_16<TOut>(lower, upper); - wrapper::vstore(&output_ptr[x], res); - } - // Compute left-over elements - for (; x < window_end_x; ++x) - { - // Add offset and clamp result to within the range of the output datatype. - int32_t result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]); - result = utility::clamp<int32_t>(result, low_bound, upper_bound); - - // Cast result to output datatype. - output_ptr[x] = static_cast<TOut>(result); - } - }, - input, output); -} - -template <typename TIn, typename TOut> -void CpuQuantizeKernel::run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) -{ - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); - UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); - if (is_data_type_quantized_asymmetric(src->info()->data_type())) - { - uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); - } -#ifdef __aarch64__ - constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN; -#else //__aarch64__ - constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO; -#endif //__aarch64__ - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(src, win_collapsed); - Iterator output(dst, win_collapsed); - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); - auto output_ptr = reinterpret_cast<TOut *>(output.ptr()); - - int x = window_start_x; - for (; x <= (window_end_x - window_step); x += window_step) - { - wrapper::vstore(&output_ptr[x], vquantize_qasymm8<TOut>(load_value(&input_ptr[x]), uqinfo)); - } - // Compute left-over elements - for (; x < window_end_x; ++x) - { - output_ptr[x] = Qasymm8QuantizationHelper<TOut>::quantize(input_ptr[x], uqinfo, rounding_policy); - } - }, - input, output); -} - -template <typename T> -void CpuQuantizeKernel::run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window) -{ - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); - UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); - if (is_data_type_quantized_asymmetric(src->info()->data_type())) - { - uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); - } -#ifdef __aarch64__ - constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN; -#else //__aarch64__ - constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO; -#endif //__aarch64__ - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(src, win_collapsed); - Iterator output(dst, win_collapsed); - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - auto input_ptr = reinterpret_cast<const T *>(input.ptr()); - auto output_ptr = reinterpret_cast<uint16_t *>(output.ptr()); - - int x = window_start_x; - for (; x <= (window_end_x - window_step); x += window_step) - { - uint16x8x2_t tmp = vquantize_qasymm16(load_value(&input_ptr[x]), uqinfo); - vst1q_u16(&output_ptr[x], tmp.val[0]); - vst1q_u16(&output_ptr[x + 8], tmp.val[1]); - } - // Compute left-over elements - for (; x < window_end_x; ++x) - { - output_ptr[x] = quantize_qasymm16(input_ptr[x], uqinfo, rounding_policy); - } - }, - input, output); -} - void CpuQuantizeKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) { ARM_COMPUTE_UNUSED(info); @@ -448,7 +156,7 @@ void CpuQuantizeKernel::run_op(ITensorPack &tensors, const Window &window, const const auto src = tensors.get_const_tensor(TensorType::ACL_SRC); auto dst = tensors.get_tensor(TensorType::ACL_DST); - (this->*_func)(src, dst, window); + (*_func)(src, dst, window); } const char *CpuQuantizeKernel::name() const diff --git a/src/cpu/kernels/CpuQuantizeKernel.h b/src/cpu/kernels/CpuQuantizeKernel.h index c2f7ac6d9d..750310c811 100644 --- a/src/cpu/kernels/CpuQuantizeKernel.h +++ b/src/cpu/kernels/CpuQuantizeKernel.h @@ -76,31 +76,7 @@ private: * * @param[in] window Region on which to execute the kernel. */ - using QuantizeFunctionExecutorPtr = void (CpuQuantizeKernel::*)(const ITensor *src, - ITensor *dst, - const Window &window); - /** Function to apply QASYMM8 or QASYMM8_SIGNED quantization on a tensor. - * - * @param[in] window Region on which to execute the kernel. - */ - template <typename TIn, typename TOut> - void run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window); - /** Function to apply QASYMM16 quantization on a tensor. - * - * @param[in] window Region on which to execute the kernel. - */ - template <typename T> - void run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window); - - template <typename TIn, typename TOut> - void run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window); - - template <typename TIn, typename TOut> - void run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window); - - template <typename TIn, typename TOut> - void run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window); - + using QuantizeFunctionExecutorPtr = void (*)(const ITensor *src, ITensor *dst, const Window &window); QuantizeFunctionExecutorPtr _func{nullptr}; size_t _split_dimension{Window::DimY}; }; diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp index 5cf81f815c..b7e395fb79 100644 --- a/src/cpu/kernels/CpuSoftmaxKernel.cpp +++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp @@ -48,6 +48,7 @@ namespace kernels { namespace { + /* Softmax */ static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_kernels = { {"sme2_fp32_softmax", @@ -65,9 +66,23 @@ static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_ker [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::F16) && data.isa.fp16; }, REGISTER_FP16_NEON(neon_fp16_softmax<false>)}, + {"sme2_qu8_softmax_lut_512VL", + [](const SoftmaxKernelDataTypeISASelectorData &data) + { + return (!data.is_log && data.dt == DataType::QASYMM8 && data.isa.sme2 && data.axis == 0 && + data.sme2_vector_length == 512); + }, + REGISTER_QASYMM8_SME2(sme2_qasymm8_softmax_lut_512VL)}, {"neon_qu8_softmax", [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::QASYMM8); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_softmax<false>)}, + {"sme2_qs8_softmax_lut_512VL", + [](const SoftmaxKernelDataTypeISASelectorData &data) + { + return (!data.is_log && data.dt == DataType::QASYMM8_SIGNED && data.isa.sme2 && data.axis == 0 && + data.sme2_vector_length == 512); + }, + REGISTER_QASYMM8_SIGNED_SME2(sme2_qasymm8_signed_softmax_lut_512VL)}, {"neon_qs8_softmax", [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::QASYMM8_SIGNED); }, @@ -88,6 +103,28 @@ static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_ker REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_signed_softmax<true>)}, }; +void init_lut(std::vector<float> &lut, DataType type, float scale, float beta) +{ + if (type == DataType::QASYMM8) + { + for (int i = 0; i < 256; ++i) + { + lut.push_back(std::exp(-scale * beta * i)); + } + } + else if (type == DataType::QASYMM8_SIGNED) + { + for (int i = -128; i < 128; ++i) + { + lut.push_back(std::exp(-scale * beta * i)); + } + } + else + { + ARM_COMPUTE_ERROR("Invalid datatype for QASYMM8/QASYMM8_SIGNED softmax"); + } +} + Status validate_arguments_softmax( const ITensorInfo &src, const ITensorInfo &dst, float beta, int axis, const ITensorInfo &tmp, bool is_log) { @@ -157,8 +194,8 @@ void CpuSoftmaxKernel::configure( auto_init_if_empty(*tmp, TensorInfo(*src).set_data_type(DataType::F32).reset_padding()); } - const auto *uk = CpuSoftmaxKernel::get_implementation( - SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log, axis}); + const auto *uk = CpuSoftmaxKernel::get_implementation(SoftmaxKernelDataTypeISASelectorData{ + src->data_type(), CPUInfo::get().get_isa(), is_log, axis, CPUInfo::get().get_sme2_vector_length()}); ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); std::string kernel_name = is_log ? std::string("CpuLogSoftmaxKernel") : std::string("CpuSoftmaxKernel"); @@ -194,6 +231,13 @@ void CpuSoftmaxKernel::configure( win.set(_axis, Window::Dimension(0, 1, 1)); ICpuKernel<CpuSoftmaxKernel>::configure(win); + + const std::string uk_name = uk->name; + if (uk_name == "sme2_qu8_softmax_lut_512VL" || uk_name == "sme2_qs8_softmax_lut_512VL") + { + const float scale = src->quantization_info().uniform().scale; + init_lut(_lut, src->data_type(), scale, beta); + } } Status CpuSoftmaxKernel::validate( @@ -230,11 +274,11 @@ void CpuSoftmaxKernel::run_op(ITensorPack &tensors, const Window &window, const const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration; void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread); - _run_method(src, tmp_for_thread, dst, _beta, _axis, window); + _run_method(src, tmp_for_thread, dst, _beta, _axis, window, _lut.data()); } else { - _run_method(src, nullptr, dst, _beta, _axis, window); + _run_method(src, nullptr, dst, _beta, _axis, window, nullptr); } } diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h index 043ad975d5..676e79782b 100644 --- a/src/cpu/kernels/CpuSoftmaxKernel.h +++ b/src/cpu/kernels/CpuSoftmaxKernel.h @@ -37,8 +37,8 @@ namespace kernels class CpuSoftmaxKernel : public ICpuKernel<CpuSoftmaxKernel> { private: - using SoftmaxKernelPtr = - std::add_pointer<void(const ITensor *, void *const, ITensor *, float, int, const Window &)>::type; + using SoftmaxKernelPtr = std::add_pointer<void( + const ITensor *, void *const, ITensor *, float, int, const Window &, const float *)>::type; public: CpuSoftmaxKernel() = default; @@ -78,10 +78,11 @@ public: static const std::vector<SoftmaxKernel> &get_available_kernels(); private: - float _beta{1.0f}; - SoftmaxKernelPtr _run_method{nullptr}; - std::string _name{}; - int _axis{}; + float _beta{1.0f}; + SoftmaxKernelPtr _run_method{nullptr}; + std::string _name{}; + int _axis{}; + std::vector<float> _lut = {}; }; } // namespace kernels } // namespace cpu diff --git a/src/cpu/kernels/quantize/generic/neon/fp16.cpp b/src/cpu/kernels/quantize/generic/neon/fp16.cpp new file mode 100644 index 0000000000..456a3bda31 --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/fp16.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) +#include "src/cpu/kernels/quantize/generic/neon/impl_fp16.h" + +namespace arm_compute +{ +namespace cpu +{ +void fp16_u8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<float16_t, uint8_t>(src, dst, window); +} +void fp16_i8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<float16_t, int8_t>(src, dst, window); +} +void fp16_run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm16<float16_t>(src, dst, window); +} +} // namespace cpu +} // namespace arm_compute +#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */ diff --git a/src/cpu/kernels/quantize/generic/neon/fp32.cpp b/src/cpu/kernels/quantize/generic/neon/fp32.cpp new file mode 100644 index 0000000000..15f52b2238 --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/fp32.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/cpu/kernels/quantize/generic/neon/impl_fp32.h" + +namespace arm_compute +{ +namespace cpu +{ +void fp32_u8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<float, uint8_t>(src, dst, window); +} +void fp32_i8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<float, int8_t>(src, dst, window); +} +void fp32_run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm16<float>(src, dst, window); +} + +void fp32_i8_run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qsymm8<float, int8_t>(src, dst, window); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/quantize/generic/neon/impl.h b/src/cpu/kernels/quantize/generic/neon/impl.h new file mode 100644 index 0000000000..1861fca391 --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/impl.h @@ -0,0 +1,302 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_H +#define ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_H + +#include "arm_compute/core/Helpers.h" + +#include "src/core/helpers/WindowHelpers.h" +#include "src/core/NEON/NEAsymm.h" +#include "src/core/NEON/wrapper/intrinsics/intrinsics.h" + +namespace arm_compute +{ +namespace cpu +{ +constexpr auto window_step = 16; + +template <typename T> +inline float32x4x4_t load_value(const T *input_ptr) +{ + using Tx16_t = typename wrapper::traits::neon_vector<T, 16>::type; + return arm_compute::convert_to_float32x4x4<Tx16_t>(wrapper::vloadq(input_ptr)); +} + +template <typename element_type> +using vector_type = wrapper::traits::neon_vector_t<element_type, window_step>; + +template <typename quantized_type> +vector_type<quantized_type> vquantize_qasymm8(const float32x4x4_t &qv, const UniformQuantizationInfo &qi); + +template <typename TOut, typename = typename std::enable_if<std::is_signed<TOut>::value, bool>::type> +inline int8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper) +{ + return wrapper::vcombine(wrapper::vqmovn(lower), wrapper::vqmovn(upper)); +} + +template <typename TOut, typename = typename std::enable_if<std::is_unsigned<TOut>::value, bool>::type> +inline uint8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper) +{ + return wrapper::vcombine(wrapper::vqmovun(lower), wrapper::vqmovun(upper)); +} + +template <typename TIn, typename TOut> +void run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); + UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); + uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); + auto output_ptr = reinterpret_cast<TOut *>(output.ptr()); + int x = window_start_x; + for (; x <= (window_end_x - window_step); x += window_step) + { + wrapper::vstore(&output_ptr[x], vquantize_qasymm8<TOut>(load_value(&input_ptr[x]), uqinfo)); + } + // Compute left-over elements + for (; x < window_end_x; ++x) + { + output_ptr[x] = quantize_qsymm8(input_ptr[x], dst->info()->quantization_info()); + } + }, + input, output); +} + +template <typename TIn, typename TOut> +void run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window) +{ + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + // Calculate output offset difference. + const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); + UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); + uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Duplicate offset in signed vector format + const int8x16_t offset = wrapper::vdup_n(static_cast<int8_t>(uqinfo.offset), wrapper::traits::vector_128_tag{}); + + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); + auto output_ptr = reinterpret_cast<TOut *>(output.ptr()); + int x = window_start_x; + for (; x <= (window_end_x - window_step); x += window_step) + { + const wrapper::traits::neon_vector_t<TIn, window_step> qv = + wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype + + // Signed addition. + auto res = vaddq_s8(reinterpret_cast<int8x16_t>(qv), offset); + + // Output is dependent on datatype. + wrapper::vstore(&output_ptr[x], + reinterpret_cast<wrapper::traits::neon_vector_t<TOut, window_step>>(res)); + } + // Compute left-over elements + for (; x < window_end_x; ++x) + { + auto result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]); + output_ptr[x] = static_cast<TOut>(result); + } + }, + input, output); +} + +template <typename TIn, typename TOut> +void run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window) +{ + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); + UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); + uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Duplicate offset in signed vector format + const int16x8_t offset = wrapper::vdup_n(static_cast<int16_t>(uqinfo.offset), wrapper::traits::vector_128_tag{}); + + const int32_t low_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 0 : -128; + const int32_t upper_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 255 : 127; + + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); + TOut *output_ptr = reinterpret_cast<TOut *>(output.ptr()); + + int x = window_start_x; + for (; x <= (window_end_x - window_step); x += window_step) + { + const auto qv = wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype + int16x8_t lower = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgetlow(qv))); + int16x8_t upper = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgethigh(qv))); + + // Signed addition. + lower = wrapper::vqadd(lower, offset); + upper = wrapper::vqadd(upper, offset); + + // Output is dependent on datatype. + auto res = recombine_8_16<TOut>(lower, upper); + wrapper::vstore(&output_ptr[x], res); + } + // Compute left-over elements + for (; x < window_end_x; ++x) + { + // Add offset and clamp result to within the range of the output datatype. + int32_t result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]); + result = utility::clamp<int32_t>(result, low_bound, upper_bound); + + // Cast result to output datatype. + output_ptr[x] = static_cast<TOut>(result); + } + }, + input, output); +} + +template <typename TIn, typename TOut> +void run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); + UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); + if (is_data_type_quantized_asymmetric(src->info()->data_type())) + { + uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); + } +#ifdef __aarch64__ + constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN; +#else //__aarch64__ + constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO; +#endif //__aarch64__ + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); + auto output_ptr = reinterpret_cast<TOut *>(output.ptr()); + + int x = window_start_x; + for (; x <= (window_end_x - window_step); x += window_step) + { + wrapper::vstore(&output_ptr[x], vquantize_qasymm8<TOut>(load_value(&input_ptr[x]), uqinfo)); + } + // Compute left-over elements + for (; x < window_end_x; ++x) + { + output_ptr[x] = Qasymm8QuantizationHelper<TOut>::quantize(input_ptr[x], uqinfo, rounding_policy); + } + }, + input, output); +} + +template <typename T> +void run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window) +{ + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); + UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); + if (is_data_type_quantized_asymmetric(src->info()->data_type())) + { + uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); + } +#ifdef __aarch64__ + constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN; +#else //__aarch64__ + constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO; +#endif //__aarch64__ + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + auto input_ptr = reinterpret_cast<const T *>(input.ptr()); + auto output_ptr = reinterpret_cast<uint16_t *>(output.ptr()); + + int x = window_start_x; + for (; x <= (window_end_x - window_step); x += window_step) + { + uint16x8x2_t tmp = vquantize_qasymm16(load_value(&input_ptr[x]), uqinfo); + vst1q_u16(&output_ptr[x], tmp.val[0]); + vst1q_u16(&output_ptr[x + 8], tmp.val[1]); + } + // Compute left-over elements + for (; x < window_end_x; ++x) + { + output_ptr[x] = quantize_qasymm16(input_ptr[x], uqinfo, rounding_policy); + } + }, + input, output); +} +} // namespace cpu +} // namespace arm_compute + +#endif // ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_H diff --git a/src/cpu/kernels/quantize/generic/neon/impl_fp16.h b/src/cpu/kernels/quantize/generic/neon/impl_fp16.h new file mode 100644 index 0000000000..47f1b90abd --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/impl_fp16.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_FP16_H +#define ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_FP16_H +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#include "src/core/helpers/WindowHelpers.h" +#include "src/core/NEON/NEAsymm.h" + +namespace arm_compute +{ +namespace cpu +{ + +inline float32x4x4_t load_value(const float16_t *input_ptr) +{ + return {vcvt_f32_f16(wrapper::vload(input_ptr)), vcvt_f32_f16(wrapper::vload(input_ptr + 4)), + vcvt_f32_f16(wrapper::vload(input_ptr + 8)), vcvt_f32_f16(wrapper::vload(input_ptr + 12))}; +} + +} // namespace cpu +} // namespace arm_compute +#include "impl.h" +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_FP16_H diff --git a/src/cpu/kernels/quantize/generic/neon/impl_fp32.h b/src/cpu/kernels/quantize/generic/neon/impl_fp32.h new file mode 100644 index 0000000000..00ae242567 --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/impl_fp32.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_FP32_H +#define ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_FP32_H + +#include "src/core/helpers/WindowHelpers.h" +#include "src/core/NEON/NEAsymm.h" + +namespace arm_compute +{ +namespace cpu +{ +inline float32x4x4_t load_value(const float *input_ptr) +{ + return {wrapper::vloadq(input_ptr), wrapper::vloadq(input_ptr + 4), wrapper::vloadq(input_ptr + 8), + wrapper::vloadq(input_ptr + 12)}; +} + +} // namespace cpu +} // namespace arm_compute + +#include "impl.h" +#endif // ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_FP32_H diff --git a/src/cpu/kernels/quantize/generic/neon/integer.cpp b/src/cpu/kernels/quantize/generic/neon/integer.cpp new file mode 100644 index 0000000000..4e39afaaee --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/integer.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/cpu/kernels/quantize/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void u8_u8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<uint8_t, uint8_t>(src, dst, window); +} +void u8_i8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<uint8_t, int8_t>(src, dst, window); +} +void i8_u8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<int8_t, uint8_t>(src, dst, window); +} +void i8_i8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<int8_t, int8_t>(src, dst, window); +} + +void u8_run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm16<uint8_t>(src, dst, window); +} +void i8_run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm16<int8_t>(src, dst, window); +} + +void u8_u8_run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window) +{ + run_requantize_offset_only<uint8_t, uint8_t>(src, dst, window); +} +void u8_i8_run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window) +{ + run_requantize_offset_only<uint8_t, int8_t>(src, dst, window); +} +void i8_u8_run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window) +{ + run_requantize_offset_only<int8_t, uint8_t>(src, dst, window); +} +void i8_i8_run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window) +{ + run_requantize_offset_only<int8_t, int8_t>(src, dst, window); +} + +void i8_u8_run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window) +{ + run_requantize_offset_only_convert<int8_t, uint8_t>(src, dst, window); +} +void u8_i8_run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window) +{ + run_requantize_offset_only_convert<uint8_t, int8_t>(src, dst, window); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/quantize/generic/neon/list.h b/src/cpu/kernels/quantize/generic/neon/list.h new file mode 100644 index 0000000000..c4fb1048eb --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/list.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_LIST_H +#define ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_LIST_H + +#include "arm_compute/core/Helpers.h" + +namespace arm_compute +{ +namespace cpu +{ + +#define DECLARE_QUANTIZE_KERNEL(func_name) void func_name(const ITensor *src, ITensor *dst, const Window &window) + +DECLARE_QUANTIZE_KERNEL(u8_u8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(u8_i8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(i8_u8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(i8_i8_run_quantize_qasymm8); + +DECLARE_QUANTIZE_KERNEL(u8_u8_run_requantize_offset_only); +DECLARE_QUANTIZE_KERNEL(u8_i8_run_requantize_offset_only); +DECLARE_QUANTIZE_KERNEL(i8_u8_run_requantize_offset_only); +DECLARE_QUANTIZE_KERNEL(i8_i8_run_requantize_offset_only); + +DECLARE_QUANTIZE_KERNEL(i8_u8_run_requantize_offset_only_convert); +DECLARE_QUANTIZE_KERNEL(u8_i8_run_requantize_offset_only_convert); + +DECLARE_QUANTIZE_KERNEL(u8_run_quantize_qasymm16); +DECLARE_QUANTIZE_KERNEL(i8_run_quantize_qasymm16); + +DECLARE_QUANTIZE_KERNEL(fp32_u8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(fp32_i8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(fp32_run_quantize_qasymm16); + +DECLARE_QUANTIZE_KERNEL(fp32_i8_run_quantize_qsymm8); + +DECLARE_QUANTIZE_KERNEL(fp16_u8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(fp16_i8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(fp16_run_quantize_qasymm16); + +#undef DECLARE_QUANTIZE_KERNEL + +} // namespace cpu +} // namespace arm_compute +#endif // ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_LIST_H diff --git a/src/cpu/kernels/quantize/generic/neon/vquantize.cpp b/src/cpu/kernels/quantize/generic/neon/vquantize.cpp new file mode 100644 index 0000000000..d40702bc88 --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/vquantize.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "impl.h" +namespace arm_compute +{ +namespace cpu +{ +template <> +vector_type<uint8_t> vquantize_qasymm8<uint8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi) +{ + return vquantize(qv, qi); +} + +template <> +vector_type<int8_t> vquantize_qasymm8<int8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi) +{ + return vquantize_signed(qv, qi); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp b/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp new file mode 100644 index 0000000000..41584e954b --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) + +#include "src/cpu/kernels/reduction_layer/generic/neon/impl_fp16.h" + +namespace arm_compute +{ +namespace cpu +{ +void reduce_RedOpX_reduceX_float16_8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op); +} + +void reduce_RedOpYZW_reduceY_float16_8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op); +} + +void reduce_RedOpYZW_reduceZ_float16_8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op); +} + +void reduce_RedOpYZW_reduceW_float16_8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op); +} +} // namespace cpu +} // namespace arm_compute +#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */ diff --git a/src/cpu/kernels/reduction_layer/generic/neon/fp32.cpp b/src/cpu/kernels/reduction_layer/generic/neon/fp32.cpp new file mode 100644 index 0000000000..6f5f13e571 --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/fp32.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void reduce_RedOpYZW_complex_reduceZ_float32_4_2_SUM(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ( + window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op); +} + +void reduce_RedOpX_reduceX_float32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op); +} + +void reduce_RedOpYZW_reduceY_float32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op); +} + +void reduce_RedOpYZW_reduceZ_float32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op); +} + +void reduce_RedOpYZW_reduceW_float32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op); +} + +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/reduction_layer/generic/neon/impl.h b/src/cpu/kernels/reduction_layer/generic/neon/impl.h new file mode 100644 index 0000000000..611d83cf7e --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/impl.h @@ -0,0 +1,1543 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_H +#define ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_H + +#include "arm_compute/core/Coordinates.h" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/TensorInfo.h" + +#include "src/core/NEON/NEMath.h" +#include "src/core/NEON/wrapper/wrapper.h" +#include "support/SaturateCast.h" + +#include <arm_neon.h> + +namespace arm_compute +{ +// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized +template <typename T> +void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0) +{ + if (std::is_same<T, uint8_t>::value) + { + auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2)); + wrapper::vstore(output.ptr() + offset, res); + } + else + { + auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2)); + wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res); + } +} + +template <typename T> +uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis) +{ + uint32x4_t mask{0}; + if (op == ReductionOperation::ARG_IDX_MIN) + { + mask = wrapper::vcgt(b, a); + } + else + { + mask = wrapper::vclt(b, a); + } + + uint32x4_t vec_idx = {idx, idx + 1, idx + 2, idx + 3}; + if (axis != 0) + { + vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + } + uint32x4x4_t res = {{wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0}}; + + return res; +} + +template <typename T> +uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis) +{ + uint32x4x4_t mask{{0}}; + uint8x16_t mask_u8{0}; + if (op == ReductionOperation::ARG_IDX_MIN) + { + mask_u8 = wrapper::vcgt(b, a); + } + else + { + mask_u8 = wrapper::vclt(b, a); + } + auto wide_u16_1 = + wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8))); + auto wide_u16_2 = + wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8))); + mask.val[0] = + wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1))); + mask.val[1] = + wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1))); + mask.val[2] = + wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2))); + mask.val[3] = + wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2))); + + uint32x4x4_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, + {idx + 4, idx + 5, idx + 6, idx + 7}, + {idx + 8, idx + 9, idx + 10, idx + 11}, + {idx + 12, idx + 13, idx + 14, idx + 15}}}; + if (axis != 0) + { + vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + } + uint32x4x4_t res = { + {vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]), vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]), + vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]), vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])}}; + + return res; +} + +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +template <typename T> +inline typename std::enable_if< + std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value, + typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type +calculate_min(T in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + return wrapper::vpmin(pmin, pmin); +} + +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +template <typename T> +inline typename std::enable_if< + std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value, + typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type +calculate_min(T in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmin = wrapper::vpmin(pmin, pmin); + pmin = wrapper::vpmin(pmin, pmin); + return wrapper::vpmin(pmin, pmin); +} + +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +template <typename T> +inline typename std::enable_if< + std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value, + typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type +calculate_max(T in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + return wrapper::vpmax(pmax, pmax); +} + +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +template <typename T> +inline typename std::enable_if< + std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value, + typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type +calculate_max(T in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmax = wrapper::vpmax(pmax, pmax); + pmax = wrapper::vpmax(pmax, pmax); + return wrapper::vpmax(pmax, pmax); +} + +template <typename T> +uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) +{ + uint32x4_t res_idx_mask{0}; + uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); + + if (op == ReductionOperation::ARG_IDX_MIN) + { + auto pmin = calculate_min(vec_res_value); + auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); + res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); + } + else + { + auto pmax = calculate_max(vec_res_value); + auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); + res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); + } + + res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones); + auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask)); + pmin = wrapper::vpmin(pmin, pmin); + uint32_t res = wrapper::vgetlane(pmin, 0); + + return (res - 0xFFFFFFFF); +} + +template <typename T> +uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) +{ + uint32x4x4_t res_idx_mask{{0}}; + uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); + uint8x16_t mask_u8{0}; + if (op == ReductionOperation::ARG_IDX_MIN) + { + auto pmin = calculate_min(vec_res_value); + mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); + } + else + { + auto pmax = calculate_max(vec_res_value); + mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); + } + + // Widen vectors + auto wide_u16_1 = + wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8))); + auto wide_u16_2 = + wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8))); + auto wide_u32_1 = + wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1))); + auto wide_u32_2 = + wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1))); + auto wide_u32_3 = + wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2))); + auto wide_u32_4 = + wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2))); + res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1); + res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2); + res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3); + res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4); + res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones); + res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones); + res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones); + res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones); + + uint32_t res = 0xFFFFFFFF; + int iter = 0; + do + { + auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter])); + pmin = wrapper::vpmin(pmin, pmin); + res = std::min(wrapper::vgetlane(pmin, 0), res); + iter++; + } while (iter < 4); + + return (res - 0xFFFFFFFF); +} + +template <class F> +class Reducer +{ +public: + static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) + { + // Set out window + Window out_window(window); + out_window.set(Window::DimX, Window::Dimension(0, 1, 1)); + + f(window, out_window, input, output, op); + } + static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) + { + // Set in window + Window in_window(window); + Window out_window(window); + + in_window.set(Window::DimY, Window::Dimension(0, 1, 1)); + out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1))); + + f(in_window, out_window, input, output, 1, op); + } + static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) + { + // Set in window + Window in_window(window); + Window out_window(window); + + in_window.set(Window::DimZ, Window::Dimension(0, 1, 1)); + out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2))); + + f(in_window, out_window, input, output, 2, op); + } + static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) + { + // Set in/out window + Window in_window(window); + Window out_window(window); + + in_window.set(3, Window::Dimension(0, 1, 1)); + out_window.set(3, Window::Dimension(0, 1, 1)); + + f(in_window, out_window, input, output, 3, op); + } +}; + +template <typename T, int S> +struct RedOpX +{ + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; + + inline void operator()( + const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op) + { + const size_t input_dim_0 = in->info()->dimension(0); + const int window_step_x = 16 / sizeof(T); + const auto window_start_x = static_cast<int>(in_window.x().start()); + const auto window_end_x = static_cast<int>(in_window.x().end()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_window); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast<const T *>(input.ptr()); + + auto init_res_value = static_cast<T>(0.f); + switch (op) + { + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + init_res_value = static_cast<T>(*input_ptr); + break; + } + case ReductionOperation::PROD: + { + init_res_value = static_cast<T>(1.f); + break; + } + default: + break; + } + auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{}); + uint32x4x4_t vec_res_idx{{0}}; + + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vec_elements = wrapper::vloadq(input_ptr + x); + switch (op) + { + case ReductionOperation::SUM_SQUARE: + vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); + break; + case ReductionOperation::MEAN_SUM: + case ReductionOperation::SUM: + vec_res_value = wrapper::vadd(vec_elements, vec_res_value); + break; + case ReductionOperation::PROD: + vec_res_value = wrapper::vmul(vec_elements, vec_res_value); + break; + case ReductionOperation::ARG_IDX_MIN: + { + auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, + vec_res_idx, op, 0); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, + vec_res_idx, op, 0); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + case ReductionOperation::SUM_SQUARE: + { +#ifdef ARM_COMPUTE_DEBUG_ENABLED + auto res = static_cast<T>(0.f); + for (int i = 0; i < S; ++i) + { + res += wrapper::vgetlane(vec_res_value, i); + } +#else // ARM_COMPUTE_DEBUG_ENABLED + auto carry_res = + wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); + for (int i = 0; i < S / 4; ++i) + { + carry_res = wrapper::vpadd(carry_res, carry_res); + } + auto res = wrapper::vgetlane(carry_res, 0); +#endif // ARM_COMPUTE_DEBUG_ENABLED + if (op == ReductionOperation::SUM_SQUARE) + { + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res += (*(input_ptr + x)) * (*(input_ptr + x)); + } + } + else + { + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res += *(input_ptr + x); + } + } + + if (op == ReductionOperation::MEAN_SUM) + { + res /= input_dim_0; + } + + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::PROD: + { + auto carry_res = + wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); + T res = 1; + for (int i = 0; i < S / 2; ++i) + { + res *= wrapper::vgetlane(carry_res, i); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res *= *(input_ptr + x); + } + + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::ARG_IDX_MIN: + { + auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); + auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + if (*(input_ptr + x) < res) + { + idx = x; + res = *(input_ptr + x); + } + } + *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); + auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + if (*(input_ptr + x) > res) + { + idx = x; + res = *(input_ptr + x); + } + } + *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; + break; + } + case ReductionOperation::MIN: + { + auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res = *(input_ptr + x) < res ? *(input_ptr + x) : res; + } + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::MAX: + { + auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res = *(input_ptr + x) > res ? *(input_ptr + x) : res; + } + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + }, + input, output); + } +}; + +template <typename T> +struct RedOpX_quantized +{ + inline void operator()( + const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op) + { + using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type; + + const auto oq_info = out->info()->quantization_info().uniform(); + + const TensorInfo in_info = *(in->info()); + const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform(); + + const int window_step_x = 16 / sizeof(T); + const auto window_start_x = static_cast<int>(in_window.x().start()); + const auto window_end_x = static_cast<int>(in_window.x().end()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_window); + + const auto in_offset = static_cast<float>(iq_info.offset); + const float in_scale = iq_info.scale; + + const auto out_offset = static_cast<float>(oq_info.offset); + const float out_scale = oq_info.scale; + + const auto num_elements = static_cast<float>(in_info.dimension(0)); + + const float A = in_scale / (out_scale * num_elements); + const float B = out_offset - (in_scale * in_offset) / (out_scale); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast<T *>(input.ptr()); + + auto vec_res_value1 = + wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); + auto vec_res_value2 = + wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); + auto vec_res_value3 = + wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); + auto vec_res_value4 = + wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); + + auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f)); + auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f)); + auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f)); + auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f)); + + typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = {0}; + + if (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || + op == ReductionOperation::MIN || op == ReductionOperation::MAX) + { + vec_res_value = wrapper::vdup_n(*input_ptr, wrapper::traits::vector_128_tag{}); + } + + uint32x4x4_t vec_res_idx{{0}}; + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vec_elements = wrapper::vloadq(input_ptr + x); + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + { + const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); + const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); + + const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); + const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); + const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); + const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); + + vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1); + vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2); + vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3); + vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4); + break; + } + case ReductionOperation::PROD: + { + const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset); + const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale); + + const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); + const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); + + const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); + const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); + const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); + const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); + + auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1); + auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2); + auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3); + auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4); + + //de-quantize vec_elements + temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4); + temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4); + temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4); + temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4); + + vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f); + vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f); + vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f); + vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f); + break; + } + case ReductionOperation::ARG_IDX_MIN: + { + auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>( + x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>( + x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + switch (op) + { + case ReductionOperation::ARG_IDX_MIN: + { + auto idx = + calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); + auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + if (*(input_ptr + x) < res) + { + idx = x; + res = *(input_ptr + x); + } + } + *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto idx = + calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); + auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + if (*(input_ptr + x) > res) + { + idx = x; + res = *(input_ptr + x); + } + } + *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; + break; + } + case ReductionOperation::MIN: + { + auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res = *(input_ptr + x) < res ? *(input_ptr + x) : res; + } + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::MAX: + { + auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res = *(input_ptr + x) > res ? *(input_ptr + x) : res; + } + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::PROD: + { + auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); + carry_res = wrapper::vmul(carry_res, vec_res_value3_f); + carry_res = wrapper::vmul(carry_res, vec_res_value4_f); + + float res = wrapper::vgetlane(carry_res, 0); + res *= wrapper::vgetlane(carry_res, 1); + res *= wrapper::vgetlane(carry_res, 2); + res *= wrapper::vgetlane(carry_res, 3); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + //de-quantize input + if (std::is_same<T, uint8_t>::value) + { + res *= dequantize_qasymm8(*(input_ptr + x), iq_info); + } + else + { + res *= dequantize_qasymm8_signed(*(input_ptr + x), iq_info); + } + } + + //re-quantize result + if (std::is_same<T, uint8_t>::value) + { + res = quantize_qasymm8(res, iq_info); + } + else + { + res = quantize_qasymm8_signed(res, iq_info); + } + + *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res); + break; + } + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + { + auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2); + carry_res = wrapper::vadd(carry_res, vec_res_value3); + carry_res = wrapper::vadd(carry_res, vec_res_value4); + + auto carry_paddition = + wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res)); + carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition); + auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res += *(input_ptr + x); + } + + if (op == ReductionOperation::MEAN_SUM) + { + const int32_t resFinal = A * (static_cast<float>(res)) + B; + + *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(resFinal); + } + else + { + // Subtract accumulated offsets + res -= (in_info.dimension(0) - 1) * iq_info.offset; + *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res); + } + + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + }, + input, output); + } +}; + +template <typename T, int S> +struct RedOpYZW +{ + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; + using neon_vector = typename wrapper::traits::neon_vector<T, S>::type; + + inline void operator()(const Window &in_window, + Window &out_window, + const ITensor *in, + ITensor *out, + int axis, + const ReductionOperation op) + { + const TensorInfo in_info = *(in->info()); + const int window_step_x = 16 / sizeof(T); + const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); + const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); + // As it split over x-axis, need to set the correct spiltted window start and end. + const auto window_start_x = static_cast<int>(0); + const auto window_end_x = static_cast<int>(in_window.shape().x()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); + Window out_win_no_pad = out_window; + out_win_no_pad.set(Window::DimX, + Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_win_no_pad); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast<T *>(input.ptr()); + + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + neon_vector vec_res_value = {0}; + switch (op) + { + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vloadq(input_ptr + x); + break; + } + case ReductionOperation::PROD: + { + vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{}); + break; + } + default: + { + vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); + break; + } + } + uint32x4x4_t vec_res_idx{{0}}; + + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + const T *in_ptr = + reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim); + const auto vec_elements = wrapper::vloadq(in_ptr); + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + vec_res_value = wrapper::vadd(vec_elements, vec_res_value); + break; + case ReductionOperation::SUM_SQUARE: + vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); + break; + case ReductionOperation::PROD: + vec_res_value = wrapper::vmul(vec_elements, vec_res_value); + break; + case ReductionOperation::ARG_IDX_MIN: + { + auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + vec_res_idx = + calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + vec_res_idx = + calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + if (op == ReductionOperation::MEAN_SUM) + { + auto vec_width_inv = + wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{})); + vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv); + } + + if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) + { + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]); + } + else + { + wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value); + } + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + auto res_value = 0.f; + switch (op) + { + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + res_value = *(input_ptr + x); + break; + } + case ReductionOperation::PROD: + { + res_value = static_cast<T>(1.f); + break; + } + default: + { + res_value = static_cast<T>(0.f); + break; + } + } + + uint32_t res_idx = 0; + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + const T *in_ptr = + reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim); + + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + res_value += *in_ptr; + break; + case ReductionOperation::SUM_SQUARE: + res_value += *in_ptr * *in_ptr; + break; + case ReductionOperation::PROD: + res_value *= *in_ptr; + break; + case ReductionOperation::ARG_IDX_MIN: + { + if (*in_ptr < res_value) + { + res_value = *in_ptr; + res_idx = dim; + } + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + if (*in_ptr > res_value) + { + res_value = *in_ptr; + res_idx = dim; + } + break; + } + case ReductionOperation::MIN: + { + res_value = *in_ptr < res_value ? *in_ptr : res_value; + break; + } + case ReductionOperation::MAX: + { + res_value = *in_ptr > res_value ? *in_ptr : res_value; + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + if (op == ReductionOperation::MEAN_SUM) + { + res_value /= in_info.dimension(axis); + } + + if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) + { + *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx; + } + else + { + *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value; + } + } + }, + input, output); + } +}; + +template <typename T, int S, int axis, ReductionOperation op> +struct RedOpYZW_complex +{ + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; + using neon_vector = typename wrapper::traits::neon_vector<T, S>::type; + + inline void operator()( + const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation) + { + ARM_COMPUTE_ERROR_ON(axis != 2); + ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM); + + const TensorInfo in_info = *(in->info()); + const size_t stride_z = in_info.strides_in_bytes()[axis]; + const int window_step_x = 16 / sizeof(T); + const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); + const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); + // As it split over x-axis, need to set the correct spiltted window start and end. + const auto window_start_x = static_cast<int>(0); + const auto window_end_x = static_cast<int>(in_window.shape().x()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); + Window out_win_no_pad = out_window; + out_win_no_pad.set(Window::DimX, + Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_win_no_pad); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + neon_vector vec_res_value_0 = {0}; + neon_vector vec_res_value_1 = {0}; + + vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); + vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); + + T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T)); + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim); + T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim); + + const auto vec_elements_0 = wrapper::vloadq(in_ptr_0); + const auto vec_elements_1 = wrapper::vloadq(in_ptr_1); + + vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0); + vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1); + } + + wrapper::vstore(out_ptr, vec_res_value_0); + wrapper::vstore(out_ptr + 4, vec_res_value_1); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + auto res_value_0 = 0.f; + auto res_value_1 = 0.f; + + T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T)); + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim); + res_value_0 += *in_ptr; + res_value_1 += *(in_ptr + 1); + } + *out_ptr = res_value_0; + *(out_ptr + 1) = res_value_1; + } + }, + input, output); + } +}; + +template <typename T> +struct RedOpYZW_quantized +{ + inline void operator()(const Window &in_window, + Window &out_window, + const ITensor *in, + ITensor *out, + int axis, + const ReductionOperation op) + { + const TensorInfo in_info = *(in->info()); + const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform(); + using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type; + + const auto oq_info = out->info()->quantization_info().uniform(); + + const int window_step_x = 16 / sizeof(T); + const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); + const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); + // As it split over x-axis, need to set the correct spiltted window start and end. + const auto window_start_x = static_cast<int>(0); + const auto window_end_x = static_cast<int>(in_window.shape().x()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); + Window out_win_no_pad = out_window; + out_win_no_pad.set(Window::DimX, + Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_win_no_pad); + + using vector_type = + typename wrapper::traits::neon_bitvector<PromotedType, wrapper::traits::BitWidth::W128>::type; + using vector_type_f = typename wrapper::traits::neon_vector<float, 4>::type; + + vector_type vec_res_value1{}; + vector_type vec_res_value2{}; + vector_type vec_res_value3{}; + vector_type vec_res_value4{}; + + vector_type_f vec_res_value1_f{}; + vector_type_f vec_res_value2_f{}; + vector_type_f vec_res_value3_f{}; + vector_type_f vec_res_value4_f{}; + + const float in_offset = static_cast<float>(iq_info.offset); + const float in_scale = iq_info.scale; + + const float out_offset = static_cast<float>(oq_info.offset); + const float out_scale = oq_info.scale; + + const float num_elements = static_cast<float>(in_info.dimension(axis)); + + const float A = in_scale / (out_scale * num_elements); + const float B = out_offset - (in_scale * in_offset) / (out_scale); + + const auto vec_A = wrapper::vdup_n(static_cast<float>(A), wrapper::traits::vector_128_tag{}); + const auto vec_B = wrapper::vdup_n(static_cast<float>(B), wrapper::traits::vector_128_tag{}); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast<T *>(input.ptr()); + + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + uint32x4x4_t vec_res_idx{{0}}; + vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); + vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); + vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); + vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); + + vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); + vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); + vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); + vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); + + auto vec_res_value = wrapper::vloadq(input_ptr + x); + + for (unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim) + { + const T *in_ptr = input_ptr + x + in_info.strides_in_bytes()[axis] * index_dim; + const auto vec_elements = wrapper::vloadq(in_ptr); + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + { + const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); + const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); + + const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); + const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); + const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); + const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); + + vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1); + vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2); + vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3); + vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4); + break; + } + case ReductionOperation::PROD: + { + const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), + wrapper::traits::vector_128_tag{}); + const auto scale32x4f_4 = + wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{}); + + const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); + const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); + + const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); + const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); + const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); + const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); + + auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1); + auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2); + auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3); + auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4); + + //de-quantize vec_elements + temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4); + temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4); + temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4); + temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4); + + vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f); + vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f); + vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f); + vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f); + break; + } + case ReductionOperation::ARG_IDX_MIN: + { + auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, + vec_res_idx, op, axis); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, + vec_res_idx, op, axis); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + switch (op) + { + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::ARG_IDX_MAX: + { + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x), vec_res_idx.val[0]); + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 4, vec_res_idx.val[1]); + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 8, vec_res_idx.val[2]); + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 12, + vec_res_idx.val[3]); + break; + } + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), vec_res_value); + break; + } + case ReductionOperation::SUM: + { + // Subtract offsets + auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset); + + auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1); + auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2); + auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3); + auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4); + + vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets); + vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets); + vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets); + vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets); + + const auto temp16x8t_1 = + wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2)); + const auto temp16x8t_2 = + wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4)); + + combine_and_store<T>(temp16x8t_1, temp16x8t_2, output, x); + break; + } + case ReductionOperation::MEAN_SUM: + { + vec_res_value1_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value1), vec_A); + vec_res_value2_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value2), vec_A); + vec_res_value3_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value3), vec_A); + vec_res_value4_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value4), vec_A); + +#ifdef __aarch64__ + vec_res_value1 = wrapper::vcvta<PromotedType>(vec_res_value1_f); + vec_res_value2 = wrapper::vcvta<PromotedType>(vec_res_value2_f); + vec_res_value3 = wrapper::vcvta<PromotedType>(vec_res_value3_f); + vec_res_value4 = wrapper::vcvta<PromotedType>(vec_res_value4_f); +#else // defined(__aarch64__) + vec_res_value1 = wrapper::vcvt<PromotedType>(vec_res_value1_f); + vec_res_value2 = wrapper::vcvt<PromotedType>(vec_res_value2_f); + vec_res_value3 = wrapper::vcvt<PromotedType>(vec_res_value3_f); + vec_res_value4 = wrapper::vcvt<PromotedType>(vec_res_value4_f); +#endif // __aarch64__ + + const auto temp16x8t_1 = + wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2)); + const auto temp16x8t_2 = + wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4)); + auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2)); + + wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res); + break; + } + case ReductionOperation::PROD: + { + const auto offset32x4f_4 = + wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{}); + const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale)); + + //re-quantize + vec_res_value1_f = + wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4); + vec_res_value2_f = + wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4); + vec_res_value3_f = + wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4); + vec_res_value4_f = + wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4); + + vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f); + vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f); + vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f); + vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f); + + const auto temp16x8t_1 = + wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2)); + const auto temp16x8t_2 = + wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4)); + auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2)); + + wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res); + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + float res_value = 0.f; + int32_t res_value_q = 0; + + switch (op) + { + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + res_value = *(input_ptr + x); + break; + } + case ReductionOperation::PROD: + { + res_value = static_cast<T>(1.0f); + break; + } + default: + { + res_value = static_cast<T>(0.0f); + break; + } + } + uint32_t res_idx = 0; + + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + const T *in_ptr = + reinterpret_cast<T *>(input.ptr() + x + in_info.strides_in_bytes()[axis] * dim); + switch (op) + { + case ReductionOperation::SUM: + { + res_value += *in_ptr; + break; + } + case ReductionOperation::MEAN_SUM: + { + res_value_q += *in_ptr; + break; + } + case ReductionOperation::SUM_SQUARE: + { + res_value += *in_ptr * *in_ptr; + break; + } + case ReductionOperation::PROD: + { + //de-quantize input + if (std::is_same<T, uint8_t>::value) + { + res_value *= dequantize_qasymm8(*in_ptr, iq_info); + } + else + { + res_value *= dequantize_qasymm8_signed(*in_ptr, iq_info); + } + break; + } + case ReductionOperation::ARG_IDX_MIN: + { + if (*in_ptr < res_value) + { + res_value = *in_ptr; + res_idx = dim; + } + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + if (*in_ptr > res_value) + { + res_value = *in_ptr; + res_idx = dim; + } + break; + } + case ReductionOperation::MIN: + { + res_value = *in_ptr < res_value ? *in_ptr : res_value; + break; + } + case ReductionOperation::MAX: + { + res_value = *in_ptr > res_value ? *in_ptr : res_value; + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + switch (op) + { + case ReductionOperation::MEAN_SUM: + { + // Apply previously calculated coefficients (with rounding on aarch64) +#ifdef __aarch64__ + const int32_t res = + arm_compute::support::cpp11::round(A * (static_cast<float>(res_value_q)) + B); +#else // defined(__aarch64__) + const int32_t res = A * (static_cast<float>(res_value_q)) + B; +#endif // __aarch64__ + *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res); + break; + } + case ReductionOperation::SUM: + { + // Subtract accumulated offsets + res_value -= (in_info.dimension(axis) - 1) * iq_info.offset; + *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value); + break; + } + case ReductionOperation::PROD: + { + //re-quantize result + T res = 0; + if (std::is_same<T, uint8_t>::value) + { + res = quantize_qasymm8(res_value, iq_info); + } + else + { + res = quantize_qasymm8_signed(res_value, iq_info); + } + *(reinterpret_cast<T *>(output.ptr() + x)) = res; + break; + } + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::ARG_IDX_MAX: + { + *(reinterpret_cast<uint32_t *>(output.ptr() + x * 4)) = res_idx; + break; + } + default: + *(reinterpret_cast<T *>(output.ptr() + x)) = res_value; + } + } + }, + input, output); + } +}; + +} // namespace arm_compute +#endif // ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_H diff --git a/src/cpu/kernels/reduction_layer/generic/neon/impl_fp16.h b/src/cpu/kernels/reduction_layer/generic/neon/impl_fp16.h new file mode 100644 index 0000000000..c7ca36d5e8 --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/impl_fp16.h @@ -0,0 +1,718 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_FP16_H +#define ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_FP16_H + +#include "arm_compute/core/Coordinates.h" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/TensorInfo.h" + +#include "src/core/NEON/NEMath.h" +#include "src/core/NEON/wrapper/wrapper.h" +#include "support/SaturateCast.h" + +#include <arm_neon.h> + +namespace arm_compute +{ +// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized +void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0) +{ + auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2)); + wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res); +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +uint32x4x4_t +calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis) +{ + uint32x4x2_t mask{0}; + uint16x8_t mask_u16{0}; + if (op == ReductionOperation::ARG_IDX_MIN) + { + mask_u16 = wrapper::vcgt(b, a); + } + else + { + mask_u16 = wrapper::vclt(b, a); + } + mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16)); + mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16)); + uint32x4x2_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, {idx + 4, idx + 5, idx + 6, idx + 7}}}; + if (axis != 0) + { + vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + } + uint32x4x4_t res = {wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]), + wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]), 0, 0}; + + return res; +} + +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +inline float16x4_t calculate_min(float16x8_t in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmin = wrapper::vpmin(pmin, pmin); + return wrapper::vpmin(pmin, pmin); +} +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +inline float16x4_t calculate_max(float16x8_t in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmax = wrapper::vpmax(pmax, pmax); + return wrapper::vpmax(pmax, pmax); +} + +uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op) +{ + uint32x4x2_t res_idx_mask{0}; + uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); + uint16x8_t mask_u16; + if (op == ReductionOperation::ARG_IDX_MIN) + { + auto pmin = calculate_min(vec_res_value); + mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); + } + else + { + auto pmax = calculate_max(vec_res_value); + mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); + } + + // Widen vectors + auto wide_u32_1 = + wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16))); + auto wide_u32_2 = + wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16))); + res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1); + res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2); + res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones); + res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones); + + uint32_t res = 0xFFFFFFFF; + uint32_t iter = 0; + do + { + auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter])); + pmin = wrapper::vpmin(pmin, pmin); + res = std::min(wrapper::vgetlane(pmin, 0), res); + iter++; + } while (iter < 2); + + return (res - 0xFFFFFFFF); +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +template <class F> +class Reducer +{ +public: + static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) + { + // Set out window + Window out_window(window); + out_window.set(Window::DimX, Window::Dimension(0, 1, 1)); + + f(window, out_window, input, output, op); + } + static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) + { + // Set in window + Window in_window(window); + Window out_window(window); + + in_window.set(Window::DimY, Window::Dimension(0, 1, 1)); + out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1))); + + f(in_window, out_window, input, output, 1, op); + } + static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) + { + // Set in window + Window in_window(window); + Window out_window(window); + + in_window.set(Window::DimZ, Window::Dimension(0, 1, 1)); + out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2))); + + f(in_window, out_window, input, output, 2, op); + } + static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) + { + // Set in/out window + Window in_window(window); + Window out_window(window); + + in_window.set(3, Window::Dimension(0, 1, 1)); + out_window.set(3, Window::Dimension(0, 1, 1)); + + f(in_window, out_window, input, output, 3, op); + } +}; + +template <typename T, int S> +struct RedOpX +{ + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; + + inline void operator()( + const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op) + { + const size_t input_dim_0 = in->info()->dimension(0); + const int window_step_x = 16 / sizeof(T); + const auto window_start_x = static_cast<int>(in_window.x().start()); + const auto window_end_x = static_cast<int>(in_window.x().end()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_window); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast<const T *>(input.ptr()); + + auto init_res_value = static_cast<T>(0.f); + switch (op) + { + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + init_res_value = static_cast<T>(*input_ptr); + break; + } + case ReductionOperation::PROD: + { + init_res_value = static_cast<T>(1.f); + break; + } + default: + break; + } + auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{}); + uint32x4x4_t vec_res_idx{{0}}; + + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vec_elements = wrapper::vloadq(input_ptr + x); + switch (op) + { + case ReductionOperation::SUM_SQUARE: + vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); + break; + case ReductionOperation::MEAN_SUM: + case ReductionOperation::SUM: + vec_res_value = wrapper::vadd(vec_elements, vec_res_value); + break; + case ReductionOperation::PROD: + vec_res_value = wrapper::vmul(vec_elements, vec_res_value); + break; + case ReductionOperation::ARG_IDX_MIN: + { + auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + vec_res_idx = calculate_index(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + vec_res_idx = calculate_index(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + case ReductionOperation::SUM_SQUARE: + { +#ifdef ARM_COMPUTE_DEBUG_ENABLED + auto res = static_cast<T>(0.f); + for (int i = 0; i < S; ++i) + { + res += wrapper::vgetlane(vec_res_value, i); + } +#else // ARM_COMPUTE_DEBUG_ENABLED + auto carry_res = + wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); + for (int i = 0; i < S / 4; ++i) + { + carry_res = wrapper::vpadd(carry_res, carry_res); + } + auto res = wrapper::vgetlane(carry_res, 0); +#endif // ARM_COMPUTE_DEBUG_ENABLED + if (op == ReductionOperation::SUM_SQUARE) + { + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res += (*(input_ptr + x)) * (*(input_ptr + x)); + } + } + else + { + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res += *(input_ptr + x); + } + } + + if (op == ReductionOperation::MEAN_SUM) + { + res /= input_dim_0; + } + + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::PROD: + { + auto carry_res = + wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); + T res = 1; + for (int i = 0; i < S / 2; ++i) + { + res *= wrapper::vgetlane(carry_res, i); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res *= *(input_ptr + x); + } + + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::ARG_IDX_MIN: + { + auto idx = calculate_vector_index(vec_res_idx, vec_res_value, op); + auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + if (*(input_ptr + x) < res) + { + idx = x; + res = *(input_ptr + x); + } + } + *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto idx = calculate_vector_index(vec_res_idx, vec_res_value, op); + auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + if (*(input_ptr + x) > res) + { + idx = x; + res = *(input_ptr + x); + } + } + *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; + break; + } + case ReductionOperation::MIN: + { + auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res = *(input_ptr + x) < res ? *(input_ptr + x) : res; + } + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::MAX: + { + auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res = *(input_ptr + x) > res ? *(input_ptr + x) : res; + } + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + }, + input, output); + } +}; + +template <typename T, int S> +struct RedOpYZW +{ + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; + using neon_vector = typename wrapper::traits::neon_vector<T, S>::type; + + inline void operator()(const Window &in_window, + Window &out_window, + const ITensor *in, + ITensor *out, + int axis, + const ReductionOperation op) + { + const TensorInfo in_info = *(in->info()); + const int window_step_x = 16 / sizeof(T); + const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); + const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); + // As it split over x-axis, need to set the correct spiltted window start and end. + const auto window_start_x = static_cast<int>(0); + const auto window_end_x = static_cast<int>(in_window.shape().x()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); + Window out_win_no_pad = out_window; + out_win_no_pad.set(Window::DimX, + Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_win_no_pad); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast<T *>(input.ptr()); + + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + neon_vector vec_res_value = {0}; + switch (op) + { + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vloadq(input_ptr + x); + break; + } + case ReductionOperation::PROD: + { + vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{}); + break; + } + default: + { + vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); + break; + } + } + uint32x4x4_t vec_res_idx{{0}}; + + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + const T *in_ptr = + reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim); + const auto vec_elements = wrapper::vloadq(in_ptr); + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + vec_res_value = wrapper::vadd(vec_elements, vec_res_value); + break; + case ReductionOperation::SUM_SQUARE: + vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); + break; + case ReductionOperation::PROD: + vec_res_value = wrapper::vmul(vec_elements, vec_res_value); + break; + case ReductionOperation::ARG_IDX_MIN: + { + auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + vec_res_idx = + calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + vec_res_idx = + calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + if (op == ReductionOperation::MEAN_SUM) + { + auto vec_width_inv = + wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{})); + vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv); + } + + if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) + { + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]); +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (std::is_same<T, float16_t>::value) + { + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]); + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } + else + { + wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value); + } + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + auto res_value = 0.f; + switch (op) + { + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + res_value = *(input_ptr + x); + break; + } + case ReductionOperation::PROD: + { + res_value = static_cast<T>(1.f); + break; + } + default: + { + res_value = static_cast<T>(0.f); + break; + } + } + + uint32_t res_idx = 0; + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + const T *in_ptr = + reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim); + + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + res_value += *in_ptr; + break; + case ReductionOperation::SUM_SQUARE: + res_value += *in_ptr * *in_ptr; + break; + case ReductionOperation::PROD: + res_value *= *in_ptr; + break; + case ReductionOperation::ARG_IDX_MIN: + { + if (*in_ptr < res_value) + { + res_value = *in_ptr; + res_idx = dim; + } + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + if (*in_ptr > res_value) + { + res_value = *in_ptr; + res_idx = dim; + } + break; + } + case ReductionOperation::MIN: + { + res_value = *in_ptr < res_value ? *in_ptr : res_value; + break; + } + case ReductionOperation::MAX: + { + res_value = *in_ptr > res_value ? *in_ptr : res_value; + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + if (op == ReductionOperation::MEAN_SUM) + { + res_value /= in_info.dimension(axis); + } + + if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) + { + *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx; + } + else + { + *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value; + } + } + }, + input, output); + } +}; + +template <typename T, int S, int axis, ReductionOperation op> +struct RedOpYZW_complex +{ + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; + using neon_vector = typename wrapper::traits::neon_vector<T, S>::type; + + inline void operator()( + const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation) + { + ARM_COMPUTE_ERROR_ON(axis != 2); + ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM); + + const TensorInfo in_info = *(in->info()); + const size_t stride_z = in_info.strides_in_bytes()[axis]; + const int window_step_x = 16 / sizeof(T); + const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); + const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); + // As it split over x-axis, need to set the correct spiltted window start and end. + const auto window_start_x = static_cast<int>(0); + const auto window_end_x = static_cast<int>(in_window.shape().x()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); + Window out_win_no_pad = out_window; + out_win_no_pad.set(Window::DimX, + Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_win_no_pad); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + neon_vector vec_res_value_0 = {0}; + neon_vector vec_res_value_1 = {0}; + + vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); + vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); + + T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T)); + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim); + T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim); + + const auto vec_elements_0 = wrapper::vloadq(in_ptr_0); + const auto vec_elements_1 = wrapper::vloadq(in_ptr_1); + + vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0); + vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1); + } + + wrapper::vstore(out_ptr, vec_res_value_0); + wrapper::vstore(out_ptr + 4, vec_res_value_1); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + auto res_value_0 = 0.f; + auto res_value_1 = 0.f; + + T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T)); + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim); + res_value_0 += *in_ptr; + res_value_1 += *(in_ptr + 1); + } + *out_ptr = res_value_0; + *(out_ptr + 1) = res_value_1; + } + }, + input, output); + } +}; + +} // namespace arm_compute +#endif // ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_FP16_H diff --git a/src/cpu/kernels/reduction_layer/generic/neon/integer.cpp b/src/cpu/kernels/reduction_layer/generic/neon/integer.cpp new file mode 100644 index 0000000000..ad66b456ac --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/integer.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void reduce_RedOpX_reduceX_S32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op); +} + +void reduce_RedOpYZW_reduceY_S32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op); +} +void reduce_RedOpYZW_reduceZ_S32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op); +} + +void reduce_RedOpYZW_reduceW_S32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/reduction_layer/generic/neon/list.h b/src/cpu/kernels/reduction_layer/generic/neon/list.h new file mode 100644 index 0000000000..947c28a130 --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/list.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_LIST_H +#define ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_LIST_H + +#include "arm_compute/core/Helpers.h" + +namespace arm_compute +{ +namespace cpu +{ + +#define DECLARE_REDUCTION_KERNEL(func_name) \ + void func_name(const Window &window, const ITensor *in, ITensor *out, const ReductionOperation op) + +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_complex_reduceZ_float32_4_2_SUM); +DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_float32_4); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_float32_4); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_float32_4); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_float32_4); + +DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_float16_8); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_float16_8); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_float16_8); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_float16_8); + +DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_S32_4); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_S32_4); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_S32_4); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_S32_4); + +DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_qasymm8); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_qasymm8); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_qasymm8); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_qasymm8); + +DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_qasymm8_signed); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_qasymm8_signed); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_qasymm8_signed); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_qasymm8_signed); + +#undef DECLARE_REDUCTION_KERNEL +} // namespace cpu +} // namespace arm_compute +#endif // ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_LIST_H diff --git a/src/cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp b/src/cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp new file mode 100644 index 0000000000..bc711c6855 --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void reduce_RedOpX_reduceX_qasymm8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, RedOpX_quantized<uint8_t>(), op); +} + +void reduce_RedOpYZW_reduceY_qasymm8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, RedOpYZW_quantized<uint8_t>(), op); +} + +void reduce_RedOpYZW_reduceZ_qasymm8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<uint8_t>(), op); +} + +void reduce_RedOpYZW_reduceW_qasymm8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, RedOpYZW_quantized<uint8_t>(), op); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp new file mode 100644 index 0000000000..10ac3d6715 --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void reduce_RedOpX_reduceX_qasymm8_signed(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), op); +} + +void reduce_RedOpYZW_reduceY_qasymm8_signed(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, RedOpYZW_quantized<int8_t>(), op); +} + +void reduce_RedOpYZW_reduceZ_qasymm8_signed(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<int8_t>(), op); +} + +void reduce_RedOpYZW_reduceW_qasymm8_signed(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, RedOpYZW_quantized<int8_t>(), op); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/neon/fp16.cpp b/src/cpu/kernels/softmax/generic/neon/fp16.cpp index da62d2d614..425fcf7ac6 100644 --- a/src/cpu/kernels/softmax/generic/neon/fp16.cpp +++ b/src/cpu/kernels/softmax/generic/neon/fp16.cpp @@ -33,9 +33,15 @@ namespace cpu { template <bool IS_LOG> -void neon_fp16_softmax( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window) +void neon_fp16_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) { + ARM_COMPUTE_UNUSED(lut_ptr); if (axis == 0) { return neon_softmax_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window); @@ -46,10 +52,20 @@ void neon_fp16_softmax( } } -template void neon_fp16_softmax<true>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); -template void neon_fp16_softmax<false>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); +template void neon_fp16_softmax<true>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); +template void neon_fp16_softmax<false>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/neon/fp32.cpp b/src/cpu/kernels/softmax/generic/neon/fp32.cpp index 0701620636..a64946eb74 100644 --- a/src/cpu/kernels/softmax/generic/neon/fp32.cpp +++ b/src/cpu/kernels/softmax/generic/neon/fp32.cpp @@ -31,9 +31,15 @@ namespace cpu { template <bool IS_LOG> -void neon_fp32_softmax( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window) +void neon_fp32_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) { + ARM_COMPUTE_UNUSED(lut_ptr); if (axis == 0) { return neon_softmax_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window); @@ -44,10 +50,20 @@ void neon_fp32_softmax( } } -template void neon_fp32_softmax<true>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); -template void neon_fp32_softmax<false>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); +template void neon_fp32_softmax<true>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); +template void neon_fp32_softmax<false>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp index d39240bb38..369f9bb005 100644 --- a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp +++ b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp @@ -30,9 +30,15 @@ namespace arm_compute namespace cpu { template <bool IS_LOG> -void neon_qasymm8_softmax( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window) +void neon_qasymm8_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) { + ARM_COMPUTE_UNUSED(lut_ptr); if (axis == 0) { return neon_softmax_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window); @@ -43,10 +49,20 @@ void neon_qasymm8_softmax( } } -template void neon_qasymm8_softmax<true>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); -template void neon_qasymm8_softmax<false>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); +template void neon_qasymm8_softmax<true>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); +template void neon_qasymm8_softmax<false>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp index 26fd5dbfa0..594ceb7654 100644 --- a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp +++ b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp @@ -30,9 +30,15 @@ namespace arm_compute namespace cpu { template <bool IS_LOG> -void neon_qasymm8_signed_softmax( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window) +void neon_qasymm8_signed_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) { + ARM_COMPUTE_UNUSED(lut_ptr); if (axis == 0) { return neon_softmax_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window); @@ -43,10 +49,20 @@ void neon_qasymm8_signed_softmax( } } -template void neon_qasymm8_signed_softmax<true>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); -template void neon_qasymm8_signed_softmax<false>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); +template void neon_qasymm8_signed_softmax<true>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); +template void neon_qasymm8_signed_softmax<false>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/sme2/fp16.cpp b/src/cpu/kernels/softmax/generic/sme2/fp16.cpp index bcd34d1ca2..e70c9f4793 100644 --- a/src/cpu/kernels/softmax/generic/sme2/fp16.cpp +++ b/src/cpu/kernels/softmax/generic/sme2/fp16.cpp @@ -720,8 +720,15 @@ loop_3_end%=: ); } -void sme2_fp16_softmax(const ITensor *in, void *const, ITensor *out, const float beta, int axis, const Window &window) +void sme2_fp16_softmax(const ITensor *in, + void *const, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) { + ARM_COMPUTE_UNUSED(lut_ptr); ARM_COMPUTE_UNUSED(axis); const auto *src_info = in->info(); diff --git a/src/cpu/kernels/softmax/generic/sme2/fp32.cpp b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp index 159039a320..5e29d51746 100644 --- a/src/cpu/kernels/softmax/generic/sme2/fp32.cpp +++ b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp @@ -524,8 +524,15 @@ loop_3_end%=: ); } -void sme2_fp32_softmax(const ITensor *in, void *const, ITensor *out, const float beta, int axis, const Window &window) +void sme2_fp32_softmax(const ITensor *in, + void *const, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) { + ARM_COMPUTE_UNUSED(lut_ptr); ARM_COMPUTE_UNUSED(axis); const auto *src_info = in->info(); diff --git a/src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp b/src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp new file mode 100644 index 0000000000..9feb669f7c --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp @@ -0,0 +1,634 @@ +/* + * Copyright (c) 2023-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Window.h" + +namespace arm_compute +{ +namespace cpu +{ + +// SoftMax +// +// Steps: +// * Find max: max_value = max(src) +// * Regularize: dst[i] = exp(src[i] - max_value) +// sum_value = sum(dst) +// * Normalize: dst[i] = dst[i] / sum_value +void sme2_qasymm8_softmax_kernel_512VL( // + const uint8_t *src, + uint8_t *dst, + float beta, + const uintptr_t shape[4], + const uintptr_t src_strides[4], + const uintptr_t dst_strides[4], + const float *lut, + float *tmp) +{ + // Precondition: + // * src_strides[0] == sizeof(uint8_t) + // * dst_strides[0] == sizeof(uint8_t) + // * tmp_strides[0] == sizeof(float) + + __asm__ volatile( + R"( + .inst 0xd503477f // smstart + + // Registers + // + // * x1: Loop index + // * x2: LUT index + // * x13: temporary, body_length + // + // * x20: index_3 + // * x21: src_3 + // * x22: dst_3 + // * x23: index_2 + // * x24: src_2 + // * x25: dst_2 + // * x26: index_1 + // * x27: src_1 + // * x28: dst_1 + // * x29 tmp + // + // + // * p0: all-true + // * p1: predicate for QASYMM8 values + // * p2: predicate 0 for FP32 values (first quarter of expanded/unpacked p1) + // * p3: predicate 1 for FP32 values (second quarter of expanded/unpacked p1) + // * p4: predicate 2 for FP32 values (third quarter of expanded/unpacked p1) + // * p5: predicate 3 for FP32 values (fourth quarter of expanded/unpacked p1) + // * pn9: all-true for 32 bit values + // * pn8: all-true for 8-bit values + // + // * z0-z15 the 256 LUT values of exp(-scale*beta*x) for x in QASYMM8, stored as FP32 values + + // Prepares all constant values + + ptrue p0.b + .inst 0x25a07811 // ptrue pn9.s + .inst 0x25207810 // ptrue pn8.b + + // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl + cntb x13, ALL, MUL #4 + udiv x9, %x[length], x13 + mul x13, x13, x9 + + // ================================================== + // 3D loop opening + // ================================================== + + mov x20, %x[shape_3] + mov x21, %x[src] + mov x22, %x[dst] + mov x19, %x[lut] + mov x29, %x[tmp] + + // Load the LUT to the register file. + mov x2, %x[lut] + .inst 0xa040c440 //ld1w { z0.s - z3.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c444 //ld1w { z4.s - z7.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c448 //ld1w { z8.s - z11.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c44c //ld1w { z12.s - z15.s }, pn9/z, [x2] + + +loop_3_start%=: + // for index_3 in shape_3 downto 1 + cmp x20, #0 + b.eq loop_3_end%= + sub x20, x20, #1 + + mov x23, %x[shape_2] + mov x24, x21 + mov x25, x22 + +loop_2_start%=: + // for index_2 in shape_2 downto 1 + cmp x23, #0 + b.eq loop_2_end%= + sub x23, x23, #1 + + mov x26, %x[shape_1] + mov x27, x24 + mov x28, x25 + +loop_1_start%=: + // for index_1 in shape_2 downto 1 + cmp x26, #0 + b.eq loop_1_end%= + sub x26, x26, #1 + + // ================================================== + // Step 1: Find max + // ================================================== + // z16-z19 = minimum QASYMM8 value (0) to allow for it to be used for comparison to find the max. + dup z16.b, #0 + dup z17.b, #0 + dup z18.b, #0 + dup z19.b, #0 + mov x1, #0 // x1: index +find_max_body_start%=: + cmp x1, x13 + b.eq find_max_body_end%= + .inst 0xa0018374 // ld1b { z20.b - z23.b }, pn8/z, [x27, x1] z20-z23: x + .inst 0xc134b811 // umax { z16.b - z19.b }, { z16.b - z19.b }, { z20.b - z23.b } z16-z19: max_value = max(max_value, x) + add x1, x1, #256 // Advance index by 256 bytes/integers: Z registers = 2048-bit data = 256 8-bit integers. + b find_max_body_start%= +find_max_body_end%=: + + // Loop for processing the leftover part. +find_max_leftover_start%=: + whilelo p1.b, x1, %x[length] + b.none find_max_leftover_end%= + + ld1b z30.b, p1/z, [x27, x1] // z30: x + umax z16.b, p1/m, z16.b, z30.b // z16: max_value = max(max_value, x) + + add x1, x1, #64 + + b find_max_leftover_start%= +find_max_leftover_end%=: + + .inst 0xc132b011 // umax { z16.b, z17.b }, { z16.b, z17.b }, { z18.b, z19.b } + umax z16.b, p0/m, z16.b, z17.b + umaxv b16, p0, z16.b // Reduction unsigned max operation to get maximum_value + dup z16.b, z16.b[0] + uunpklo z16.h, z16.b // Using unpack instructions to align the max value with the FP32 entries in the LUT for use in the TBX instruction + uunpklo z16.s, z16.h + + mov x1, #0 // reset index + dup z25.s, #0 + + mov x1, #0 + +regularize_start%=: + whilelo p1.b, x1, %x[length] + b.none regularize_end%= + + // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated + punpklo p2.h, p1.b + punpkhi p4.h, p1.b + + punpkhi p3.h, p2.b + punpklo p2.h, p2.b + + punpkhi p5.h, p4.b + punpklo p4.h, p4.b + + ld1b z17.b, p1/z, [x27, x1] //z17: input data + + uunpklo z18.h, z17.b //Using unpack instructions to align the input QASYMM8 values with the FP32 entries in the LUT for use in the TBX instruction + uunpkhi z19.h, z17.b + + uunpklo z17.s, z18.h // z17 = low low input QASYMM8 values + uunpkhi z18.s, z18.h // z18 = low high input QASYMM8 values + + uunpkhi z20.s, z19.h // z20 = high high input QASYMM8 values + uunpklo z19.s, z19.h // z19 = high low input QASYMM8 values + + sub z17.s, z16.s, z17.s // z12: x = max_value - input_data + sub z18.s, z16.s, z18.s // z13: x = max_value - input_data + sub z19.s, z16.s, z19.s // z14: x = max_value - input_data + sub z20.s, z16.s, z20.s // z15: x = max_value - input_data + + tbx z21.s, z0.s, z17.s // Look-up entries 0-15 in the LUT. + tbx z22.s, z0.s, z18.s + tbx z23.s, z0.s, z19.s + tbx z24.s, z0.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z1.s, z17.s // Look-up entries 16-31 in the LUT. + tbx z22.s, z1.s, z18.s + tbx z23.s, z1.s, z19.s + tbx z24.s, z1.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z2.s, z17.s // Look-up entries 32-47 in the LUT. + tbx z22.s, z2.s, z18.s + tbx z23.s, z2.s, z19.s + tbx z24.s, z2.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z3.s, z17.s // Look-up entries 48-63 in the LUT. + tbx z22.s, z3.s, z18.s + tbx z23.s, z3.s, z19.s + tbx z24.s, z3.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z4.s, z17.s // Look-up entries 64-79 in the LUT. + tbx z22.s, z4.s, z18.s + tbx z23.s, z4.s, z19.s + tbx z24.s, z4.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z5.s, z17.s // Look-up entries 80-95 in the LUT. + tbx z22.s, z5.s, z18.s + tbx z23.s, z5.s, z19.s + tbx z24.s, z5.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z6.s, z17.s // Look-up entries 96-111 in the LUT. + tbx z22.s, z6.s, z18.s + tbx z23.s, z6.s, z19.s + tbx z24.s, z6.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z7.s, z17.s // Look-up entries 112-127 in the LUT. + tbx z22.s, z7.s, z18.s + tbx z23.s, z7.s, z19.s + tbx z24.s, z7.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z8.s, z17.s // Look-up entries 128-143 in the LUT. + tbx z22.s, z8.s, z18.s + tbx z23.s, z8.s, z19.s + tbx z24.s, z8.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z9.s, z17.s // Look-up entries 144-159 in the LUT. + tbx z22.s, z9.s, z18.s + tbx z23.s, z9.s, z19.s + tbx z24.s, z9.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z10.s, z17.s // Look-up entries 160-175 in the LUT. + tbx z22.s, z10.s, z18.s + tbx z23.s, z10.s, z19.s + tbx z24.s, z10.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z11.s, z17.s // Look-up entries 176-191 in the LUT. + tbx z22.s, z11.s, z18.s + tbx z23.s, z11.s, z19.s + tbx z24.s, z11.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z12.s, z17.s // Look-up entries 192-207 in the LUT. + tbx z22.s, z12.s, z18.s + tbx z23.s, z12.s, z19.s + tbx z24.s, z12.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z13.s, z17.s // Look-up entries 208-223 in the LUT. + tbx z22.s, z13.s, z18.s + tbx z23.s, z13.s, z19.s + tbx z24.s, z13.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z14.s, z17.s // Look-up entries 224-239 in the LUT. + tbx z22.s, z14.s, z18.s + tbx z23.s, z14.s, z19.s + tbx z24.s, z14.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z15.s, z17.s // Look-up entries 240-255 in the LUT. + tbx z22.s, z15.s, z18.s + tbx z23.s, z15.s, z19.s + tbx z24.s, z15.s, z20.s + + + st1w z21.s, p2, [x29, x1, LSL #2]// z21 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p2/m, z25.s, z21.s + add x1, x1, #16 + + st1w z22.s, p3, [x29, x1, LSL #2]// z22 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p3/m, z25.s, z22.s + add x1, x1, #16 + + st1w z23.s, p4, [x29, x1, LSL #2]// z23 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p4/m, z25.s, z23.s + add x1, x1, #16 + + st1w z24.s, p5, [x29, x1, LSL #2]// z24 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p5/m, z25.s, z24.s + add x1, x1, #16 + + b regularize_start%= +regularize_end%=: + + mov w9, 0x0000 + movk w9, 0x4380, LSL #16 // Moving 256.f into w9 to scale - via multiplication (division by reciprocal) - the floating point [0,1] range of the results to the [0,255] integer range of QASYMM8 + dup z29.s, w9 + faddv s25, p0, z25.s + fdiv s25, s29, s25 + dup z25.s, z25.s[0] // z25: 256.f/sum. 256 is needed to get the full range and 1/sum is part of softmax. + + // ================================================== + // Step 3: Normalize + // ================================================== + mov x1, #0 +normalize_body_start%=: + cmp x1, x13 + b.eq normalize_body_end%= + + mov x2, x1 // Preserve the index into x2 for the final store to dst. + .inst 0xa001c7b0 // ld1w { z16.s - z19.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + .inst 0xa001c7b4 // ld1w { z20.s - z23.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + + // z16-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z16.s, z25.s, z16.s + fmul z17.s, z25.s, z17.s + fmul z18.s, z25.s, z18.s + fmul z19.s, z25.s, z19.s + fmul z20.s, z25.s, z20.s + fmul z21.s, z25.s, z21.s + fmul z22.s, z25.s, z22.s + fmul z23.s, z25.s, z23.s + + // z16-z23: convert the FP32 values from the tmp tensor to uint32. + fcvtzu z16.s, p0/m, z16.s + fcvtzu z17.s, p0/m, z17.s + fcvtzu z18.s, p0/m, z18.s + fcvtzu z19.s, p0/m, z19.s + fcvtzu z20.s, p0/m, z20.s + fcvtzu z21.s, p0/m, z21.s + fcvtzu z22.s, p0/m, z22.s + fcvtzu z23.s, p0/m, z23.s + + // z16-z17: narrow the uint32 values into uint8 and saturate them. + .inst 0xc133e230 // uqcvt z16.b, { z16.s - z19.s } + .inst 0xc133e2b1 // uqcvt z17.b, { z20.s - z23.s } + + dup z20.s, z25.s[0] // Juggling the value to z20 as z25 will be overwritten by the load below + + .inst 0xa001c7b8 // ld1w { z24.s - z27.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + .inst 0xa001c7bc // ld1w { z28.s - z31.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + + // z24-z31: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z24.s, z20.s, z24.s + fmul z25.s, z20.s, z25.s + fmul z26.s, z20.s, z26.s + fmul z27.s, z20.s, z27.s + fmul z28.s, z20.s, z28.s + fmul z29.s, z20.s, z29.s + fmul z30.s, z20.s, z30.s + fmul z31.s, z20.s, z31.s + + // z24-z31: convert the FP32 values from the tmp tensor to uint32. + fcvtzu z24.s, p0/m, z24.s + fcvtzu z25.s, p0/m, z25.s + fcvtzu z26.s, p0/m, z26.s + fcvtzu z27.s, p0/m, z27.s + fcvtzu z28.s, p0/m, z28.s + fcvtzu z29.s, p0/m, z29.s + fcvtzu z30.s, p0/m, z30.s + fcvtzu z31.s, p0/m, z31.s + + // z18-z19: narrow the uint32 values into uint8 and saturate them. + .inst 0xc133e332 // uqcvt z18.b, { z24.s - z27.s } + .inst 0xc133e3b3 // uqcvt z19.b, { z28.s - z31.s } + + .inst 0xa0228390 // st1b { z16.b - z19.b }, pn8, [x28, x2] + + dup z25.s, z20.s[0] // Juggling the value back to z25 as z20 will be overwritten by the next iteration or z25 will be used below. + +b normalize_body_start%= +normalize_body_end%=: + +normalize_leftover_start%=: + whilelo p1.b, x1, %x[length] + b.none normalize_leftover_end%= + + // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated + punpklo p2.h, p1.b + punpkhi p4.h, p1.b + + punpkhi p3.h, p2.b + punpklo p2.h, p2.b + + punpkhi p5.h, p4.b + punpklo p4.h, p4.b + + mov x2, x1 // Preserve the index into x2 for the final store to dst. + + // z20-z23: load exp(-scale*beta*x) from the tmp tensor + ld1w z20.s, p2/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z21.s, p3/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z22.s, p4/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z23.s, p5/z, [x29, x1, LSL #2] + add x1, x1, #16 + + // z20-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z20.s, z25.s, z20.s + fmul z21.s, z25.s, z21.s + fmul z22.s, z25.s, z22.s + fmul z23.s, z25.s, z23.s + + // z20-23: convert the FP32 values from the tmp tensor to uint32. + fcvtzu z20.s, p0/m, z20.s + fcvtzu z21.s, p0/m, z21.s + fcvtzu z22.s, p0/m, z22.s + fcvtzu z23.s, p0/m, z23.s + + .inst 0xc133e2b3 // uqcvt z19.b, { z20.s - z23.s }, narrow the uint32 values into uint8 and saturate them into z19. + + st1b z19.b, p1, [x28, x2] + + b normalize_leftover_start%= +normalize_leftover_end%=: + // ================================================== + // 3D loop closing + // ================================================== + add x27, x27, %x[src_stride_1] + add x28, x28, %x[dst_stride_1] + b loop_1_start%= +loop_1_end%=: + + add x24, x24, %x[src_stride_2] + add x25, x25, %x[dst_stride_2] + b loop_2_start%= +loop_2_end%=: + + add x21, x21, %x[src_stride_3] + add x22, x22, %x[dst_stride_3] + b loop_3_start%= +loop_3_end%=: + .inst 0xd503467f // smstop + )" + : + : [src] "r"(src), [tmp] "r"(tmp), [dst] "r"(dst), [beta] "r"(beta), [lut] "r"(lut), // + [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), // + [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), + [src_stride_3] "r"(src_strides[3]), // + [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), + [dst_stride_3] "r"(dst_strides[3]), // + [length] "r"(shape[0]) // + : "cc", "memory", // + "p0", "p1", "p2", "p3", "p4", // + "x2", "x9", "x13", // + "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x19", // + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", // + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", // + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", // + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" // + ); +} + +void sme2_qasymm8_softmax_lut_512VL(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(axis); + + const auto *src_info = in->info(); + const auto *dst_info = out->info(); + + const auto &full_shape = dst_info->tensor_shape(); + const auto &src_strides = src_info->strides_in_bytes(); + const auto &dst_strides = dst_info->strides_in_bytes(); + Strides tmp_strides; + + tmp_strides[0] = src_strides[0] * 4; + tmp_strides[1] = src_strides[1] * 4; + tmp_strides[2] = src_strides[2] * 4; + tmp_strides[3] = src_strides[3] * 4; + + const uintptr_t k_shape[] = { + full_shape[0], + window.num_iterations(1), + window.num_iterations(2), + window.num_iterations(3), + }; + + const uintptr_t k_src_strides[] = { + src_strides[0], + src_strides[1], + src_strides[2], + src_strides[3], + }; + + const uintptr_t k_dst_strides[] = { + dst_strides[0], + dst_strides[1], + dst_strides[2], + dst_strides[3], + }; + + const uintptr_t k_src_offset = window[0].start() * src_strides[0] + // + window[1].start() * src_strides[1] + // + window[2].start() * src_strides[2] + // + window[3].start() * src_strides[3]; + + const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + // + window[1].start() * dst_strides[1] + // + window[2].start() * dst_strides[2] + // + window[3].start() * dst_strides[3]; + + const uintptr_t k_tmp_offset = window[0].start() * tmp_strides[0] + // + window[1].start() * tmp_strides[1] + // + window[2].start() * tmp_strides[2] + // + window[3].start() * tmp_strides[3]; + + const auto *k_src = reinterpret_cast<const uint8_t *>(in->buffer() + k_src_offset); + float *tmp_float_ptr = reinterpret_cast<float *>(tmp); + auto *k_tmp = reinterpret_cast<float *>(tmp_float_ptr + k_tmp_offset); + auto *k_dst = reinterpret_cast<uint8_t *>(out->buffer() + k_dst_offset); + + sme2_qasymm8_softmax_kernel_512VL(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides, lut_ptr, k_tmp); +} + +} // namespace cpu +} // namespace arm_compute + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp new file mode 100644 index 0000000000..14c0f6c327 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp @@ -0,0 +1,655 @@ +/* + * Copyright (c) 2023-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Window.h" + +namespace arm_compute +{ +namespace cpu +{ + +// SoftMax +// +// Steps: +// * Find max: max_value = max(src) +// * Regularize: dst[i] = exp(src[i] - max_value) +// sum_value = sum(dst) +// * Normalize: dst[i] = dst[i] / sum_value +void sme2_qasymm8_signed_softmax_kernel_512VL( // + const int8_t *src, + int8_t *dst, + float beta, + const uintptr_t shape[4], + const uintptr_t src_strides[4], + const uintptr_t dst_strides[4], + const float *lut, + float *tmp) +{ + // Precondition: + // * src_strides[0] == sizeof(int8_t) + // * dst_strides[0] == sizeof(int8_t) + // * tmp_strides[0] == sizeof(float) + + __asm__ volatile( + R"( + .inst 0xd503477f // smstart + + // For register list explanation refer to qasymm8.cpp. + + // Prepares all constant values + + ptrue p0.b + .inst 0x25a07811 // ptrue pn9.s + .inst 0x25207810 // ptrue pn8.b + + // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl + cntb x13, ALL, MUL #4 + udiv x9, %x[length], x13 + mul x13, x13, x9 + + // ================================================== + // 3D loop opening + // ================================================== + + mov x20, %x[shape_3] + mov x21, %x[src] + mov x22, %x[dst] + mov x19, %x[lut] + mov x29, %x[tmp] + + // Load the LUT to the register file. + mov x2, %x[lut] + .inst 0xa040c440 //ld1w { z0.s - z3.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c444 //ld1w { z4.s - z7.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c448 //ld1w { z8.s - z11.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c44c //ld1w { z12.s - z15.s }, pn9/z, [x2] + + +loop_3_start%=: + // for index_3 in shape_3 downto 1 + cmp x20, #0 + b.eq loop_3_end%= + sub x20, x20, #1 + + mov x23, %x[shape_2] + mov x24, x21 + mov x25, x22 + +loop_2_start%=: + // for index_2 in shape_2 downto 1 + cmp x23, #0 + b.eq loop_2_end%= + sub x23, x23, #1 + + mov x26, %x[shape_1] + mov x27, x24 + mov x28, x25 + +loop_1_start%=: + // for index_1 in shape_2 downto 1 + cmp x26, #0 + b.eq loop_1_end%= + sub x26, x26, #1 + + // ================================================== + // Step 1: Find max + // ================================================== + // z16-z19 = minimum QASYMM8_SIGNED value (-128) to allow for it to be used for comparison to find the max. + dup z16.b, #0x80 + dup z17.b, #0x80 + dup z18.b, #0x80 + dup z19.b, #0x80 + + mov x1, #0 // x1: index +find_max_body_start%=: + cmp x1, x13 + b.eq find_max_body_end%= + .inst 0xa0018374 // ld1b { z20.b - z23.b }, pn8/z, [x27, x1] z16-z19: x + .inst 0xc134b810 // smax { z16.b - z19.b }, { z16.b - z19.b }, { z20.b - z23.b } z16-z19: max_value = max(max_value, x) + add x1, x1, #256 // Advance index by 256 bytes/integers: Z registers = 2048-bit data = 256 8-bit integers. + b find_max_body_start%= +find_max_body_end%=: + + // Loop for processing the leftover part. +find_max_leftover_start%=: + whilelo p1.b, x1, %x[length] + b.none find_max_leftover_end%= + + ld1b z30.b, p1/z, [x27, x1] // z30: x + smax z16.b, p1/m, z16.b, z30.b // z16: max_value = max(max_value, x) + + add x1, x1, #64 + + b find_max_leftover_start%= +find_max_leftover_end%=: + .inst 0xc132b010 // smax { z16.b, z17.b }, { z16.b, z17.b }, { z18.b, z19.b } + smax z16.b, p0/m, z16.b, z17.b + smaxv b16, p0, z16.b // Reduction signed max operation to get maximum_value + mov z16.b, b16 // z16: duplicated max_value for current row + + sunpklo z16.h, z16.b // Using unpack instructions to align the max value with the FP32 entries in the LUT for use in the TBX instruction + sunpklo z16.s, z16.h + + mov x1, #0 // reset index + dup z25.s, #0 + + +regularize_start%=: + whilelo p1.b, x1, %x[length] + b.none regularize_end%= + + mov w9, 0xFF80 + movk w9, 0xFFFF, LSL #16 // Moving -127.f into w9 to set the registers below to the minimum QASYMM8_SIGNED value + dup z17.s, w9 + dup z18.s, w9 + dup z19.s, w9 + dup z20.s, w9 + + dup z21.s, #0x0 + dup z22.s, #0x0 + dup z23.s, #0x0 + dup z24.s, #0x0 + + // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated + punpklo p2.h, p1.b + punpkhi p4.h, p1.b + + punpkhi p3.h, p2.b + punpklo p2.h, p2.b + + punpkhi p5.h, p4.b + punpklo p4.h, p4.b + + ld1b z17.b, p1/z, [x27, x1] //z17: input data + + sunpklo z18.h, z17.b // Using unpack instructions to align the input QASYMM8_SIGNED values with the FP32 entries in the LUT for use in the TBX instruction + sunpkhi z19.h, z17.b // + + sunpklo z17.s, z18.h // z17 = low low input QASYMM8_SIGNED values + sunpkhi z18.s, z18.h // z18 = low high input QASYMM8_SIGNED values + + sunpkhi z20.s, z19.h // z20 = high high input QASYMM8_SIGNED values + sunpklo z19.s, z19.h // z19 = high low input QASYMM8_SIGNED values + + sub z17.s, z16.s, z17.s // z12: x = max_value - input_data + sub z18.s, z16.s, z18.s // z13: x = max_value - input_data + sub z19.s, z16.s, z19.s // z14: x = max_value - input_data + sub z20.s, z16.s, z20.s // z15: x = max_value - input_data + + add z17.s, z17.s, #128 + add z18.s, z18.s, #128 + add z19.s, z19.s, #128 + add z20.s, z20.s, #128 + + tbx z21.s, z0.s, z17.s // Look-up entries 0-15 in the LUT. + tbx z22.s, z0.s, z18.s + tbx z23.s, z0.s, z19.s + tbx z24.s, z0.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z1.s, z17.s // Look-up entries 16-31 in the LUT. + tbx z22.s, z1.s, z18.s + tbx z23.s, z1.s, z19.s + tbx z24.s, z1.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z2.s, z17.s // Look-up entries 32-47 in the LUT. + tbx z22.s, z2.s, z18.s + tbx z23.s, z2.s, z19.s + tbx z24.s, z2.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z3.s, z17.s // Look-up entries 48-63 in the LUT. + tbx z22.s, z3.s, z18.s + tbx z23.s, z3.s, z19.s + tbx z24.s, z3.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z4.s, z17.s // Look-up entries 64-79 in the LUT. + tbx z22.s, z4.s, z18.s + tbx z23.s, z4.s, z19.s + tbx z24.s, z4.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z5.s, z17.s // Look-up entries 80-95 in the LUT. + tbx z22.s, z5.s, z18.s + tbx z23.s, z5.s, z19.s + tbx z24.s, z5.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z6.s, z17.s // Look-up entries 96-111 in the LUT. + tbx z22.s, z6.s, z18.s + tbx z23.s, z6.s, z19.s + tbx z24.s, z6.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z7.s, z17.s // Look-up entries 112-127 in the LUT. + tbx z22.s, z7.s, z18.s + tbx z23.s, z7.s, z19.s + tbx z24.s, z7.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z8.s, z17.s // Look-up entries 128-143 in the LUT. + tbx z22.s, z8.s, z18.s + tbx z23.s, z8.s, z19.s + tbx z24.s, z8.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z9.s, z17.s // Look-up entries 144-159 in the LUT. + tbx z22.s, z9.s, z18.s + tbx z23.s, z9.s, z19.s + tbx z24.s, z9.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z10.s, z17.s // Look-up entries 160-175 in the LUT. + tbx z22.s, z10.s, z18.s + tbx z23.s, z10.s, z19.s + tbx z24.s, z10.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z11.s, z17.s // Look-up entries 176-191 in the LUT. + tbx z22.s, z11.s, z18.s + tbx z23.s, z11.s, z19.s + tbx z24.s, z11.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z12.s, z17.s // Look-up entries 192-207 in the LUT. + tbx z22.s, z12.s, z18.s + tbx z23.s, z12.s, z19.s + tbx z24.s, z12.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z13.s, z17.s // Look-up entries 208-223 in the LUT. + tbx z22.s, z13.s, z18.s + tbx z23.s, z13.s, z19.s + tbx z24.s, z13.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z14.s, z17.s // Look-up entries 224-239 in the LUT. + tbx z22.s, z14.s, z18.s + tbx z23.s, z14.s, z19.s + tbx z24.s, z14.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z15.s, z17.s // Look-up entries 240-255 in the LUT. + tbx z22.s, z15.s, z18.s + tbx z23.s, z15.s, z19.s + tbx z24.s, z15.s, z20.s + + + st1w z21.s, p2, [x29, x1, LSL #2]// z21 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p2/m, z25.s, z21.s + add x1, x1, #16 + + st1w z22.s, p3, [x29, x1, LSL #2]// z22 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p3/m, z25.s, z22.s + add x1, x1, #16 + + st1w z23.s, p4, [x29, x1, LSL #2]// z23 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p4/m, z25.s, z23.s + add x1, x1, #16 + + st1w z24.s, p5, [x29, x1, LSL #2]// z24 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p5/m, z25.s, z24.s + add x1, x1, #16 + + b regularize_start%= +regularize_end%=: + + mov w9, 0x0000 + movk w9, 0x4380, LSL #16 // Moving 256.f into w9 to scale - via multiplication (division by reciprocal) - the floating point [0,1] range of the results to the [-128, 127] integer range of QASYMM8_SIGNED + mov w10, 0x0000 + movk w10, 0x4300, LSL #16 // Moving 128.f into w10 for the subtraction to move the results - via subtraction - from the [0,255] range to the [-128,127] range + dup z29.s, w9 + dup z30.s, w10 + faddv s25, p0, z25.s + fdiv s25, s29, s25 + dup z25.s, z25.s[0] // z25: 256.f/sum. 256 is needed to get the full range and 1/sum is part of softmax. + + // ================================================== + // Step 3: Normalize + // ================================================== + mov x1, #0 +normalize_body_start%=: + cmp x1, x13 + b.eq normalize_body_end%= + + mov x2, x1 // Preserve the index into x2 for the final store to dst. + .inst 0xa001c7b0 // ld1w { z16.s - z19.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + .inst 0xa001c7b4 // ld1w { z20.s - z23.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + + // z16-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z16.s, z25.s, z16.s + fmul z17.s, z25.s, z17.s + fmul z18.s, z25.s, z18.s + fmul z19.s, z25.s, z19.s + fmul z20.s, z25.s, z20.s + fmul z21.s, z25.s, z21.s + fmul z22.s, z25.s, z22.s + fmul z23.s, z25.s, z23.s + + // z16-z23: subtract 128.f. + fsub z16.s, z16.s, z30.s // Subtract 128.f + fsub z17.s, z17.s, z30.s // Subtract 128.f + fsub z18.s, z18.s, z30.s // Subtract 128.f + fsub z19.s, z19.s, z30.s // Subtract 128.f + fsub z20.s, z20.s, z30.s // Subtract 128.f + fsub z21.s, z21.s, z30.s // Subtract 128.f + fsub z22.s, z22.s, z30.s // Subtract 128.f + fsub z23.s, z23.s, z30.s // Subtract 128.f + + // z16-z23: convert the FP32 values from the tmp tensor to int32. + fcvtzs z16.s, p0/m, z16.s + fcvtzs z17.s, p0/m, z17.s + fcvtzs z18.s, p0/m, z18.s + fcvtzs z19.s, p0/m, z19.s + fcvtzs z20.s, p0/m, z20.s + fcvtzs z21.s, p0/m, z21.s + fcvtzs z22.s, p0/m, z22.s + fcvtzs z23.s, p0/m, z23.s + + // z16-z17: narrow the int32 values into int8 and saturate them. + .inst 0xc133e210 // sqcvt z16.b, { z16.s - z19.s } + .inst 0xc133e291 // sqcvt z17.b, { z20.s - z23.s } + + // Juggling the value to z20 (resp. 21) as z25 (resp. z30) will be overwritten by the load below. + dup z20.s, z25.s[0] + dup z21.s, z30.s[0] + + .inst 0xa001c7b8 // ld1w { z24.s - z27.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + .inst 0xa001c7bc // ld1w { z28.s - z31.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + + // z24-z31: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z24.s, z20.s, z24.s + fmul z25.s, z20.s, z25.s + fmul z26.s, z20.s, z26.s + fmul z27.s, z20.s, z27.s + fmul z28.s, z20.s, z28.s + fmul z29.s, z20.s, z29.s + fmul z30.s, z20.s, z30.s + fmul z31.s, z20.s, z31.s + + // z24-z31: subtract 128.f. + fsub z24.s, z24.s, z21.s + fsub z25.s, z25.s, z21.s + fsub z26.s, z26.s, z21.s + fsub z27.s, z27.s, z21.s + fsub z28.s, z28.s, z21.s + fsub z29.s, z29.s, z21.s + fsub z30.s, z30.s, z21.s + fsub z31.s, z31.s, z21.s + + // z24-z31: convert the FP32 values from the tmp tensor to int32. + fcvtzs z24.s, p0/m, z24.s + fcvtzs z25.s, p0/m, z25.s + fcvtzs z26.s, p0/m, z26.s + fcvtzs z27.s, p0/m, z27.s + fcvtzs z28.s, p0/m, z28.s + fcvtzs z29.s, p0/m, z29.s + fcvtzs z30.s, p0/m, z30.s + fcvtzs z31.s, p0/m, z31.s + + // z18-z19: narrow the int32 values into int8 and saturate them. + .inst 0xc133e312 // sqcvt z18.b, { z24.s - z27.s } + .inst 0xc133e393 // sqcvt z19.b, { z28.s - z31.s } + + .inst 0xa0228390 // st1b { z16.b - z19.b }, pn8, [x28, x2] + + // Juggling the values back to z25 (resp. z30) as z20 (resp. z21) will be overwritten by the next iteration or z25 (resp. z30) will be used below. + dup z25.s, z20.s[0] + dup z30.s, z21.s[0] +b normalize_body_start%= +normalize_body_end%=: +normalize_leftover_start%=: + whilelo p1.b, x1, %x[length] + b.none normalize_leftover_end%= + + // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated + punpklo p2.h, p1.b + punpkhi p4.h, p1.b + + punpkhi p3.h, p2.b + punpklo p2.h, p2.b + + punpkhi p5.h, p4.b + punpklo p4.h, p4.b + + mov x2, x1 // Preserve the index into x2 for the final store to dst. + + // z20-z23: load exp(-scale*beta*x) from the tmp tensor + ld1w z20.s, p2/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z21.s, p3/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z22.s, p4/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z23.s, p5/z, [x29, x1, LSL #2] + add x1, x1, #16 + + // z20-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z20.s, z25.s, z20.s + fmul z21.s, z25.s, z21.s + fmul z22.s, z25.s, z22.s + fmul z23.s, z25.s, z23.s + + //z20-z23: Subtract 128.f. + fsub z20.s, z20.s, z30.s + fsub z21.s, z21.s, z30.s + fsub z22.s, z22.s, z30.s + fsub z23.s, z23.s, z30.s + + // z20-23: convert the FP32 values from the tmp tensor to int32. + fcvtzs z20.s, p0/m, z20.s + fcvtzs z21.s, p0/m, z21.s + fcvtzs z22.s, p0/m, z22.s + fcvtzs z23.s, p0/m, z23.s + + .inst 0xc133e293 // sqcvt z19.b, { z20.s - z23.s }, narrow the int32 values into int8 and saturate them into z19. + + st1b z19.b, p1, [x28, x2] + + b normalize_leftover_start%= +normalize_leftover_end%=: + // ================================================== + // 3D loop closing + // ================================================== + add x27, x27, %x[src_stride_1] + add x28, x28, %x[dst_stride_1] + b loop_1_start%= +loop_1_end%=: + + add x24, x24, %x[src_stride_2] + add x25, x25, %x[dst_stride_2] + b loop_2_start%= +loop_2_end%=: + + add x21, x21, %x[src_stride_3] + add x22, x22, %x[dst_stride_3] + b loop_3_start%= +loop_3_end%=: + .inst 0xd503467f // smstop + )" + : + : [src] "r"(src), [tmp] "r"(tmp), [dst] "r"(dst), [beta] "r"(beta), [lut] "r"(lut), // + [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), // + [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), + [src_stride_3] "r"(src_strides[3]), // + [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), + [dst_stride_3] "r"(dst_strides[3]), // + [length] "r"(shape[0]) // + : "cc", "memory", // + "p0", "p1", "p2", "p3", "p4", // + "x2", "x9", "x13", // + "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x19", // + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", // + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", // + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", // + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" // + ); +} + +void sme2_qasymm8_signed_softmax_lut_512VL(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(axis); + + const auto *src_info = in->info(); + const auto *dst_info = out->info(); + + const auto &full_shape = dst_info->tensor_shape(); + const auto &src_strides = src_info->strides_in_bytes(); + const auto &dst_strides = dst_info->strides_in_bytes(); + Strides tmp_strides; + + tmp_strides[0] = src_strides[0] * 4; + tmp_strides[1] = src_strides[1] * 4; + tmp_strides[2] = src_strides[2] * 4; + tmp_strides[3] = src_strides[3] * 4; + + const uintptr_t k_shape[] = { + full_shape[0], + window.num_iterations(1), + window.num_iterations(2), + window.num_iterations(3), + }; + + const uintptr_t k_src_strides[] = { + src_strides[0], + src_strides[1], + src_strides[2], + src_strides[3], + }; + + const uintptr_t k_dst_strides[] = { + dst_strides[0], + dst_strides[1], + dst_strides[2], + dst_strides[3], + }; + + const uintptr_t k_src_offset = window[0].start() * src_strides[0] + // + window[1].start() * src_strides[1] + // + window[2].start() * src_strides[2] + // + window[3].start() * src_strides[3]; + + const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + // + window[1].start() * dst_strides[1] + // + window[2].start() * dst_strides[2] + // + window[3].start() * dst_strides[3]; + + const uintptr_t k_tmp_offset = window[0].start() * tmp_strides[0] + // + window[1].start() * tmp_strides[1] + // + window[2].start() * tmp_strides[2] + // + window[3].start() * tmp_strides[3]; + + const auto *k_src = reinterpret_cast<const int8_t *>(in->buffer() + k_src_offset); + float *tmp_float_ptr = reinterpret_cast<float *>(tmp); + auto *k_tmp = reinterpret_cast<float *>(tmp_float_ptr + k_tmp_offset); + auto *k_dst = reinterpret_cast<int8_t *>(out->buffer() + k_dst_offset); + + sme2_qasymm8_signed_softmax_kernel_512VL(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides, lut_ptr, k_tmp); +} + +} // namespace cpu +} // namespace arm_compute + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/cpu/kernels/softmax/list.h b/src/cpu/kernels/softmax/list.h index 1bb8ed50f0..7bbb265022 100644 --- a/src/cpu/kernels/softmax/list.h +++ b/src/cpu/kernels/softmax/list.h @@ -28,9 +28,10 @@ namespace arm_compute { namespace cpu { -#define DECLARE_SOFTMAX_KERNEL(func_name) \ - template <bool IS_LOG> \ - void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window) +#define DECLARE_SOFTMAX_KERNEL(func_name) \ + template <bool IS_LOG> \ + void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window, \ + const float *lut_ptr) DECLARE_SOFTMAX_KERNEL(neon_fp32_softmax); DECLARE_SOFTMAX_KERNEL(neon_fp16_softmax); @@ -39,11 +40,37 @@ DECLARE_SOFTMAX_KERNEL(neon_qasymm8_signed_softmax); #ifdef ARM_COMPUTE_ENABLE_SME2 -void sme2_fp32_softmax( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); +void sme2_fp32_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); -void sme2_fp16_softmax( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); +void sme2_fp16_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +void sme2_qasymm8_softmax_lut_512VL(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +void sme2_qasymm8_signed_softmax_lut_512VL(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); #endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 7d85885654..a4c856bb8f 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -945,6 +945,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected } break; #endif /* __aarch64__ */ + #if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: { @@ -963,13 +964,14 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected break; } #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if defined(ENABLE_FP16_KERNELS) case DataType::F16: ARM_COMPUTE_RETURN_ERROR_ON_MSG( !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for F16 input and F16 output"); break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +#endif /* ENABLE_FP16_KERNELS */ default: ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel"); break; @@ -1102,11 +1104,11 @@ void CpuGemmAssemblyDispatch::configure( } break; #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifdef ENABLE_FP16_KERNELS case DataType::F16: create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +#endif /* ENABLE_FP16_KERNELS */ default: break; } diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp index 21c0253f91..19adc1ef34 100644 --- a/src/gpu/cl/kernels/ClScatterKernel.cpp +++ b/src/gpu/cl/kernels/ClScatterKernel.cpp @@ -66,21 +66,44 @@ Status ClScatterKernel::validate(const ITensorInfo *updates, const int32_t upt_dims = upt_shape.num_dimensions(); const int32_t dst_dims = dst_shape.num_dimensions(); const int32_t ind_dims = ind_shape.num_dimensions(); + const int32_t data_dim = upt_dims - (ind_dims - 1); // Number of batch dims is the number of indices dims - 1 const int32_t index_len = ind_shape[0]; + bool unsupported_padding_config = + (dst_dims == index_len) && index_len > 1 && (dst->has_padding() || updates->has_padding()); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(unsupported_padding_config, "Padding is not supported with these shapes."); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32, DataType::F16, DataType::S32, DataType::S16, DataType::S8, DataType::U32, DataType::U16, DataType::U8); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(ind_dims > 2, "Only 2D indices tensors are currently supported."); + + // Check data dims in update tensor and output tensor are equal + for (int32_t i = 0; i < data_dim; i++) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(upt_shape[i] != dst_shape[i], + "Data dims should be same size in both updates and ouput tensor."); + } + + // Check if batch dims in indices and updates tensor are equal. + for (int32_t i = 0; i < ind_dims - 1; i++) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(upt_shape[data_dim + i] != ind_shape[i + 1], + "Batch dimensions should be the same in updates and indices tensor."); + } + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(ind_shape[1] != upt_shape[data_dim], + "Height of indices tensor should match size of highest dimension in updates tensor " + "(Excluding batch dimension)"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG( - ind_shape[1] != upt_shape[upt_dims - 1], - "Height of indices tensor should match size of highest dimension in updates tensor."); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(upt_dims > dst_dims, "Update tensor cannot have more dims than output tensor."); + data_dim >= dst_dims, "Update tensor cannot have more dims than output tensor. (Excluding batch dimensions)"); + ARM_COMPUTE_RETURN_ERROR_ON(index_len != dst_dims - data_dim); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((ind_dims < 2), "Shape of Indices tensor must be at least 2D"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > max_index_length, "Maximum supported index length is 5!"); - ARM_COMPUTE_RETURN_ERROR_ON(index_len != dst_dims - upt_dims + 1); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > dst_dims && dst_dims != 1, + "Index length should be smaller than or equal to number of output dims"); return Status{}; } @@ -95,34 +118,40 @@ void ClScatterKernel::configure(const ClCompileContext &compile_context, ARM_COMPUTE_LOG_PARAMS(updates, indices, dst, info); const TensorShape &dst_shape = dst->tensor_shape(); + const int index_len = indices->dimension(0); - const bool is_scalar_block = updates->num_dimensions() == 1; - const int n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0)); + // Check for single element data block + const bool is_scalar_block = (dst->num_dimensions() == static_cast<uint32_t>(index_len)); + const int n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0)); const int partial_n0 = updates->dimension(0) % n0; // The GWS will be 2D [x, y] // x-dimension refers to the x coordinate of the dst tensor // y-dimension refers to the collapsed y-coordinate of the data part of the dst tensor - Window win = calculate_max_window(dst_shape, Steps(n0)); - const int index_len = indices->dimension(0); + Window win; - // Collapse the dimensions corresponding to indices in the execution window - for (int i = 0; i < index_len; ++i) + if (!is_scalar_block) { - win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1)); - } + win = calculate_max_window(dst_shape, Steps(n0)); + + // Collapse the dimensions corresponding to indices in the execution window + for (int i = 0; i < index_len; ++i) + { + win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1)); + } - win = win.collapse(win, 1); + win = win.collapse(win, 1); + } // Set build options CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dst->data_type())); build_opts.add_option_if(is_data_type_float(dst->data_type()), "-DIS_FLOAT"); - const int num_dims = dst->num_dimensions(); - - build_opts.add_option("-DNUM_INDICES=" + support::cpp11::to_string(indices->dimension(1))); + const int num_dims = dst->num_dimensions(); + TensorShape ind_collapsed = indices->tensor_shape().collapsed_from(1); + build_opts.add_option("-DNUM_INDICES=" + support::cpp11::to_string(ind_collapsed[1])); build_opts.add_option("-DINDEX_LENGTH=" + support::cpp11::to_string(index_len)); // We provide 5 variables to use in a constant array @@ -185,13 +214,23 @@ void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::Com utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1)); auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST)); - const ITensorInfo *dst_info = dst->info(); - const int num_dims = dst_info->num_dimensions(); + const ITensorInfo *dst_info = dst->info(); + const ITensorInfo *upd_info = updates->info(); + const int num_dims = dst_info->num_dimensions(); + const int ind_dims = indices->info()->num_dimensions(); + const int index_len = indices->info()->dimension(0); - const int index_len = indices->info()->dimension(0); + bool unsupported_padding_config = + num_dims == index_len && index_len > 1 && (dst_info->has_padding() || upd_info->has_padding()); + if (unsupported_padding_config) + { + ARM_COMPUTE_ERROR("Unsupported Configuration! Padding not supported with these shapes."); + } // calculate m-dimensional data block strides in updates and destination tensors - const int upt_block_stride = updates->info()->strides_in_bytes()[updates->info()->num_dimensions() - 1]; + const int upt_block_stride = + updates->info()->strides_in_bytes()[updates->info()->num_dimensions() - (ind_dims - 1)]; + const int out_block_stride = dst_info->strides_in_bytes()[num_dims - index_len]; unsigned int idx = 0; |