aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuCastKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuCastKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuCastKernel.cpp408
1 files changed, 79 insertions, 329 deletions
diff --git a/src/cpu/kernels/CpuCastKernel.cpp b/src/cpu/kernels/CpuCastKernel.cpp
index db76df9076..e1314e61da 100644
--- a/src/cpu/kernels/CpuCastKernel.cpp
+++ b/src/cpu/kernels/CpuCastKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -32,10 +32,13 @@
#include "src/core/NEON/NEFixedPoint.h"
#include "src/core/NEON/NEMath.h"
#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/core/common/Registrars.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
#include "support/SaturateCast.h"
+#include "src/cpu/kernels/cast/list.h"
+
namespace arm_compute
{
namespace cpu
@@ -44,6 +47,50 @@ namespace kernels
{
namespace
{
+static const std::vector<CpuCastKernel::CastKernel> available_kernels =
+{
+ {
+ "neon_qs8_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::QASYMM8_SIGNED && data.dst_dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_qasymm8_signed_to_fp16_cast)
+ },
+ {
+ "neon_qu8_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::QASYMM8 && data.dst_dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_u8_to_fp16_cast)
+ },
+ {
+ "neon_u8_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::U8 && data.dst_dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_u8_to_fp16_cast)
+ },
+ {
+ "neon_fp16_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_to_other_dt_cast)
+ },
+ {
+ "neon_fp32_to_fp16_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F32 && data.dst_dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_fp32_to_fp16_cast)
+ },
+ {
+ "neon_fp32_to_bf16_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F32 && data.dst_dt == DataType::BFLOAT16 && data.isa.bf16; },
+ REGISTER_BF16_NEON(arm_compute::cpu::neon_fp32_to_bfloat16_cast)
+ },
+ {
+ "neon_s32_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::S32 && data.dst_dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_s32_to_fp16_cast)
+ },
+ {
+ "neon_bf16_cast",
+ [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::BFLOAT16 && data.dst_dt == DataType::F32 && data.isa.bf16; },
+ REGISTER_BF16_NEON(arm_compute::cpu::neon_bfloat16_to_fp32_cast)
+ },
+};
+
Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, ConvertPolicy policy)
{
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src);
@@ -151,6 +198,9 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
Iterator src(_src, win);
Iterator dst(_dst, win);
+ /*ukernel runs only when using fp16/bfloat16, so we validate it isn't a nullptr only before using it */
+ const auto *uk = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ _src->info()->data_type(), _dst->info()->data_type(), CPUInfo::get().get_isa() });
+
switch(_src->info()->data_type())
{
case DataType::QASYMM8_SIGNED:
@@ -262,42 +312,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
src, dst);
break;
}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
/* Up-conversion QASYMM8_SIGNED -> F16 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const int8_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
- int x = window_start_x;
-
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const int8x16_t texels_s8 = vld1q_s8(src_ptr + x);
-
- const int16x8x2_t texels =
- {
- {
- vmovl_s8(vget_low_s8(texels_s8)),
- vmovl_s8(vget_high_s8(texels_s8))
- }
- };
- vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0]));
- vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
- }
- },
- src, dst);
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
default:
ARM_COMPUTE_ERROR("dst data type not supported");
}
@@ -414,41 +435,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
src, dst);
break;
}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
- /* Up-conversion U8 -> F16 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const uint8_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const uint8x16_t texels_u8 = vld1q_u8(src_ptr + x);
-
- const int16x8x2_t texels =
- {
- {
- vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))),
- vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8)))
- }
- };
- vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0]));
- vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
- }
- },
- src, dst);
+ /* Up-conversion U8 -> FP16 */
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::U16:
{
/* Up-conversion U8 -> U16 */
@@ -668,6 +661,7 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
}
break;
}
+
case DataType::U16:
{
switch(_dst->info()->data_type())
@@ -775,258 +769,37 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
}
break;
}
-#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
case DataType::BFLOAT16:
- switch(_dst->info()->data_type())
- {
- case DataType::F32:
- {
- /* Up-conversion BFLOAT16 -> F32 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const bfloat16 *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<float *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const uint16x8x2_t texels =
- {
- {
- vld1q_u16(reinterpret_cast<uint16_t *>(src.ptr())),
- vld1q_u16(reinterpret_cast<uint16_t *>(src.ptr()) + 8)
- }
- };
-
- vst1q_f32(reinterpret_cast<float *>(dst.ptr()),
- vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[0])), 16)));
- vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 4,
- vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[0])), 16)));
- vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 8,
- vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[1])), 16)));
- vst1q_f32(reinterpret_cast<float *>(dst.ptr()) + 12,
- vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[1])), 16)));
- }
-
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = float(*(src_ptr + x));
- }
- },
- src, dst);
- break;
- }
- default:
- ARM_COMPUTE_ERROR("dst data type unsupported");
- }
+ {
+ /* Up-conversion BFLOAT16 -> F32 */
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
-#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ }
case DataType::F16:
- switch(_dst->info()->data_type())
- {
- case DataType::QASYMM8_SIGNED:
- {
- /* Down-conversion F16 -> QASYMM8_SIGNED (Always saturating) */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<int8_t *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float16x8x2_t texels =
- {
- {
- vld1q_f16(src_ptr + x),
- vld1q_f16(src_ptr + x + 8),
- }
- };
-
- vst1q_s8(dst_ptr + x, vcombine_s8(vqmovn_s16(vcvtq_s16_f16(texels.val[0])), vqmovn_s16(vcvtq_s16_f16(texels.val[1]))));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = utils::cast::saturate_cast<int8_t>(*(src_ptr + x));
- }
- },
- src, dst);
- break;
- }
- case DataType::QASYMM8:
- case DataType::U8:
- {
- /* Down-conversion F16 -> QASYMM8/U8 (Always saturating) */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<uint8_t *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float16x8x2_t texels =
- {
- {
- vld1q_f16(src_ptr + x),
- vld1q_f16(src_ptr + x + 8),
- }
- };
-
- vst1q_u8(dst_ptr + x, vcombine_u8(vqmovun_s16(vcvtq_s16_f16(texels.val[0])), vqmovun_s16(vcvtq_s16_f16(texels.val[1]))));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = utils::cast::saturate_cast<uint8_t>(*(src_ptr + x));
- }
-
- },
- src, dst);
- break;
- }
- case DataType::F32:
- {
- /* Up-conversion F16 -> F32 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<float *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float16x8x2_t texels =
- {
- {
- vld1q_f16(src_ptr + x),
- vld1q_f16(src_ptr + x + 8)
- }
- };
- vst1q_f32(dst_ptr + x, vcvt_f32_f16(vget_low_f16(texels.val[0])));
- vst1q_f32(dst_ptr + x + 4, vcvt_f32_f16(vget_high_f16(texels.val[0])));
- vst1q_f32(dst_ptr + x + 8, vcvt_f32_f16(vget_low_f16(texels.val[1])));
- vst1q_f32(dst_ptr + x + 12, vcvt_f32_f16(vget_high_f16(texels.val[1])));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = static_cast<float>(*(src_ptr + x));
- }
- },
- src, dst);
- break;
- }
- case DataType::S32:
- {
- /* Up-conversion F16 -> S32 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<int32_t *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float16x8x2_t texels =
- {
- {
- vld1q_f16(src_ptr + x),
- vld1q_f16(src_ptr + x + 8)
- }
- };
-
- vst1q_s32(dst_ptr + x, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[0]))));
- vst1q_s32(dst_ptr + x + 4, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[0]))));
- vst1q_s32(dst_ptr + x + 8, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[1]))));
- vst1q_s32(dst_ptr + x + 12, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[1]))));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = static_cast<int32_t>(*(src_ptr + x));
- }
- },
- src, dst);
- break;
- }
- default:
- ARM_COMPUTE_ERROR("dst data type not supported");
- }
+ {
+ /* conversion F16 -> any data type */
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ }
case DataType::F32:
switch(_dst->info()->data_type())
{
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
/* Down-conversion F32 -> F16 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const float *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float32x4x4_t texels =
- {
- {
- vld1q_f32(src_ptr + x),
- vld1q_f32(src_ptr + x + 4),
- vld1q_f32(src_ptr + x + 8),
- vld1q_f32(src_ptr + x + 12)
- }
- };
-
- vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1])));
- vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3])));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
- }
- },
- src, dst);
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
case DataType::BFLOAT16:
{
/* Down-conversion F32 -> BFLOAT16 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const float *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<bfloat16 *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(src.ptr()),
- reinterpret_cast<uint16_t *>(dst.ptr()));
- wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(src.ptr()) + 8,
- reinterpret_cast<uint16_t *>(dst.ptr()) + 8);
- }
-
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = *(src_ptr + x);
- }
- },
- src, dst);
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
}
-#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
case DataType::S32:
{
/* Conversion F32 -> S32 */
@@ -1140,42 +913,13 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
case DataType::S32:
switch(_dst->info()->data_type())
{
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
/* Down-conversion S32 -> F16 */
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto src_ptr = reinterpret_cast<const int32_t *>(src.ptr());
- const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
-
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const float32x4x4_t texels =
- {
- {
- vcvtq_f32_s32(vld1q_s32(src_ptr + x)),
- vcvtq_f32_s32(vld1q_s32(src_ptr + x + 4)),
- vcvtq_f32_s32(vld1q_s32(src_ptr + x + 8)),
- vcvtq_f32_s32(vld1q_s32(src_ptr + x + 12))
- }
- };
-
- vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1])));
- vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3])));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
- }
- },
- src, dst);
+ ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
+ uk->ukernel(_src, _dst, info, _policy, window);
break;
}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
{
/* Conversion S32 -> F32 */
@@ -1362,6 +1106,12 @@ const char *CpuCastKernel::name() const
{
return "CpuCastKernel.cpp";
}
+
+const std::vector<CpuCastKernel::CastKernel> &CpuCastKernel::get_available_kernels()
+{
+ return available_kernels;
+}
+
} // namespace kernels
} // namespace cpu
} // namespace arm_compute