diff options
author | Yair Schwarzbaum <yair.schwarzbaum@arm.com> | 2022-02-01 08:55:56 +0200 |
---|---|---|
committer | Yair Schwarzbaum <yair.schwarzbaum@arm.com> | 2022-03-03 10:16:39 +0000 |
commit | 298b2c0526615fc1f0242c2792fe2c51a4f0c44a (patch) | |
tree | e47e5986e805e29fed4afca59c76e5375076cff2 | |
parent | 918a9fb4aa4be23ca4261c241e9e52acc42f9bb3 (diff) | |
download | ComputeLibrary-298b2c0526615fc1f0242c2792fe2c51a4f0c44a.tar.gz |
Decouple castKernel
Resolves: COMPMID-4625
Signed-off-by: Yair Schwarzbaum <yair.schwarzbaum@arm.com>
Change-Id: I3c30f007804b179e5e2b439f421fbd4e57fb02e1
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7149
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
-rw-r--r-- | Android.bp | 2 | ||||
-rw-r--r-- | arm_compute/core/Utils.h | 3 | ||||
-rw-r--r-- | filelist.json | 8 | ||||
-rw-r--r-- | src/core/common/Registrars.h | 8 | ||||
-rw-r--r-- | src/cpu/kernels/CpuCastKernel.cpp | 408 | ||||
-rw-r--r-- | src/cpu/kernels/CpuCastKernel.h | 12 | ||||
-rw-r--r-- | src/cpu/kernels/CpuKernelSelectionTypes.h | 8 | ||||
-rw-r--r-- | src/cpu/kernels/cast/generic/neon/bfloat16.cpp | 144 | ||||
-rw-r--r-- | src/cpu/kernels/cast/generic/neon/fp16.cpp | 396 | ||||
-rw-r--r-- | src/cpu/kernels/cast/list.h | 44 | ||||
-rw-r--r-- | tests/validation/NEON/Cast.cpp | 72 |
11 files changed, 771 insertions, 334 deletions
diff --git a/Android.bp b/Android.bp index a279fdf5bb..340aeeed23 100644 --- a/Android.bp +++ b/Android.bp @@ -439,6 +439,8 @@ cc_library_static { "src/cpu/kernels/boundingboxtransform/generic/neon/fp32.cpp", "src/cpu/kernels/boundingboxtransform/generic/neon/impl.cpp", "src/cpu/kernels/boundingboxtransform/generic/neon/qsymm16.cpp", + "src/cpu/kernels/cast/generic/neon/bfloat16.cpp", + "src/cpu/kernels/cast/generic/neon/fp16.cpp", "src/cpu/kernels/crop/generic/neon/fp16.cpp", "src/cpu/kernels/crop/generic/neon/fp32.cpp", "src/cpu/kernels/crop/generic/neon/impl.cpp", diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h index fd9a0ee708..2d774770ae 100644 --- a/arm_compute/core/Utils.h +++ b/arm_compute/core/Utils.h @@ -1241,6 +1241,9 @@ inline std::string cpu_impl_dt(const DataType &data_type) case DataType::QSYMM8_PER_CHANNEL: ret = "qp8"; break; + case DataType::BFLOAT16: + ret = "bf16"; + break; default: ARM_COMPUTE_ERROR("Unsupported."); } diff --git a/filelist.json b/filelist.json index 3bdc00aeef..81b28f7f4b 100644 --- a/filelist.json +++ b/filelist.json @@ -969,8 +969,12 @@ "common": [ "src/cpu/operators/CpuCast.cpp", "src/cpu/kernels/CpuCastKernel.cpp", - "src/runtime/NEON/functions/NECast.cpp" - ] + "src/runtime/NEON/functions/NECast.cpp", + "src/cpu/kernels/cast/generic/neon/bfloat16.cpp" + ], + "neon":{ + "fp16":["src/cpu/kernels/cast/generic/neon/fp16.cpp"] + } } }, "ChannelShuffle": { diff --git a/src/core/common/Registrars.h b/src/core/common/Registrars.h index c7fbf7f831..cc76de2be5 100644 --- a/src/core/common/Registrars.h +++ b/src/core/common/Registrars.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -167,4 +167,10 @@ #define REGISTER_INTEGER_SVE2(func_name) nullptr #endif /* defined(ENABLE_INTEGER_KERNELS) */ +#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#define REGISTER_BF16_NEON(func_name) &(func_name) +#else /* !(defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16))*/ +#define REGISTER_BF16_NEON(func_name) nullptr +#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)*/ + #endif /* SRC_CORE_COMMON_REGISTRARS_H */ diff --git a/src/cpu/kernels/CpuCastKernel.cpp b/src/cpu/kernels/CpuCastKernel.cpp index db76df9076..e1314e61da 100644 --- a/src/cpu/kernels/CpuCastKernel.cpp +++ b/src/cpu/kernels/CpuCastKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -32,10 +32,13 @@ #include "src/core/NEON/NEFixedPoint.h" #include "src/core/NEON/NEMath.h" #include "src/core/NEON/wrapper/wrapper.h" +#include "src/core/common/Registrars.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" #include "support/SaturateCast.h" +#include "src/cpu/kernels/cast/list.h" + namespace arm_compute { namespace cpu @@ -44,6 +47,50 @@ namespace kernels { namespace { +static const std::vector<CpuCastKernel::CastKernel> available_kernels = +{ + { + "neon_qs8_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::QASYMM8_SIGNED && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_qasymm8_signed_to_fp16_cast) + }, + { + "neon_qu8_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::QASYMM8 && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_u8_to_fp16_cast) + }, + { + "neon_u8_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::U8 && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_u8_to_fp16_cast) + }, + { + "neon_fp16_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_to_other_dt_cast) + }, + { + "neon_fp32_to_fp16_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F32 && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_fp32_to_fp16_cast) + }, + { + "neon_fp32_to_bf16_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F32 && data.dst_dt == DataType::BFLOAT16 && data.isa.bf16; }, + REGISTER_BF16_NEON(arm_compute::cpu::neon_fp32_to_bfloat16_cast) + }, + { + "neon_s32_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::S32 && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_s32_to_fp16_cast) + }, + { + "neon_bf16_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::BFLOAT16 && data.dst_dt == DataType::F32 && data.isa.bf16; }, + REGISTER_BF16_NEON(arm_compute::cpu::neon_bfloat16_to_fp32_cast) + }, +}; + Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, ConvertPolicy policy) { ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src); @@ -151,6 +198,9 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr Iterator src(_src, win); Iterator dst(_dst, win); + /*ukernel runs only when using fp16/bfloat16, so we validate it isn't a nullptr only before using it */ + const auto *uk = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ _src->info()->data_type(), _dst->info()->data_type(), CPUInfo::get().get_isa() }); + switch(_src->info()->data_type()) { case DataType::QASYMM8_SIGNED: @@ -262,42 +312,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr src, dst); break; } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { /* Up-conversion QASYMM8_SIGNED -> F16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const int8_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr()); - int x = window_start_x; - - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const int8x16_t texels_s8 = vld1q_s8(src_ptr + x); - - const int16x8x2_t texels = - { - { - vmovl_s8(vget_low_s8(texels_s8)), - vmovl_s8(vget_high_s8(texels_s8)) - } - }; - vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0])); - vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1])); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x)); - } - }, - src, dst); + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - default: ARM_COMPUTE_ERROR("dst data type not supported"); } @@ -414,41 +435,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr src, dst); break; } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { - /* Up-conversion U8 -> F16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const uint8_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const uint8x16_t texels_u8 = vld1q_u8(src_ptr + x); - - const int16x8x2_t texels = - { - { - vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))), - vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8))) - } - }; - vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0])); - vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1])); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x)); - } - }, - src, dst); + /* Up-conversion U8 -> FP16 */ + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::U16: { /* Up-conversion U8 -> U16 */ @@ -668,6 +661,7 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr } break; } + case DataType::U16: { switch(_dst->info()->data_type()) @@ -775,258 +769,37 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr } break; } -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) case DataType::BFLOAT16: - switch(_dst->info()->data_type()) - { - case DataType::F32: - { - /* Up-conversion BFLOAT16 -> F32 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const bfloat16 *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<float *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const uint16x8x2_t texels = - { - { - vld1q_u16(reinterpret_cast<uint16_t *>(src.ptr())), - vld1q_u16(reinterpret_cast<uint16_t *>(src.ptr()) + 8) - } - }; - - vst1q_f32(reinterpret_cast<float *>(dst.ptr()), - vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[0])), 16))); - vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 4, - vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[0])), 16))); - vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 8, - vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[1])), 16))); - vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 12, - vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[1])), 16))); - } - - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = float(*(src_ptr + x)); - } - }, - src, dst); - break; - } - default: - ARM_COMPUTE_ERROR("dst data type unsupported"); - } + { + /* Up-conversion BFLOAT16 -> F32 */ + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } case DataType::F16: - switch(_dst->info()->data_type()) - { - case DataType::QASYMM8_SIGNED: - { - /* Down-conversion F16 -> QASYMM8_SIGNED (Always saturating) */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<int8_t *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8x2_t texels = - { - { - vld1q_f16(src_ptr + x), - vld1q_f16(src_ptr + x + 8), - } - }; - - vst1q_s8(dst_ptr + x, vcombine_s8(vqmovn_s16(vcvtq_s16_f16(texels.val[0])), vqmovn_s16(vcvtq_s16_f16(texels.val[1])))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = utils::cast::saturate_cast<int8_t>(*(src_ptr + x)); - } - }, - src, dst); - break; - } - case DataType::QASYMM8: - case DataType::U8: - { - /* Down-conversion F16 -> QASYMM8/U8 (Always saturating) */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<uint8_t *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8x2_t texels = - { - { - vld1q_f16(src_ptr + x), - vld1q_f16(src_ptr + x + 8), - } - }; - - vst1q_u8(dst_ptr + x, vcombine_u8(vqmovun_s16(vcvtq_s16_f16(texels.val[0])), vqmovun_s16(vcvtq_s16_f16(texels.val[1])))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = utils::cast::saturate_cast<uint8_t>(*(src_ptr + x)); - } - - }, - src, dst); - break; - } - case DataType::F32: - { - /* Up-conversion F16 -> F32 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<float *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8x2_t texels = - { - { - vld1q_f16(src_ptr + x), - vld1q_f16(src_ptr + x + 8) - } - }; - vst1q_f32(dst_ptr + x, vcvt_f32_f16(vget_low_f16(texels.val[0]))); - vst1q_f32(dst_ptr + x + 4, vcvt_f32_f16(vget_high_f16(texels.val[0]))); - vst1q_f32(dst_ptr + x + 8, vcvt_f32_f16(vget_low_f16(texels.val[1]))); - vst1q_f32(dst_ptr + x + 12, vcvt_f32_f16(vget_high_f16(texels.val[1]))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast<float>(*(src_ptr + x)); - } - }, - src, dst); - break; - } - case DataType::S32: - { - /* Up-conversion F16 -> S32 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<int32_t *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8x2_t texels = - { - { - vld1q_f16(src_ptr + x), - vld1q_f16(src_ptr + x + 8) - } - }; - - vst1q_s32(dst_ptr + x, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[0])))); - vst1q_s32(dst_ptr + x + 4, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[0])))); - vst1q_s32(dst_ptr + x + 8, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[1])))); - vst1q_s32(dst_ptr + x + 12, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[1])))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast<int32_t>(*(src_ptr + x)); - } - }, - src, dst); - break; - } - default: - ARM_COMPUTE_ERROR("dst data type not supported"); - } + { + /* conversion F16 -> any data type */ + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + } case DataType::F32: switch(_dst->info()->data_type()) { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { /* Down-conversion F32 -> F16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const float *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float32x4x4_t texels = - { - { - vld1q_f32(src_ptr + x), - vld1q_f32(src_ptr + x + 4), - vld1q_f32(src_ptr + x + 8), - vld1q_f32(src_ptr + x + 12) - } - }; - - vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1]))); - vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3]))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x)); - } - }, - src, dst); + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) case DataType::BFLOAT16: { /* Down-conversion F32 -> BFLOAT16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const float *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<bfloat16 *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(src.ptr()), - reinterpret_cast<uint16_t *>(dst.ptr())); - wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(src.ptr()) + 8, - reinterpret_cast<uint16_t *>(dst.ptr()) + 8); - } - - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = *(src_ptr + x); - } - }, - src, dst); + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ case DataType::S32: { /* Conversion F32 -> S32 */ @@ -1140,42 +913,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr case DataType::S32: switch(_dst->info()->data_type()) { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { /* Down-conversion S32 -> F16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const int32_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float32x4x4_t texels = - { - { - vcvtq_f32_s32(vld1q_s32(src_ptr + x)), - vcvtq_f32_s32(vld1q_s32(src_ptr + x + 4)), - vcvtq_f32_s32(vld1q_s32(src_ptr + x + 8)), - vcvtq_f32_s32(vld1q_s32(src_ptr + x + 12)) - } - }; - - vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1]))); - vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3]))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x)); - } - }, - src, dst); + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: { /* Conversion S32 -> F32 */ @@ -1362,6 +1106,12 @@ const char *CpuCastKernel::name() const { return "CpuCastKernel.cpp"; } + +const std::vector<CpuCastKernel::CastKernel> &CpuCastKernel::get_available_kernels() +{ + return available_kernels; +} + } // namespace kernels } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/CpuCastKernel.h b/src/cpu/kernels/CpuCastKernel.h index 7679178fa1..95d46fad23 100644 --- a/src/cpu/kernels/CpuCastKernel.h +++ b/src/cpu/kernels/CpuCastKernel.h @@ -39,6 +39,9 @@ namespace kernels */ class CpuCastKernel : public ICpuKernel<CpuCastKernel> { +private: + using CastKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const ThreadInfo &, ConvertPolicy, const Window &)>::type; + public: CpuCastKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuCastKernel); @@ -73,6 +76,15 @@ public: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; + struct CastKernel + { + const char *name; + const CastDataTypeISASelectorDataPtr is_selected; + CastKernelPtr ukernel; + }; + + static const std::vector<CastKernel> &get_available_kernels(); + private: ConvertPolicy _policy{ ConvertPolicy::SATURATE }; }; diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h index 8c5a39ad49..afcf014ad2 100644 --- a/src/cpu/kernels/CpuKernelSelectionTypes.h +++ b/src/cpu/kernels/CpuKernelSelectionTypes.h @@ -47,6 +47,13 @@ struct DataTypeDataLayoutISASelectorData const cpuinfo::CpuIsaInfo &isa; }; +struct CastDataTypeISASelectorData +{ + DataType src_dt; + DataType dst_dt; + const cpuinfo::CpuIsaInfo &isa; +}; + struct PoolDataTypeISASelectorData { DataType dt; @@ -74,6 +81,7 @@ using DataTypeDataLayoutSelectorPtr = std::add_pointer<bool(const using PoolDataTypeISASelectorPtr = std::add_pointer<bool(const PoolDataTypeISASelectorData &data)>::type; using ElementwiseDataTypeISASelectorPtr = std::add_pointer<bool(const ElementwiseDataTypeISASelectorData &data)>::type; using DepthwiseConv2dNativeDataTypeISASelectorPtr = std::add_pointer<bool(const DepthwiseConv2dNativeDataTypeISASelectorData &data)>::type; +using CastDataTypeISASelectorDataPtr = std::add_pointer<bool(const CastDataTypeISASelectorData &data)>::type; } // namespace kernels } // namespace cpu diff --git a/src/cpu/kernels/cast/generic/neon/bfloat16.cpp b/src/cpu/kernels/cast/generic/neon/bfloat16.cpp new file mode 100644 index 0000000000..b15584b0aa --- /dev/null +++ b/src/cpu/kernels/cast/generic/neon/bfloat16.cpp @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2016-2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) + +#include "arm_compute/core/TensorInfo.h" +#include "src/cpu/kernels/CpuCastKernel.h" +#include "src/cpu/kernels/cast/list.h" +#include "support/SaturateCast.h" + +namespace arm_compute +{ +namespace cpu +{ +void neon_fp32_to_bfloat16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + + /* Down-conversion F32 -> BFLOAT16 */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast<const float *>(src.ptr()); + const auto dst_ptr = reinterpret_cast<bfloat16 *>(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(src.ptr()), + reinterpret_cast<uint16_t *>(dst.ptr())); + wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(src.ptr()) + 8, + reinterpret_cast<uint16_t *>(dst.ptr()) + 8); + } + + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = *(src_ptr + x); + } + }, + src, dst); +} + +void neon_bfloat16_to_fp32_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + switch(_dst->info()->data_type()) + { + case DataType::F32: + { + /* Up-conversion BFLOAT16 -> F32 */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast<const bfloat16 *>(src.ptr()); + const auto dst_ptr = reinterpret_cast<float *>(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const uint16x8x2_t texels = + { + { + vld1q_u16(reinterpret_cast<uint16_t *>(src.ptr())), + vld1q_u16(reinterpret_cast<uint16_t *>(src.ptr()) + 8) + } + }; + + vst1q_f32(reinterpret_cast<float *>(dst.ptr()), + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[0])), 16))); + vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 4, + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[0])), 16))); + vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 8, + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[1])), 16))); + vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 12, + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[1])), 16))); + } + + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = float(*(src_ptr + x)); + } + }, + src, dst); + break; + } + default: + ARM_COMPUTE_ERROR("dst data type unsupported"); + } +} + +} // namespace cpu +} // namespace arm_compute + +#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ diff --git a/src/cpu/kernels/cast/generic/neon/fp16.cpp b/src/cpu/kernels/cast/generic/neon/fp16.cpp new file mode 100644 index 0000000000..d2c66923cc --- /dev/null +++ b/src/cpu/kernels/cast/generic/neon/fp16.cpp @@ -0,0 +1,396 @@ +/* + * Copyright (c) 2016-2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) + +#include "arm_compute/core/TensorInfo.h" +#include "src/cpu/kernels/CpuCastKernel.h" +#include "src/cpu/kernels/cast/list.h" +#include "support/SaturateCast.h" + +namespace arm_compute +{ +namespace cpu +{ +void neon_qasymm8_signed_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast<const int8_t *>(src.ptr()); + const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr()); + int x = window_start_x; + + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const int8x16_t texels_s8 = vld1q_s8(src_ptr + x); + + const int16x8x2_t texels = + { + { + vmovl_s8(vget_low_s8(texels_s8)), + vmovl_s8(vget_high_s8(texels_s8)) + } + }; + vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0])); + vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1])); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x)); + } + }, + src, dst); +} + +void neon_s32_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast<const int32_t *>(src.ptr()); + const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float32x4x4_t texels = + { + { + vcvtq_f32_s32(vld1q_s32(src_ptr + x)), + vcvtq_f32_s32(vld1q_s32(src_ptr + x + 4)), + vcvtq_f32_s32(vld1q_s32(src_ptr + x + 8)), + vcvtq_f32_s32(vld1q_s32(src_ptr + x + 12)) + } + }; + + vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1]))); + vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3]))); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x)); + } + }, + src, dst); +} + +void neon_fp32_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast<const float *>(src.ptr()); + const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float32x4x4_t texels = + { + { + vld1q_f32(src_ptr + x), + vld1q_f32(src_ptr + x + 4), + vld1q_f32(src_ptr + x + 8), + vld1q_f32(src_ptr + x + 12) + } + }; + + vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1]))); + vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3]))); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x)); + } + }, + src, dst); +} + +void neon_fp16_to_other_dt_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + switch(_dst->info()->data_type()) + { + case DataType::QASYMM8_SIGNED: + { + /* Down-conversion F16 -> QASYMM8_SIGNED (Always saturating) */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr()); + const auto dst_ptr = reinterpret_cast<int8_t *>(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float16x8x2_t texels = + { + { + vld1q_f16(src_ptr + x), + vld1q_f16(src_ptr + x + 8), + } + }; + + vst1q_s8(dst_ptr + x, vcombine_s8(vqmovn_s16(vcvtq_s16_f16(texels.val[0])), vqmovn_s16(vcvtq_s16_f16(texels.val[1])))); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = utils::cast::saturate_cast<int8_t>(*(src_ptr + x)); + } + }, + src, dst); + break; + } + case DataType::QASYMM8: + case DataType::U8: + { + /* Down-conversion F16 -> QASYMM8/U8 (Always saturating) */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr()); + const auto dst_ptr = reinterpret_cast<uint8_t *>(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float16x8x2_t texels = + { + { + vld1q_f16(src_ptr + x), + vld1q_f16(src_ptr + x + 8), + } + }; + + vst1q_u8(dst_ptr + x, vcombine_u8(vqmovun_s16(vcvtq_s16_f16(texels.val[0])), vqmovun_s16(vcvtq_s16_f16(texels.val[1])))); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = utils::cast::saturate_cast<uint8_t>(*(src_ptr + x)); + } + + }, + src, dst); + break; + } + case DataType::F32: + { + /* Up-conversion F16 -> F32 */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr()); + const auto dst_ptr = reinterpret_cast<float *>(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float16x8x2_t texels = + { + { + vld1q_f16(src_ptr + x), + vld1q_f16(src_ptr + x + 8) + } + }; + vst1q_f32(dst_ptr + x, vcvt_f32_f16(vget_low_f16(texels.val[0]))); + vst1q_f32(dst_ptr + x + 4, vcvt_f32_f16(vget_high_f16(texels.val[0]))); + vst1q_f32(dst_ptr + x + 8, vcvt_f32_f16(vget_low_f16(texels.val[1]))); + vst1q_f32(dst_ptr + x + 12, vcvt_f32_f16(vget_high_f16(texels.val[1]))); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast<float>(*(src_ptr + x)); + } + }, + src, dst); + break; + } + case DataType::S32: + { + /* Up-conversion F16 -> S32 */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr()); + const auto dst_ptr = reinterpret_cast<int32_t *>(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float16x8x2_t texels = + { + { + vld1q_f16(src_ptr + x), + vld1q_f16(src_ptr + x + 8) + } + }; + + vst1q_s32(dst_ptr + x, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[0])))); + vst1q_s32(dst_ptr + x + 4, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[0])))); + vst1q_s32(dst_ptr + x + 8, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[1])))); + vst1q_s32(dst_ptr + x + 12, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[1])))); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast<int32_t>(*(src_ptr + x)); + } + }, + src, dst); + break; + } + default: + ARM_COMPUTE_ERROR("dst data type not supported"); + } +} + +void neon_u8_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_UNUSED(_policy); + + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + const int window_step_x = 16; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + ARM_COMPUTE_ERROR_ON(_src == _dst); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst); + + Window win{ window }; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator src(_src, win); + Iterator dst(_dst, win); + /* Up-conversion U8 -> F16 */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast<const uint8_t *>(src.ptr()); + const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const uint8x16_t texels_u8 = vld1q_u8(src_ptr + x); + + const int16x8x2_t texels = + { + { + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))), + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8))) + } + }; + vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0])); + vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1])); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x)); + } + }, + src, dst); + return; +} + +} // namespace cpu +} // namespace arm_compute +#endif /* #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */ diff --git a/src/cpu/kernels/cast/list.h b/src/cpu/kernels/cast/list.h new file mode 100644 index 0000000000..ffd82d5bf3 --- /dev/null +++ b/src/cpu/kernels/cast/list.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_CORE_NEON_KERNELS_CAST_LIST_H +#define SRC_CORE_NEON_KERNELS_CAST_LIST_H +namespace arm_compute +{ +namespace cpu +{ +#define DECLARE_CAST_KERNEL(func_name) \ + void func_name(const ITensor *_src, ITensor *_dst, const ThreadInfo &tensor, ConvertPolicy _policy, const Window &window) + +DECLARE_CAST_KERNEL(neon_fp32_to_fp16_cast); +DECLARE_CAST_KERNEL(neon_u8_to_fp16_cast); +DECLARE_CAST_KERNEL(neon_fp16_to_other_dt_cast); +DECLARE_CAST_KERNEL(neon_s32_to_fp16_cast); +DECLARE_CAST_KERNEL(neon_qasymm8_signed_to_fp16_cast); +DECLARE_CAST_KERNEL(neon_fp32_to_bfloat16_cast); +DECLARE_CAST_KERNEL(neon_bfloat16_to_fp32_cast); + +#undef DECLARE_CAST_KERNEL +} // namespace cpu +} // namespace arm_compute +#endif //SRC_CORE_NEON_KERNELS_CAST_LIST_H
\ No newline at end of file diff --git a/tests/validation/NEON/Cast.cpp b/tests/validation/NEON/Cast.cpp index db73bea9cb..3a77106a42 100644 --- a/tests/validation/NEON/Cast.cpp +++ b/tests/validation/NEON/Cast.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -25,6 +25,8 @@ #include "arm_compute/runtime/NEON/functions/NECast.h" #include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/TensorAllocator.h" +#include "src/common/cpuinfo/CpuIsaInfo.h" +#include "src/cpu/kernels/CpuCastKernel.h" #include "tests/NEON/Accessor.h" #include "tests/PaddingCalculator.h" #include "tests/datasets/ConvertPolicyDataset.h" @@ -34,7 +36,6 @@ #include "tests/framework/datasets/Datasets.h" #include "tests/validation/Validation.h" #include "tests/validation/fixtures/CastFixture.h" - namespace arm_compute { namespace test @@ -187,6 +188,73 @@ CAST_SUITE(F32_to_F16, DataType::F32, DataType::F16, NECastToF16Fixture<float>, CAST_SUITE(F32_to_S32, DataType::F32, DataType::S32, NECastToS32Fixture<float>, CastF32toS32Dataset, one_tolerance) CAST_SUITE(F32_to_U8, DataType::F32, DataType::S32, NECastToS32Fixture<float>, CastF32toS32Dataset, one_tolerance) +DATA_TEST_CASE(KernelSelectionDstFP16, framework::DatasetMode::ALL, + combine(framework::dataset::make("CpuExt", std::string("NEON")), + framework::dataset::make("DataType", +{ + DataType::F16, + DataType::U8, + DataType::S32, + DataType::QASYMM8, + DataType::QASYMM8_SIGNED, + DataType::BFLOAT16, +})), +cpu_ext, data_type) +{ + using namespace cpu::kernels; + const CpuCastKernel::CastKernel *selected_impl; + + cpuinfo::CpuIsaInfo cpu_isa{}; + cpu_isa.neon = (cpu_ext == "NEON"); + + cpu_isa.bf16 = (data_type == DataType::BFLOAT16); + + /* bf16 cast is different from all the others being converted to fp32 and not to fp16 */ + if(cpu_isa.bf16) + { + cpu_isa.fp16 = false; + selected_impl = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ data_type, DataType::F32, cpu_isa }, cpu::KernelSelectionType::Preferred); + } + else + { + cpu_isa.fp16 = true; + selected_impl = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ data_type, DataType::F16, cpu_isa }, cpu::KernelSelectionType::Preferred); + } + + ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl); + + std::string expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_cast"; + std::string actual = selected_impl->name; + + ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS); +} + +DATA_TEST_CASE(KernelSelectionSrcFP32, framework::DatasetMode::ALL, + combine(framework::dataset::make("CpuExt", std::string("NEON")), + framework::dataset::make("DataType", +{ + DataType::F16, + DataType::BFLOAT16, +})), +cpu_ext, data_type) +{ + using namespace cpu::kernels; + + cpuinfo::CpuIsaInfo cpu_isa{}; + cpu_isa.neon = (cpu_ext == "NEON"); + cpu_isa.fp16 = (data_type == DataType::F16); + cpu_isa.bf16 = (data_type == DataType::BFLOAT16); + + const auto *selected_impl = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ DataType::F32, data_type, cpu_isa }, cpu::KernelSelectionType::Preferred); + + ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl); + + std::string expected = lower_string(cpu_ext) + "_fp32_to_" + cpu_impl_dt(data_type) + "_cast"; + std::string actual = selected_impl->name; + + ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS); +} + TEST_SUITE_END() // Cast TEST_SUITE_END() // Neon } // namespace validation |