diff options
Diffstat (limited to 'src/cpu/kernels/CpuCastKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuCastKernel.cpp | 408 |
1 files changed, 79 insertions, 329 deletions
diff --git a/src/cpu/kernels/CpuCastKernel.cpp b/src/cpu/kernels/CpuCastKernel.cpp index db76df9076..e1314e61da 100644 --- a/src/cpu/kernels/CpuCastKernel.cpp +++ b/src/cpu/kernels/CpuCastKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -32,10 +32,13 @@ #include "src/core/NEON/NEFixedPoint.h" #include "src/core/NEON/NEMath.h" #include "src/core/NEON/wrapper/wrapper.h" +#include "src/core/common/Registrars.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" #include "support/SaturateCast.h" +#include "src/cpu/kernels/cast/list.h" + namespace arm_compute { namespace cpu @@ -44,6 +47,50 @@ namespace kernels { namespace { +static const std::vector<CpuCastKernel::CastKernel> available_kernels = +{ + { + "neon_qs8_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::QASYMM8_SIGNED && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_qasymm8_signed_to_fp16_cast) + }, + { + "neon_qu8_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::QASYMM8 && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_u8_to_fp16_cast) + }, + { + "neon_u8_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::U8 && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_u8_to_fp16_cast) + }, + { + "neon_fp16_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_to_other_dt_cast) + }, + { + "neon_fp32_to_fp16_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F32 && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_fp32_to_fp16_cast) + }, + { + "neon_fp32_to_bf16_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F32 && data.dst_dt == DataType::BFLOAT16 && data.isa.bf16; }, + REGISTER_BF16_NEON(arm_compute::cpu::neon_fp32_to_bfloat16_cast) + }, + { + "neon_s32_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::S32 && data.dst_dt == DataType::F16 && data.isa.fp16; }, + REGISTER_FP16_NEON(arm_compute::cpu::neon_s32_to_fp16_cast) + }, + { + "neon_bf16_cast", + [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::BFLOAT16 && data.dst_dt == DataType::F32 && data.isa.bf16; }, + REGISTER_BF16_NEON(arm_compute::cpu::neon_bfloat16_to_fp32_cast) + }, +}; + Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, ConvertPolicy policy) { ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src); @@ -151,6 +198,9 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr Iterator src(_src, win); Iterator dst(_dst, win); + /*ukernel runs only when using fp16/bfloat16, so we validate it isn't a nullptr only before using it */ + const auto *uk = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ _src->info()->data_type(), _dst->info()->data_type(), CPUInfo::get().get_isa() }); + switch(_src->info()->data_type()) { case DataType::QASYMM8_SIGNED: @@ -262,42 +312,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr src, dst); break; } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { /* Up-conversion QASYMM8_SIGNED -> F16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const int8_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr()); - int x = window_start_x; - - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const int8x16_t texels_s8 = vld1q_s8(src_ptr + x); - - const int16x8x2_t texels = - { - { - vmovl_s8(vget_low_s8(texels_s8)), - vmovl_s8(vget_high_s8(texels_s8)) - } - }; - vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0])); - vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1])); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x)); - } - }, - src, dst); + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - default: ARM_COMPUTE_ERROR("dst data type not supported"); } @@ -414,41 +435,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr src, dst); break; } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { - /* Up-conversion U8 -> F16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const uint8_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const uint8x16_t texels_u8 = vld1q_u8(src_ptr + x); - - const int16x8x2_t texels = - { - { - vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))), - vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8))) - } - }; - vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0])); - vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1])); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x)); - } - }, - src, dst); + /* Up-conversion U8 -> FP16 */ + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::U16: { /* Up-conversion U8 -> U16 */ @@ -668,6 +661,7 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr } break; } + case DataType::U16: { switch(_dst->info()->data_type()) @@ -775,258 +769,37 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr } break; } -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) case DataType::BFLOAT16: - switch(_dst->info()->data_type()) - { - case DataType::F32: - { - /* Up-conversion BFLOAT16 -> F32 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const bfloat16 *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<float *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const uint16x8x2_t texels = - { - { - vld1q_u16(reinterpret_cast<uint16_t *>(src.ptr())), - vld1q_u16(reinterpret_cast<uint16_t *>(src.ptr()) + 8) - } - }; - - vst1q_f32(reinterpret_cast<float *>(dst.ptr()), - vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[0])), 16))); - vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 4, - vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[0])), 16))); - vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 8, - vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[1])), 16))); - vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 12, - vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[1])), 16))); - } - - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = float(*(src_ptr + x)); - } - }, - src, dst); - break; - } - default: - ARM_COMPUTE_ERROR("dst data type unsupported"); - } + { + /* Up-conversion BFLOAT16 -> F32 */ + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } case DataType::F16: - switch(_dst->info()->data_type()) - { - case DataType::QASYMM8_SIGNED: - { - /* Down-conversion F16 -> QASYMM8_SIGNED (Always saturating) */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<int8_t *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8x2_t texels = - { - { - vld1q_f16(src_ptr + x), - vld1q_f16(src_ptr + x + 8), - } - }; - - vst1q_s8(dst_ptr + x, vcombine_s8(vqmovn_s16(vcvtq_s16_f16(texels.val[0])), vqmovn_s16(vcvtq_s16_f16(texels.val[1])))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = utils::cast::saturate_cast<int8_t>(*(src_ptr + x)); - } - }, - src, dst); - break; - } - case DataType::QASYMM8: - case DataType::U8: - { - /* Down-conversion F16 -> QASYMM8/U8 (Always saturating) */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<uint8_t *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8x2_t texels = - { - { - vld1q_f16(src_ptr + x), - vld1q_f16(src_ptr + x + 8), - } - }; - - vst1q_u8(dst_ptr + x, vcombine_u8(vqmovun_s16(vcvtq_s16_f16(texels.val[0])), vqmovun_s16(vcvtq_s16_f16(texels.val[1])))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = utils::cast::saturate_cast<uint8_t>(*(src_ptr + x)); - } - - }, - src, dst); - break; - } - case DataType::F32: - { - /* Up-conversion F16 -> F32 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<float *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8x2_t texels = - { - { - vld1q_f16(src_ptr + x), - vld1q_f16(src_ptr + x + 8) - } - }; - vst1q_f32(dst_ptr + x, vcvt_f32_f16(vget_low_f16(texels.val[0]))); - vst1q_f32(dst_ptr + x + 4, vcvt_f32_f16(vget_high_f16(texels.val[0]))); - vst1q_f32(dst_ptr + x + 8, vcvt_f32_f16(vget_low_f16(texels.val[1]))); - vst1q_f32(dst_ptr + x + 12, vcvt_f32_f16(vget_high_f16(texels.val[1]))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast<float>(*(src_ptr + x)); - } - }, - src, dst); - break; - } - case DataType::S32: - { - /* Up-conversion F16 -> S32 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<int32_t *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float16x8x2_t texels = - { - { - vld1q_f16(src_ptr + x), - vld1q_f16(src_ptr + x + 8) - } - }; - - vst1q_s32(dst_ptr + x, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[0])))); - vst1q_s32(dst_ptr + x + 4, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[0])))); - vst1q_s32(dst_ptr + x + 8, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[1])))); - vst1q_s32(dst_ptr + x + 12, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[1])))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast<int32_t>(*(src_ptr + x)); - } - }, - src, dst); - break; - } - default: - ARM_COMPUTE_ERROR("dst data type not supported"); - } + { + /* conversion F16 -> any data type */ + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + } case DataType::F32: switch(_dst->info()->data_type()) { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { /* Down-conversion F32 -> F16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const float *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float32x4x4_t texels = - { - { - vld1q_f32(src_ptr + x), - vld1q_f32(src_ptr + x + 4), - vld1q_f32(src_ptr + x + 8), - vld1q_f32(src_ptr + x + 12) - } - }; - - vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1]))); - vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3]))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x)); - } - }, - src, dst); + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) case DataType::BFLOAT16: { /* Down-conversion F32 -> BFLOAT16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const float *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<bfloat16 *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(src.ptr()), - reinterpret_cast<uint16_t *>(dst.ptr())); - wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(src.ptr()) + 8, - reinterpret_cast<uint16_t *>(dst.ptr()) + 8); - } - - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = *(src_ptr + x); - } - }, - src, dst); + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ case DataType::S32: { /* Conversion F32 -> S32 */ @@ -1140,42 +913,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr case DataType::S32: switch(_dst->info()->data_type()) { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { /* Down-conversion S32 -> F16 */ - execute_window_loop(win, [&](const Coordinates &) - { - const auto src_ptr = reinterpret_cast<const int32_t *>(src.ptr()); - const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr()); - - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float32x4x4_t texels = - { - { - vcvtq_f32_s32(vld1q_s32(src_ptr + x)), - vcvtq_f32_s32(vld1q_s32(src_ptr + x + 4)), - vcvtq_f32_s32(vld1q_s32(src_ptr + x + 8)), - vcvtq_f32_s32(vld1q_s32(src_ptr + x + 12)) - } - }; - - vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1]))); - vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3]))); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x)); - } - }, - src, dst); + ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); + uk->ukernel(_src, _dst, info, _policy, window); break; } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: { /* Conversion S32 -> F32 */ @@ -1362,6 +1106,12 @@ const char *CpuCastKernel::name() const { return "CpuCastKernel.cpp"; } + +const std::vector<CpuCastKernel::CastKernel> &CpuCastKernel::get_available_kernels() +{ + return available_kernels; +} + } // namespace kernels } // namespace cpu } // namespace arm_compute |