From 298b2c0526615fc1f0242c2792fe2c51a4f0c44a Mon Sep 17 00:00:00 2001 From: Yair Schwarzbaum Date: Tue, 1 Feb 2022 08:55:56 +0200 Subject: Decouple castKernel Resolves: COMPMID-4625 Signed-off-by: Yair Schwarzbaum Change-Id: I3c30f007804b179e5e2b439f421fbd4e57fb02e1 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7149 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Giorgio Arena --- src/core/common/Registrars.h | 8 +- src/cpu/kernels/CpuCastKernel.cpp | 408 +++++-------------------- src/cpu/kernels/CpuCastKernel.h | 12 + src/cpu/kernels/CpuKernelSelectionTypes.h | 8 + src/cpu/kernels/cast/generic/neon/bfloat16.cpp | 144 +++++++++ src/cpu/kernels/cast/generic/neon/fp16.cpp | 396 ++++++++++++++++++++++++ src/cpu/kernels/cast/list.h | 44 +++ 7 files changed, 690 insertions(+), 330 deletions(-) create mode 100644 src/cpu/kernels/cast/generic/neon/bfloat16.cpp create mode 100644 src/cpu/kernels/cast/generic/neon/fp16.cpp create mode 100644 src/cpu/kernels/cast/list.h (limited to 'src') diff --git a/src/core/common/Registrars.h b/src/core/common/Registrars.h index c7fbf7f831..cc76de2be5 100644 --- a/src/core/common/Registrars.h +++ b/src/core/common/Registrars.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -167,4 +167,10 @@ #define REGISTER_INTEGER_SVE2(func_name) nullptr #endif /* defined(ENABLE_INTEGER_KERNELS) */ +#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#define REGISTER_BF16_NEON(func_name) &(func_name) +#else /* !(defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16))*/ +#define REGISTER_BF16_NEON(func_name) nullptr +#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)*/ + #endif /* SRC_CORE_COMMON_REGISTRARS_H */ diff --git a/src/cpu/kernels/CpuCastKernel.cpp b/src/cpu/kernels/CpuCastKernel.cpp index db76df9076..e1314e61da 100644 --- a/src/cpu/kernels/CpuCastKernel.cpp +++ b/src/cpu/kernels/CpuCastKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -32,10 +32,13 @@ #include "src/core/NEON/NEFixedPoint.h" #include "src/core/NEON/NEMath.h" #include "src/core/NEON/wrapper/wrapper.h" +#include "src/core/common/Registrars.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" #include "support/SaturateCast.h" +#include "src/cpu/kernels/cast/list.h" + namespace arm_compute { namespace cpu @@ -44,6 +47,50 @@ namespace kernels { namespace { +static const std::vector available_kernels = +{ + { + "neon_qs8_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::QASYMM8_SIGNED && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_qasymm8_signed_to_fp16_cast) + }, + { + "neon_qu8_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::QASYMM8 && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_u8_to_fp16_cast) + }, + { + "neon_u8_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::U8 && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_u8_to_fp16_cast) + }, + { + "neon_fp16_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_to_other_dt_cast) + }, + { + "neon_fp32_to_fp16_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F32 && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_fp32_to_fp16_cast) + }, + { + "neon_fp32_to_bf16_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F32 && data.dst_dt == DataType::BFLOAT16 && data.isa.bf16; }, + REGISTER_BF16_NEON(arm_compute::cpu::neon_fp32_to_bfloat16_cast) + }, + { + "neon_s32_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::S32 && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_s32_to_fp16_cast) + }, + { + "neon_bf16_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::BFLOAT16 && data.dst_dt == DataType::F32 && data.isa.bf16; }, + REGISTER_BF16_NEON(arm_compute::cpu::neon_bfloat16_to_fp32_cast) + }, +}; + Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, ConvertPolicy policy) { ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src); @@ -151,6 +198,9 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr Iterator src(_src, win); Iterator dst(_dst, win); + /*ukernel runs only when using fp16/bfloat16, so we validate it isn't a nullptr only before using it */ + const auto *uk = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ _src->info()->data_type(), _dst->info()->data_type(), CPUInfo::get().get_isa() }); + switch(_src->info()->data_type()) { case DataType::QASYMM8_SIGNED: @@ -262,42 +312,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr src, dst); break; } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { /* Up-conversion QASYMM8_SIGNED -> F16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast(src.ptr()); - const auto dst_ptr = reinterpret_cast(dst.ptr()); - int x = window_start_x; - - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const int8x16_t texels_s8 = vld1q_s8(src_ptr + x); - - const int16x8x2_t texels = - { - { - vmovl_s8(vget_low_s8(texels_s8)), - vmovl_s8(vget_high_s8(texels_s8)) - } - }; - vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0])); - vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1])); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast(*(src_ptr + x)); - } - }, - src, dst); + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - default: ARM_COMPUTE_ERROR("dst data type not supported"); } @@ -414,41 +435,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr src, dst); break; } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { - /* Up-conversion U8 -> F16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast(src.ptr()); - const auto dst_ptr = reinterpret_cast(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const uint8x16_t texels_u8 = vld1q_u8(src_ptr + x); - - const int16x8x2_t texels = - { - { - vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))), - vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8))) - } - }; - vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0])); - vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1])); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast(*(src_ptr + x)); - } - }, - src, dst); + /* Up-conversion U8 -> FP16 */ + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::U16: { /* Up-conversion U8 -> U16 */ @@ -668,6 +661,7 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr } break; } + case DataType::U16: { switch(_dst->info()->data_type()) @@ -775,258 +769,37 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr } break; } -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) case DataType::BFLOAT16: - switch(_dst->info()->data_type()) - { - case DataType::F32: - { - /* Up-conversion BFLOAT16 -> F32 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast(src.ptr()); - const auto dst_ptr = reinterpret_cast(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const uint16x8x2_t texels = - { - { - vld1q_u16(reinterpret_cast(src.ptr())), - vld1q_u16(reinterpret_cast(src.ptr()) + 8) - } - }; - - vst1q_f32(reinterpret_cast(dst.ptr()), - vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[0])), 16))); - vst1q_f32(reinterpret_cast(dst.ptr()) + 4, - vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[0])), 16))); - vst1q_f32(reinterpret_cast(dst.ptr()) + 8, - vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[1])), 16))); - vst1q_f32(reinterpret_cast(dst.ptr()) + 12, - vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[1])), 16))); - } - - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = float(*(src_ptr + x)); - } - }, - src, dst); - break; - } - default: - ARM_COMPUTE_ERROR("dst data type unsupported"); - } + { + /* Up-conversion BFLOAT16 -> F32 */ + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } case DataType::F16: - switch(_dst->info()->data_type()) - { - case DataType::QASYMM8_SIGNED: - { - /* Down-conversion F16 -> QASYMM8_SIGNED (Always saturating) */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast(src.ptr()); - const auto dst_ptr = reinterpret_cast(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8x2_t texels = - { - { - vld1q_f16(src_ptr + x), - vld1q_f16(src_ptr + x + 8), - } - }; - - vst1q_s8(dst_ptr + x, vcombine_s8(vqmovn_s16(vcvtq_s16_f16(texels.val[0])), vqmovn_s16(vcvtq_s16_f16(texels.val[1])))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = utils::cast::saturate_cast(*(src_ptr + x)); - } - }, - src, dst); - break; - } - case DataType::QASYMM8: - case DataType::U8: - { - /* Down-conversion F16 -> QASYMM8/U8 (Always saturating) */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast(src.ptr()); - const auto dst_ptr = reinterpret_cast(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8x2_t texels = - { - { - vld1q_f16(src_ptr + x), - vld1q_f16(src_ptr + x + 8), - } - }; - - vst1q_u8(dst_ptr + x, vcombine_u8(vqmovun_s16(vcvtq_s16_f16(texels.val[0])), vqmovun_s16(vcvtq_s16_f16(texels.val[1])))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = utils::cast::saturate_cast(*(src_ptr + x)); - } - - }, - src, dst); - break; - } - case DataType::F32: - { - /* Up-conversion F16 -> F32 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast(src.ptr()); - const auto dst_ptr = reinterpret_cast(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8x2_t texels = - { - { - vld1q_f16(src_ptr + x), - vld1q_f16(src_ptr + x + 8) - } - }; - vst1q_f32(dst_ptr + x, vcvt_f32_f16(vget_low_f16(texels.val[0]))); - vst1q_f32(dst_ptr + x + 4, vcvt_f32_f16(vget_high_f16(texels.val[0]))); - vst1q_f32(dst_ptr + x + 8, vcvt_f32_f16(vget_low_f16(texels.val[1]))); - vst1q_f32(dst_ptr + x + 12, vcvt_f32_f16(vget_high_f16(texels.val[1]))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast(*(src_ptr + x)); - } - }, - src, dst); - break; - } - case DataType::S32: - { - /* Up-conversion F16 -> S32 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast(src.ptr()); - const auto dst_ptr = reinterpret_cast(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8x2_t texels = - { - { - vld1q_f16(src_ptr + x), - vld1q_f16(src_ptr + x + 8) - } - }; - - vst1q_s32(dst_ptr + x, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[0])))); - vst1q_s32(dst_ptr + x + 4, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[0])))); - vst1q_s32(dst_ptr + x + 8, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[1])))); - vst1q_s32(dst_ptr + x + 12, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[1])))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast(*(src_ptr + x)); - } - }, - src, dst); - break; - } - default: - ARM_COMPUTE_ERROR("dst data type not supported"); - } + { + /* conversion F16 -> any data type */ + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + } case DataType::F32: switch(_dst->info()->data_type()) { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { /* Down-conversion F32 -> F16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast(src.ptr()); - const auto dst_ptr = reinterpret_cast(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float32x4x4_t texels = - { - { - vld1q_f32(src_ptr + x), - vld1q_f32(src_ptr + x + 4), - vld1q_f32(src_ptr + x + 8), - vld1q_f32(src_ptr + x + 12) - } - }; - - vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1]))); - vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3]))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast(*(src_ptr + x)); - } - }, - src, dst); + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) case DataType::BFLOAT16: { /* Down-conversion F32 -> BFLOAT16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast(src.ptr()); - const auto dst_ptr = reinterpret_cast(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - wrapper::vcvt_bf16_f32(reinterpret_cast(src.ptr()), - reinterpret_cast(dst.ptr())); - wrapper::vcvt_bf16_f32(reinterpret_cast(src.ptr()) + 8, - reinterpret_cast(dst.ptr()) + 8); - } - - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = *(src_ptr + x); - } - }, - src, dst); + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ case DataType::S32: { /* Conversion F32 -> S32 */ @@ -1140,42 +913,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr case DataType::S32: switch(_dst->info()->data_type()) { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { /* Down-conversion S32 -> F16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast(src.ptr()); - const auto dst_ptr = reinterpret_cast(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float32x4x4_t texels = - { - { - vcvtq_f32_s32(vld1q_s32(src_ptr + x)), - vcvtq_f32_s32(vld1q_s32(src_ptr + x + 4)), - vcvtq_f32_s32(vld1q_s32(src_ptr + x + 8)), - vcvtq_f32_s32(vld1q_s32(src_ptr + x + 12)) - } - }; - - vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1]))); - vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3]))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast(*(src_ptr + x)); - } - }, - src, dst); + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: { /* Conversion S32 -> F32 */ @@ -1362,6 +1106,12 @@ const char *CpuCastKernel::name() const { return "CpuCastKernel.cpp"; } + +const std::vector &CpuCastKernel::get_available_kernels() +{ + return available_kernels; +} + } // namespace kernels } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/CpuCastKernel.h b/src/cpu/kernels/CpuCastKernel.h index 7679178fa1..95d46fad23 100644 --- a/src/cpu/kernels/CpuCastKernel.h +++ b/src/cpu/kernels/CpuCastKernel.h @@ -39,6 +39,9 @@ namespace kernels */ class CpuCastKernel : public ICpuKernel { +private: + using CastKernelPtr = std::add_pointer::type; + public: CpuCastKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuCastKernel); @@ -73,6 +76,15 @@ public: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; + struct CastKernel + { + const char *name; + const CastDataTypeISASelectorDataPtr is_selected; + CastKernelPtr ukernel; + }; + + static const std::vector &get_available_kernels(); + private: ConvertPolicy _policy{ ConvertPolicy::SATURATE }; }; diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h index 8c5a39ad49..afcf014ad2 100644 --- a/src/cpu/kernels/CpuKernelSelectionTypes.h +++ b/src/cpu/kernels/CpuKernelSelectionTypes.h @@ -47,6 +47,13 @@ struct DataTypeDataLayoutISASelectorData const cpuinfo::CpuIsaInfo &isa; }; +struct CastDataTypeISASelectorData +{ + DataType src_dt; + DataType dst_dt; + const cpuinfo::CpuIsaInfo &isa; +}; + struct PoolDataTypeISASelectorData { DataType dt; @@ -74,6 +81,7 @@ using DataTypeDataLayoutSelectorPtr = std::add_pointer::type; using ElementwiseDataTypeISASelectorPtr = std::add_pointer::type; using DepthwiseConv2dNativeDataTypeISASelectorPtr = std::add_pointer::type; +using CastDataTypeISASelectorDataPtr = std::add_pointer::type; } // namespace kernels } // namespace cpu diff --git a/src/cpu/kernels/cast/generic/neon/bfloat16.cpp b/src/cpu/kernels/cast/generic/neon/bfloat16.cpp new file mode 100644 index 0000000000..b15584b0aa --- /dev/null +++ b/src/cpu/kernels/cast/generic/neon/bfloat16.cpp @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2016-2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) + +#include "arm_compute/core/TensorInfo.h" +#include "src/cpu/kernels/CpuCastKernel.h" +#include "src/cpu/kernels/cast/list.h" +#include "support/SaturateCast.h" + +namespace arm_compute +{ +namespace cpu +{ +void neon_fp32_to_bfloat16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + + /* Down-conversion F32 -> BFLOAT16 */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast(src.ptr()); + const auto dst_ptr = reinterpret_cast(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + wrapper::vcvt_bf16_f32(reinterpret_cast(src.ptr()), + reinterpret_cast(dst.ptr())); + wrapper::vcvt_bf16_f32(reinterpret_cast(src.ptr()) + 8, + reinterpret_cast(dst.ptr()) + 8); + } + + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = *(src_ptr + x); + } + }, + src, dst); +} + +void neon_bfloat16_to_fp32_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + switch(_dst->info()->data_type()) + { + case DataType::F32: + { + /* Up-conversion BFLOAT16 -> F32 */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast(src.ptr()); + const auto dst_ptr = reinterpret_cast(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const uint16x8x2_t texels = + { + { + vld1q_u16(reinterpret_cast(src.ptr())), + vld1q_u16(reinterpret_cast(src.ptr()) + 8) + } + }; + + vst1q_f32(reinterpret_cast(dst.ptr()), + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[0])), 16))); + vst1q_f32(reinterpret_cast(dst.ptr()) + 4, + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[0])), 16))); + vst1q_f32(reinterpret_cast(dst.ptr()) + 8, + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[1])), 16))); + vst1q_f32(reinterpret_cast(dst.ptr()) + 12, + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[1])), 16))); + } + + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = float(*(src_ptr + x)); + } + }, + src, dst); + break; + } + default: + ARM_COMPUTE_ERROR("dst data type unsupported"); + } +} + +} // namespace cpu +} // namespace arm_compute + +#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ diff --git a/src/cpu/kernels/cast/generic/neon/fp16.cpp b/src/cpu/kernels/cast/generic/neon/fp16.cpp new file mode 100644 index 0000000000..d2c66923cc --- /dev/null +++ b/src/cpu/kernels/cast/generic/neon/fp16.cpp @@ -0,0 +1,396 @@ +/* + * Copyright (c) 2016-2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) + +#include "arm_compute/core/TensorInfo.h" +#include "src/cpu/kernels/CpuCastKernel.h" +#include "src/cpu/kernels/cast/list.h" +#include "support/SaturateCast.h" + +namespace arm_compute +{ +namespace cpu +{ +void neon_qasymm8_signed_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast(src.ptr()); + const auto dst_ptr = reinterpret_cast(dst.ptr()); + int x = window_start_x; + + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const int8x16_t texels_s8 = vld1q_s8(src_ptr + x); + + const int16x8x2_t texels = + { + { + vmovl_s8(vget_low_s8(texels_s8)), + vmovl_s8(vget_high_s8(texels_s8)) + } + }; + vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0])); + vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1])); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast(*(src_ptr + x)); + } + }, + src, dst); +} + +void neon_s32_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast(src.ptr()); + const auto dst_ptr = reinterpret_cast(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float32x4x4_t texels = + { + { + vcvtq_f32_s32(vld1q_s32(src_ptr + x)), + vcvtq_f32_s32(vld1q_s32(src_ptr + x + 4)), + vcvtq_f32_s32(vld1q_s32(src_ptr + x + 8)), + vcvtq_f32_s32(vld1q_s32(src_ptr + x + 12)) + } + }; + + vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1]))); + vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3]))); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast(*(src_ptr + x)); + } + }, + src, dst); +} + +void neon_fp32_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast(src.ptr()); + const auto dst_ptr = reinterpret_cast(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float32x4x4_t texels = + { + { + vld1q_f32(src_ptr + x), + vld1q_f32(src_ptr + x + 4), + vld1q_f32(src_ptr + x + 8), + vld1q_f32(src_ptr + x + 12) + } + }; + + vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1]))); + vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3]))); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast(*(src_ptr + x)); + } + }, + src, dst); +} + +void neon_fp16_to_other_dt_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + switch(_dst->info()->data_type()) + { + case DataType::QASYMM8_SIGNED: + { + /* Down-conversion F16 -> QASYMM8_SIGNED (Always saturating) */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast(src.ptr()); + const auto dst_ptr = reinterpret_cast(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float16x8x2_t texels = + { + { + vld1q_f16(src_ptr + x), + vld1q_f16(src_ptr + x + 8), + } + }; + + vst1q_s8(dst_ptr + x, vcombine_s8(vqmovn_s16(vcvtq_s16_f16(texels.val[0])), vqmovn_s16(vcvtq_s16_f16(texels.val[1])))); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = utils::cast::saturate_cast(*(src_ptr + x)); + } + }, + src, dst); + break; + } + case DataType::QASYMM8: + case DataType::U8: + { + /* Down-conversion F16 -> QASYMM8/U8 (Always saturating) */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast(src.ptr()); + const auto dst_ptr = reinterpret_cast(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float16x8x2_t texels = + { + { + vld1q_f16(src_ptr + x), + vld1q_f16(src_ptr + x + 8), + } + }; + + vst1q_u8(dst_ptr + x, vcombine_u8(vqmovun_s16(vcvtq_s16_f16(texels.val[0])), vqmovun_s16(vcvtq_s16_f16(texels.val[1])))); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = utils::cast::saturate_cast(*(src_ptr + x)); + } + + }, + src, dst); + break; + } + case DataType::F32: + { + /* Up-conversion F16 -> F32 */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast(src.ptr()); + const auto dst_ptr = reinterpret_cast(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float16x8x2_t texels = + { + { + vld1q_f16(src_ptr + x), + vld1q_f16(src_ptr + x + 8) + } + }; + vst1q_f32(dst_ptr + x, vcvt_f32_f16(vget_low_f16(texels.val[0]))); + vst1q_f32(dst_ptr + x + 4, vcvt_f32_f16(vget_high_f16(texels.val[0]))); + vst1q_f32(dst_ptr + x + 8, vcvt_f32_f16(vget_low_f16(texels.val[1]))); + vst1q_f32(dst_ptr + x + 12, vcvt_f32_f16(vget_high_f16(texels.val[1]))); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast(*(src_ptr + x)); + } + }, + src, dst); + break; + } + case DataType::S32: + { + /* Up-conversion F16 -> S32 */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast(src.ptr()); + const auto dst_ptr = reinterpret_cast(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float16x8x2_t texels = + { + { + vld1q_f16(src_ptr + x), + vld1q_f16(src_ptr + x + 8) + } + }; + + vst1q_s32(dst_ptr + x, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[0])))); + vst1q_s32(dst_ptr + x + 4, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[0])))); + vst1q_s32(dst_ptr + x + 8, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[1])))); + vst1q_s32(dst_ptr + x + 12, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[1])))); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast(*(src_ptr + x)); + } + }, + src, dst); + break; + } + default: + ARM_COMPUTE_ERROR("dst data type not supported"); + } +} + +void neon_u8_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + /* Up-conversion U8 -> F16 */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast(src.ptr()); + const auto dst_ptr = reinterpret_cast(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const uint8x16_t texels_u8 = vld1q_u8(src_ptr + x); + + const int16x8x2_t texels = + { + { + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))), + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8))) + } + }; + vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0])); + vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1])); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast(*(src_ptr + x)); + } + }, + src, dst); + return; +} + +} // namespace cpu +} // namespace arm_compute +#endif /* #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */ diff --git a/src/cpu/kernels/cast/list.h b/src/cpu/kernels/cast/list.h new file mode 100644 index 0000000000..ffd82d5bf3 --- /dev/null +++ b/src/cpu/kernels/cast/list.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_CORE_NEON_KERNELS_CAST_LIST_H +#define SRC_CORE_NEON_KERNELS_CAST_LIST_H +namespace arm_compute +{ +namespace cpu +{ +#define DECLARE_CAST_KERNEL(func_name) \ + void func_name(const ITensor *_src, ITensor *_dst, const ThreadInfo &tensor, ConvertPolicy _policy, const Window &window) + +DECLARE_CAST_KERNEL(neon_fp32_to_fp16_cast); +DECLARE_CAST_KERNEL(neon_u8_to_fp16_cast); +DECLARE_CAST_KERNEL(neon_fp16_to_other_dt_cast); +DECLARE_CAST_KERNEL(neon_s32_to_fp16_cast); +DECLARE_CAST_KERNEL(neon_qasymm8_signed_to_fp16_cast); +DECLARE_CAST_KERNEL(neon_fp32_to_bfloat16_cast); +DECLARE_CAST_KERNEL(neon_bfloat16_to_fp32_cast); + +#undef DECLARE_CAST_KERNEL +} // namespace cpu +} // namespace arm_compute +#endif //SRC_CORE_NEON_KERNELS_CAST_LIST_H \ No newline at end of file -- cgit v1.2.1