aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYair Schwarzbaum <yair.schwarzbaum@arm.com>2022-02-01 08:55:56 +0200
committerYair Schwarzbaum <yair.schwarzbaum@arm.com>2022-03-03 10:16:39 +0000
commit298b2c0526615fc1f0242c2792fe2c51a4f0c44a (patch)
treee47e5986e805e29fed4afca59c76e5375076cff2
parent918a9fb4aa4be23ca4261c241e9e52acc42f9bb3 (diff)
downloadComputeLibrary-298b2c0526615fc1f0242c2792fe2c51a4f0c44a.tar.gz
Decouple castKernel
Resolves: COMPMID-4625 Signed-off-by: Yair Schwarzbaum <yair.schwarzbaum@arm.com> Change-Id: I3c30f007804b179e5e2b439f421fbd4e57fb02e1 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7149 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
-rw-r--r--Android.bp2
-rw-r--r--arm_compute/core/Utils.h3
-rw-r--r--filelist.json8
-rw-r--r--src/core/common/Registrars.h8
-rw-r--r--src/cpu/kernels/CpuCastKernel.cpp408
-rw-r--r--src/cpu/kernels/CpuCastKernel.h12
-rw-r--r--src/cpu/kernels/CpuKernelSelectionTypes.h8
-rw-r--r--src/cpu/kernels/cast/generic/neon/bfloat16.cpp144
-rw-r--r--src/cpu/kernels/cast/generic/neon/fp16.cpp396
-rw-r--r--src/cpu/kernels/cast/list.h44
-rw-r--r--tests/validation/NEON/Cast.cpp72
11 files changed, 771 insertions, 334 deletions
diff --git a/Android.bp b/Android.bp
index a279fdf5bb..340aeeed23 100644
--- a/Android.bp
+++ b/Android.bp
@@ -439,6 +439,8 @@ cc_library_static {
"src/cpu/kernels/boundingboxtransform/generic/neon/fp32.cpp",
"src/cpu/kernels/boundingboxtransform/generic/neon/impl.cpp",
"src/cpu/kernels/boundingboxtransform/generic/neon/qsymm16.cpp",
+ "src/cpu/kernels/cast/generic/neon/bfloat16.cpp",
+ "src/cpu/kernels/cast/generic/neon/fp16.cpp",
"src/cpu/kernels/crop/generic/neon/fp16.cpp",
"src/cpu/kernels/crop/generic/neon/fp32.cpp",
"src/cpu/kernels/crop/generic/neon/impl.cpp",
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h
index fd9a0ee708..2d774770ae 100644
--- a/arm_compute/core/Utils.h
+++ b/arm_compute/core/Utils.h
@@ -1241,6 +1241,9 @@ inline std::string cpu_impl_dt(const DataType &data_type)
case DataType::QSYMM8_PER_CHANNEL:
ret = "qp8";
break;
+ case DataType::BFLOAT16:
+ ret = "bf16";
+ break;
default:
ARM_COMPUTE_ERROR("Unsupported.");
}
diff --git a/filelist.json b/filelist.json
index 3bdc00aeef..81b28f7f4b 100644
--- a/filelist.json
+++ b/filelist.json
@@ -969,8 +969,12 @@
"common": [
"src/cpu/operators/CpuCast.cpp",
"src/cpu/kernels/CpuCastKernel.cpp",
- "src/runtime/NEON/functions/NECast.cpp"
- ]
+ "src/runtime/NEON/functions/NECast.cpp",
+ "src/cpu/kernels/cast/generic/neon/bfloat16.cpp"
+ ],
+ "neon":{
+ "fp16":["src/cpu/kernels/cast/generic/neon/fp16.cpp"]
+ }
}
},
"ChannelShuffle": {
diff --git a/src/core/common/Registrars.h b/src/core/common/Registrars.h
index c7fbf7f831..cc76de2be5 100644
--- a/src/core/common/Registrars.h
+++ b/src/core/common/Registrars.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2021 Arm Limited.
+ * Copyright (c) 2020-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -167,4 +167,10 @@
#define REGISTER_INTEGER_SVE2(func_name) nullptr
#endif /* defined(ENABLE_INTEGER_KERNELS) */
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+#define REGISTER_BF16_NEON(func_name) &(func_name)
+#else /* !(defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16))*/
+#define REGISTER_BF16_NEON(func_name) nullptr
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)*/
+
#endif /* SRC_CORE_COMMON_REGISTRARS_H */
diff --git a/src/cpu/kernels/CpuCastKernel.cpp b/src/cpu/kernels/CpuCastKernel.cpp
index db76df9076..e1314e61da 100644
--- a/src/cpu/kernels/CpuCastKernel.cpp
+++ b/src/cpu/kernels/CpuCastKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -32,10 +32,13 @@
#include "src/core/NEON/NEFixedPoint.h"
#include "src/core/NEON/NEMath.h"
#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/core/common/Registrars.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
#include "support/SaturateCast.h"
+#include "src/cpu/kernels/cast/list.h"
+
namespace arm_compute
{
namespace cpu
@@ -44,6 +47,50 @@ namespace kernels
{
namespace
{
+static const std::vector<CpuCastKernel::CastKernel> available_kernels =
+{
+ {
+ "neon_qs8_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::QASYMM8_SIGNED && data.dst_dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_qasymm8_signed_to_fp16_cast)
+ },
+ {
+ "neon_qu8_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::QASYMM8 && data.dst_dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_u8_to_fp16_cast)
+ },
+ {
+ "neon_u8_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::U8 && data.dst_dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_u8_to_fp16_cast)
+ },
+ {
+ "neon_fp16_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_to_other_dt_cast)
+ },
+ {
+ "neon_fp32_to_fp16_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F32 && data.dst_dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_fp32_to_fp16_cast)
+ },
+ {
+ "neon_fp32_to_bf16_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F32 && data.dst_dt == DataType::BFLOAT16 && data.isa.bf16; },
+ REGISTER_BF16_NEON(arm_compute::cpu::neon_fp32_to_bfloat16_cast)
+ },
+ {
+ "neon_s32_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::S32 && data.dst_dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_s32_to_fp16_cast)
+ },
+ {
+ "neon_bf16_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::BFLOAT16 && data.dst_dt == DataType::F32 && data.isa.bf16; },
+ REGISTER_BF16_NEON(arm_compute::cpu::neon_bfloat16_to_fp32_cast)
+ },
+};
+
Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, ConvertPolicy policy)
{
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src);
@@ -151,6 +198,9 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
Iterator src(_src, win);
Iterator dst(_dst, win);
+ /*ukernel runs only when using fp16/bfloat16, so we validate it isn't a nullptr only before using it */
+ const auto *uk = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ _src->info()->data_type(), _dst->info()->data_type(), CPUInfo::get().get_isa() });
+
switch(_src->info()->data_type())
{
case DataType::QASYMM8_SIGNED:
@@ -262,42 +312,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
src, dst);
break;
}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
/* Up-conversion QASYMM8_SIGNED -> F16 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const int8_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
- int x = window_start_x;
-
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const int8x16_t texels_s8 = vld1q_s8(src_ptr + x);
-
- const int16x8x2_t texels =
- {
- {
- vmovl_s8(vget_low_s8(texels_s8)),
- vmovl_s8(vget_high_s8(texels_s8))
- }
- };
- vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0]));
- vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
- }
- },
- src, dst);
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
default:
ARM_COMPUTE_ERROR("dst data type not supported");
}
@@ -414,41 +435,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
src, dst);
break;
}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
- /* Up-conversion U8 -> F16 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const uint8_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const uint8x16_t texels_u8 = vld1q_u8(src_ptr + x);
-
- const int16x8x2_t texels =
- {
- {
- vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))),
- vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8)))
- }
- };
- vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0]));
- vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
- }
- },
- src, dst);
+ /* Up-conversion U8 -> FP16 */
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::U16:
{
/* Up-conversion U8 -> U16 */
@@ -668,6 +661,7 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
}
break;
}
+
case DataType::U16:
{
switch(_dst->info()->data_type())
@@ -775,258 +769,37 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
}
break;
}
-#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
case DataType::BFLOAT16:
- switch(_dst->info()->data_type())
- {
- case DataType::F32:
- {
- /* Up-conversion BFLOAT16 -> F32 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const bfloat16 *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<float *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const uint16x8x2_t texels =
- {
- {
- vld1q_u16(reinterpret_cast<uint16_t *>(src.ptr())),
- vld1q_u16(reinterpret_cast<uint16_t *>(src.ptr()) + 8)
- }
- };
-
- vst1q_f32(reinterpret_cast<float *>(dst.ptr()),
- vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[0])), 16)));
- vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 4,
- vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[0])), 16)));
- vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 8,
- vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[1])), 16)));
- vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 12,
- vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[1])), 16)));
- }
-
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = float(*(src_ptr + x));
- }
- },
- src, dst);
- break;
- }
- default:
- ARM_COMPUTE_ERROR("dst data type unsupported");
- }
+ {
+ /* Up-conversion BFLOAT16 -> F32 */
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
-#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ }
case DataType::F16:
- switch(_dst->info()->data_type())
- {
- case DataType::QASYMM8_SIGNED:
- {
- /* Down-conversion F16 -> QASYMM8_SIGNED (Always saturating) */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<int8_t *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float16x8x2_t texels =
- {
- {
- vld1q_f16(src_ptr + x),
- vld1q_f16(src_ptr + x + 8),
- }
- };
-
- vst1q_s8(dst_ptr + x, vcombine_s8(vqmovn_s16(vcvtq_s16_f16(texels.val[0])), vqmovn_s16(vcvtq_s16_f16(texels.val[1]))));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = utils::cast::saturate_cast<int8_t>(*(src_ptr + x));
- }
- },
- src, dst);
- break;
- }
- case DataType::QASYMM8:
- case DataType::U8:
- {
- /* Down-conversion F16 -> QASYMM8/U8 (Always saturating) */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<uint8_t *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float16x8x2_t texels =
- {
- {
- vld1q_f16(src_ptr + x),
- vld1q_f16(src_ptr + x + 8),
- }
- };
-
- vst1q_u8(dst_ptr + x, vcombine_u8(vqmovun_s16(vcvtq_s16_f16(texels.val[0])), vqmovun_s16(vcvtq_s16_f16(texels.val[1]))));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = utils::cast::saturate_cast<uint8_t>(*(src_ptr + x));
- }
-
- },
- src, dst);
- break;
- }
- case DataType::F32:
- {
- /* Up-conversion F16 -> F32 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<float *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float16x8x2_t texels =
- {
- {
- vld1q_f16(src_ptr + x),
- vld1q_f16(src_ptr + x + 8)
- }
- };
- vst1q_f32(dst_ptr + x, vcvt_f32_f16(vget_low_f16(texels.val[0])));
- vst1q_f32(dst_ptr + x + 4, vcvt_f32_f16(vget_high_f16(texels.val[0])));
- vst1q_f32(dst_ptr + x + 8, vcvt_f32_f16(vget_low_f16(texels.val[1])));
- vst1q_f32(dst_ptr + x + 12, vcvt_f32_f16(vget_high_f16(texels.val[1])));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = static_cast<float>(*(src_ptr + x));
- }
- },
- src, dst);
- break;
- }
- case DataType::S32:
- {
- /* Up-conversion F16 -> S32 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<int32_t *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float16x8x2_t texels =
- {
- {
- vld1q_f16(src_ptr + x),
- vld1q_f16(src_ptr + x + 8)
- }
- };
-
- vst1q_s32(dst_ptr + x, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[0]))));
- vst1q_s32(dst_ptr + x + 4, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[0]))));
- vst1q_s32(dst_ptr + x + 8, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[1]))));
- vst1q_s32(dst_ptr + x + 12, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[1]))));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = static_cast<int32_t>(*(src_ptr + x));
- }
- },
- src, dst);
- break;
- }
- default:
- ARM_COMPUTE_ERROR("dst data type not supported");
- }
+ {
+ /* conversion F16 -> any data type */
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ }
case DataType::F32:
switch(_dst->info()->data_type())
{
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
/* Down-conversion F32 -> F16 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const float *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float32x4x4_t texels =
- {
- {
- vld1q_f32(src_ptr + x),
- vld1q_f32(src_ptr + x + 4),
- vld1q_f32(src_ptr + x + 8),
- vld1q_f32(src_ptr + x + 12)
- }
- };
-
- vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1])));
- vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3])));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
- }
- },
- src, dst);
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
case DataType::BFLOAT16:
{
/* Down-conversion F32 -> BFLOAT16 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const float *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<bfloat16 *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(src.ptr()),
- reinterpret_cast<uint16_t *>(dst.ptr()));
- wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(src.ptr()) + 8,
- reinterpret_cast<uint16_t *>(dst.ptr()) + 8);
- }
-
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = *(src_ptr + x);
- }
- },
- src, dst);
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
}
-#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
case DataType::S32:
{
/* Conversion F32 -> S32 */
@@ -1140,42 +913,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
case DataType::S32:
switch(_dst->info()->data_type())
{
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
/* Down-conversion S32 -> F16 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const int32_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float32x4x4_t texels =
- {
- {
- vcvtq_f32_s32(vld1q_s32(src_ptr + x)),
- vcvtq_f32_s32(vld1q_s32(src_ptr + x + 4)),
- vcvtq_f32_s32(vld1q_s32(src_ptr + x + 8)),
- vcvtq_f32_s32(vld1q_s32(src_ptr + x + 12))
- }
- };
-
- vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1])));
- vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3])));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
- }
- },
- src, dst);
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
{
/* Conversion S32 -> F32 */
@@ -1362,6 +1106,12 @@ const char *CpuCastKernel::name() const
{
return "CpuCastKernel.cpp";
}
+
+const std::vector<CpuCastKernel::CastKernel> &CpuCastKernel::get_available_kernels()
+{
+ return available_kernels;
+}
+
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/CpuCastKernel.h b/src/cpu/kernels/CpuCastKernel.h
index 7679178fa1..95d46fad23 100644
--- a/src/cpu/kernels/CpuCastKernel.h
+++ b/src/cpu/kernels/CpuCastKernel.h
@@ -39,6 +39,9 @@ namespace kernels
*/
class CpuCastKernel : public ICpuKernel<CpuCastKernel>
{
+private:
+ using CastKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const ThreadInfo &, ConvertPolicy, const Window &)>::type;
+
public:
CpuCastKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuCastKernel);
@@ -73,6 +76,15 @@ public:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
+ struct CastKernel
+ {
+ const char *name;
+ const CastDataTypeISASelectorDataPtr is_selected;
+ CastKernelPtr ukernel;
+ };
+
+ static const std::vector<CastKernel> &get_available_kernels();
+
private:
ConvertPolicy _policy{ ConvertPolicy::SATURATE };
};
diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h
index 8c5a39ad49..afcf014ad2 100644
--- a/src/cpu/kernels/CpuKernelSelectionTypes.h
+++ b/src/cpu/kernels/CpuKernelSelectionTypes.h
@@ -47,6 +47,13 @@ struct DataTypeDataLayoutISASelectorData
const cpuinfo::CpuIsaInfo &isa;
};
+struct CastDataTypeISASelectorData
+{
+ DataType src_dt;
+ DataType dst_dt;
+ const cpuinfo::CpuIsaInfo &isa;
+};
+
struct PoolDataTypeISASelectorData
{
DataType dt;
@@ -74,6 +81,7 @@ using DataTypeDataLayoutSelectorPtr = std::add_pointer<bool(const
using PoolDataTypeISASelectorPtr = std::add_pointer<bool(const PoolDataTypeISASelectorData &data)>::type;
using ElementwiseDataTypeISASelectorPtr = std::add_pointer<bool(const ElementwiseDataTypeISASelectorData &data)>::type;
using DepthwiseConv2dNativeDataTypeISASelectorPtr = std::add_pointer<bool(const DepthwiseConv2dNativeDataTypeISASelectorData &data)>::type;
+using CastDataTypeISASelectorDataPtr = std::add_pointer<bool(const CastDataTypeISASelectorData &data)>::type;
} // namespace kernels
} // namespace cpu
diff --git a/src/cpu/kernels/cast/generic/neon/bfloat16.cpp b/src/cpu/kernels/cast/generic/neon/bfloat16.cpp
new file mode 100644
index 0000000000..b15584b0aa
--- /dev/null
+++ b/src/cpu/kernels/cast/generic/neon/bfloat16.cpp
@@ -0,0 +1,144 @@
+/*
+ * Copyright (c) 2016-2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+
+#include "arm_compute/core/TensorInfo.h"
+#include "src/cpu/kernels/CpuCastKernel.h"
+#include "src/cpu/kernels/cast/list.h"
+#include "support/SaturateCast.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void neon_fp32_to_bfloat16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_UNUSED(_policy);
+
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+ const int window_step_x = 16;
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+ ARM_COMPUTE_ERROR_ON(_src == _dst);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+
+ Window win{ window };
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator src(_src, win);
+ Iterator dst(_dst, win);
+
+ /* Down-conversion F32 -> BFLOAT16 */
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto src_ptr = reinterpret_cast<const float *>(src.ptr());
+ const auto dst_ptr = reinterpret_cast<bfloat16 *>(dst.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(src.ptr()),
+ reinterpret_cast<uint16_t *>(dst.ptr()));
+ wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(src.ptr()) + 8,
+ reinterpret_cast<uint16_t *>(dst.ptr()) + 8);
+ }
+
+ for(; x < window_end_x; ++x)
+ {
+ *(dst_ptr + x) = *(src_ptr + x);
+ }
+ },
+ src, dst);
+}
+
+void neon_bfloat16_to_fp32_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_UNUSED(_policy);
+
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+ const int window_step_x = 16;
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+ ARM_COMPUTE_ERROR_ON(_src == _dst);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+
+ Window win{ window };
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator src(_src, win);
+ Iterator dst(_dst, win);
+ switch(_dst->info()->data_type())
+ {
+ case DataType::F32:
+ {
+ /* Up-conversion BFLOAT16 -> F32 */
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto src_ptr = reinterpret_cast<const bfloat16 *>(src.ptr());
+ const auto dst_ptr = reinterpret_cast<float *>(dst.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const uint16x8x2_t texels =
+ {
+ {
+ vld1q_u16(reinterpret_cast<uint16_t *>(src.ptr())),
+ vld1q_u16(reinterpret_cast<uint16_t *>(src.ptr()) + 8)
+ }
+ };
+
+ vst1q_f32(reinterpret_cast<float *>(dst.ptr()),
+ vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[0])), 16)));
+ vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 4,
+ vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[0])), 16)));
+ vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 8,
+ vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[1])), 16)));
+ vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 12,
+ vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[1])), 16)));
+ }
+
+ for(; x < window_end_x; ++x)
+ {
+ *(dst_ptr + x) = float(*(src_ptr + x));
+ }
+ },
+ src, dst);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("dst data type unsupported");
+ }
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
diff --git a/src/cpu/kernels/cast/generic/neon/fp16.cpp b/src/cpu/kernels/cast/generic/neon/fp16.cpp
new file mode 100644
index 0000000000..d2c66923cc
--- /dev/null
+++ b/src/cpu/kernels/cast/generic/neon/fp16.cpp
@@ -0,0 +1,396 @@
+/*
+ * Copyright (c) 2016-2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+
+#include "arm_compute/core/TensorInfo.h"
+#include "src/cpu/kernels/CpuCastKernel.h"
+#include "src/cpu/kernels/cast/list.h"
+#include "support/SaturateCast.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void neon_qasymm8_signed_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_UNUSED(_policy);
+
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+ const int window_step_x = 16;
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+ ARM_COMPUTE_ERROR_ON(_src == _dst);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+
+ Window win{ window };
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator src(_src, win);
+ Iterator dst(_dst, win);
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto src_ptr = reinterpret_cast<const int8_t *>(src.ptr());
+ const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
+ int x = window_start_x;
+
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const int8x16_t texels_s8 = vld1q_s8(src_ptr + x);
+
+ const int16x8x2_t texels =
+ {
+ {
+ vmovl_s8(vget_low_s8(texels_s8)),
+ vmovl_s8(vget_high_s8(texels_s8))
+ }
+ };
+ vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0]));
+ vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
+ }
+
+ // Compute left-over elements
+ for(; x < window_end_x; ++x)
+ {
+ *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
+ }
+ },
+ src, dst);
+}
+
+void neon_s32_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_UNUSED(_policy);
+
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+ const int window_step_x = 16;
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+ ARM_COMPUTE_ERROR_ON(_src == _dst);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+
+ Window win{ window };
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator src(_src, win);
+ Iterator dst(_dst, win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto src_ptr = reinterpret_cast<const int32_t *>(src.ptr());
+ const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const float32x4x4_t texels =
+ {
+ {
+ vcvtq_f32_s32(vld1q_s32(src_ptr + x)),
+ vcvtq_f32_s32(vld1q_s32(src_ptr + x + 4)),
+ vcvtq_f32_s32(vld1q_s32(src_ptr + x + 8)),
+ vcvtq_f32_s32(vld1q_s32(src_ptr + x + 12))
+ }
+ };
+
+ vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1])));
+ vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3])));
+ }
+
+ // Compute left-over elements
+ for(; x < window_end_x; ++x)
+ {
+ *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
+ }
+ },
+ src, dst);
+}
+
+void neon_fp32_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_UNUSED(_policy);
+
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+ const int window_step_x = 16;
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+ ARM_COMPUTE_ERROR_ON(_src == _dst);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+
+ Window win{ window };
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator src(_src, win);
+ Iterator dst(_dst, win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto src_ptr = reinterpret_cast<const float *>(src.ptr());
+ const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const float32x4x4_t texels =
+ {
+ {
+ vld1q_f32(src_ptr + x),
+ vld1q_f32(src_ptr + x + 4),
+ vld1q_f32(src_ptr + x + 8),
+ vld1q_f32(src_ptr + x + 12)
+ }
+ };
+
+ vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1])));
+ vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3])));
+ }
+
+ // Compute left-over elements
+ for(; x < window_end_x; ++x)
+ {
+ *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
+ }
+ },
+ src, dst);
+}
+
+void neon_fp16_to_other_dt_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_UNUSED(_policy);
+
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+ const int window_step_x = 16;
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+ ARM_COMPUTE_ERROR_ON(_src == _dst);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+
+ Window win{ window };
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator src(_src, win);
+ Iterator dst(_dst, win);
+ switch(_dst->info()->data_type())
+ {
+ case DataType::QASYMM8_SIGNED:
+ {
+ /* Down-conversion F16 -> QASYMM8_SIGNED (Always saturating) */
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
+ const auto dst_ptr = reinterpret_cast<int8_t *>(dst.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const float16x8x2_t texels =
+ {
+ {
+ vld1q_f16(src_ptr + x),
+ vld1q_f16(src_ptr + x + 8),
+ }
+ };
+
+ vst1q_s8(dst_ptr + x, vcombine_s8(vqmovn_s16(vcvtq_s16_f16(texels.val[0])), vqmovn_s16(vcvtq_s16_f16(texels.val[1]))));
+ }
+
+ // Compute left-over elements
+ for(; x < window_end_x; ++x)
+ {
+ *(dst_ptr + x) = utils::cast::saturate_cast<int8_t>(*(src_ptr + x));
+ }
+ },
+ src, dst);
+ break;
+ }
+ case DataType::QASYMM8:
+ case DataType::U8:
+ {
+ /* Down-conversion F16 -> QASYMM8/U8 (Always saturating) */
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
+ const auto dst_ptr = reinterpret_cast<uint8_t *>(dst.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const float16x8x2_t texels =
+ {
+ {
+ vld1q_f16(src_ptr + x),
+ vld1q_f16(src_ptr + x + 8),
+ }
+ };
+
+ vst1q_u8(dst_ptr + x, vcombine_u8(vqmovun_s16(vcvtq_s16_f16(texels.val[0])), vqmovun_s16(vcvtq_s16_f16(texels.val[1]))));
+ }
+
+ // Compute left-over elements
+ for(; x < window_end_x; ++x)
+ {
+ *(dst_ptr + x) = utils::cast::saturate_cast<uint8_t>(*(src_ptr + x));
+ }
+
+ },
+ src, dst);
+ break;
+ }
+ case DataType::F32:
+ {
+ /* Up-conversion F16 -> F32 */
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
+ const auto dst_ptr = reinterpret_cast<float *>(dst.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const float16x8x2_t texels =
+ {
+ {
+ vld1q_f16(src_ptr + x),
+ vld1q_f16(src_ptr + x + 8)
+ }
+ };
+ vst1q_f32(dst_ptr + x, vcvt_f32_f16(vget_low_f16(texels.val[0])));
+ vst1q_f32(dst_ptr + x + 4, vcvt_f32_f16(vget_high_f16(texels.val[0])));
+ vst1q_f32(dst_ptr + x + 8, vcvt_f32_f16(vget_low_f16(texels.val[1])));
+ vst1q_f32(dst_ptr + x + 12, vcvt_f32_f16(vget_high_f16(texels.val[1])));
+ }
+
+ // Compute left-over elements
+ for(; x < window_end_x; ++x)
+ {
+ *(dst_ptr + x) = static_cast<float>(*(src_ptr + x));
+ }
+ },
+ src, dst);
+ break;
+ }
+ case DataType::S32:
+ {
+ /* Up-conversion F16 -> S32 */
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
+ const auto dst_ptr = reinterpret_cast<int32_t *>(dst.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const float16x8x2_t texels =
+ {
+ {
+ vld1q_f16(src_ptr + x),
+ vld1q_f16(src_ptr + x + 8)
+ }
+ };
+
+ vst1q_s32(dst_ptr + x, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[0]))));
+ vst1q_s32(dst_ptr + x + 4, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[0]))));
+ vst1q_s32(dst_ptr + x + 8, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[1]))));
+ vst1q_s32(dst_ptr + x + 12, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[1]))));
+ }
+
+ // Compute left-over elements
+ for(; x < window_end_x; ++x)
+ {
+ *(dst_ptr + x) = static_cast<int32_t>(*(src_ptr + x));
+ }
+ },
+ src, dst);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("dst data type not supported");
+ }
+}
+
+void neon_u8_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_UNUSED(_policy);
+
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+ const int window_step_x = 16;
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+ ARM_COMPUTE_ERROR_ON(_src == _dst);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
+
+ Window win{ window };
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator src(_src, win);
+ Iterator dst(_dst, win);
+ /* Up-conversion U8 -> F16 */
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto src_ptr = reinterpret_cast<const uint8_t *>(src.ptr());
+ const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const uint8x16_t texels_u8 = vld1q_u8(src_ptr + x);
+
+ const int16x8x2_t texels =
+ {
+ {
+ vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))),
+ vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8)))
+ }
+ };
+ vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0]));
+ vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
+ }
+
+ // Compute left-over elements
+ for(; x < window_end_x; ++x)
+ {
+ *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
+ }
+ },
+ src, dst);
+ return;
+}
+
+} // namespace cpu
+} // namespace arm_compute
+#endif /* #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
diff --git a/src/cpu/kernels/cast/list.h b/src/cpu/kernels/cast/list.h
new file mode 100644
index 0000000000..ffd82d5bf3
--- /dev/null
+++ b/src/cpu/kernels/cast/list.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_CORE_NEON_KERNELS_CAST_LIST_H
+#define SRC_CORE_NEON_KERNELS_CAST_LIST_H
+namespace arm_compute
+{
+namespace cpu
+{
+#define DECLARE_CAST_KERNEL(func_name) \
+ void func_name(const ITensor *_src, ITensor *_dst, const ThreadInfo &tensor, ConvertPolicy _policy, const Window &window)
+
+DECLARE_CAST_KERNEL(neon_fp32_to_fp16_cast);
+DECLARE_CAST_KERNEL(neon_u8_to_fp16_cast);
+DECLARE_CAST_KERNEL(neon_fp16_to_other_dt_cast);
+DECLARE_CAST_KERNEL(neon_s32_to_fp16_cast);
+DECLARE_CAST_KERNEL(neon_qasymm8_signed_to_fp16_cast);
+DECLARE_CAST_KERNEL(neon_fp32_to_bfloat16_cast);
+DECLARE_CAST_KERNEL(neon_bfloat16_to_fp32_cast);
+
+#undef DECLARE_CAST_KERNEL
+} // namespace cpu
+} // namespace arm_compute
+#endif //SRC_CORE_NEON_KERNELS_CAST_LIST_H \ No newline at end of file
diff --git a/tests/validation/NEON/Cast.cpp b/tests/validation/NEON/Cast.cpp
index db73bea9cb..3a77106a42 100644
--- a/tests/validation/NEON/Cast.cpp
+++ b/tests/validation/NEON/Cast.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,6 +25,8 @@
#include "arm_compute/runtime/NEON/functions/NECast.h"
#include "arm_compute/runtime/Tensor.h"
#include "arm_compute/runtime/TensorAllocator.h"
+#include "src/common/cpuinfo/CpuIsaInfo.h"
+#include "src/cpu/kernels/CpuCastKernel.h"
#include "tests/NEON/Accessor.h"
#include "tests/PaddingCalculator.h"
#include "tests/datasets/ConvertPolicyDataset.h"
@@ -34,7 +36,6 @@
#include "tests/framework/datasets/Datasets.h"
#include "tests/validation/Validation.h"
#include "tests/validation/fixtures/CastFixture.h"
-
namespace arm_compute
{
namespace test
@@ -187,6 +188,73 @@ CAST_SUITE(F32_to_F16, DataType::F32, DataType::F16, NECastToF16Fixture<float>,
CAST_SUITE(F32_to_S32, DataType::F32, DataType::S32, NECastToS32Fixture<float>, CastF32toS32Dataset, one_tolerance)
CAST_SUITE(F32_to_U8, DataType::F32, DataType::S32, NECastToS32Fixture<float>, CastF32toS32Dataset, one_tolerance)
+DATA_TEST_CASE(KernelSelectionDstFP16, framework::DatasetMode::ALL,
+ combine(framework::dataset::make("CpuExt", std::string("NEON")),
+ framework::dataset::make("DataType",
+{
+ DataType::F16,
+ DataType::U8,
+ DataType::S32,
+ DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED,
+ DataType::BFLOAT16,
+})),
+cpu_ext, data_type)
+{
+ using namespace cpu::kernels;
+ const CpuCastKernel::CastKernel *selected_impl;
+
+ cpuinfo::CpuIsaInfo cpu_isa{};
+ cpu_isa.neon = (cpu_ext == "NEON");
+
+ cpu_isa.bf16 = (data_type == DataType::BFLOAT16);
+
+ /* bf16 cast is different from all the others being converted to fp32 and not to fp16 */
+ if(cpu_isa.bf16)
+ {
+ cpu_isa.fp16 = false;
+ selected_impl = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ data_type, DataType::F32, cpu_isa }, cpu::KernelSelectionType::Preferred);
+ }
+ else
+ {
+ cpu_isa.fp16 = true;
+ selected_impl = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ data_type, DataType::F16, cpu_isa }, cpu::KernelSelectionType::Preferred);
+ }
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl);
+
+ std::string expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_cast";
+ std::string actual = selected_impl->name;
+
+ ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS);
+}
+
+DATA_TEST_CASE(KernelSelectionSrcFP32, framework::DatasetMode::ALL,
+ combine(framework::dataset::make("CpuExt", std::string("NEON")),
+ framework::dataset::make("DataType",
+{
+ DataType::F16,
+ DataType::BFLOAT16,
+})),
+cpu_ext, data_type)
+{
+ using namespace cpu::kernels;
+
+ cpuinfo::CpuIsaInfo cpu_isa{};
+ cpu_isa.neon = (cpu_ext == "NEON");
+ cpu_isa.fp16 = (data_type == DataType::F16);
+ cpu_isa.bf16 = (data_type == DataType::BFLOAT16);
+
+ const auto *selected_impl = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ DataType::F32, data_type, cpu_isa }, cpu::KernelSelectionType::Preferred);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl);
+
+ std::string expected = lower_string(cpu_ext) + "_fp32_to_" + cpu_impl_dt(data_type) + "_cast";
+ std::string actual = selected_impl->name;
+
+ ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS);
+}
+
TEST_SUITE_END() // Cast
TEST_SUITE_END() // Neon
} // namespace validation