aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuCastKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuCastKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuCastKernel.cpp153
1 files changed, 151 insertions, 2 deletions
diff --git a/src/cpu/kernels/CpuCastKernel.cpp b/src/cpu/kernels/CpuCastKernel.cpp
index 15a9ddcab4..641dea40dc 100644
--- a/src/cpu/kernels/CpuCastKernel.cpp
+++ b/src/cpu/kernels/CpuCastKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2022 Arm Limited.
+ * Copyright (c) 2016-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -99,9 +99,16 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, Conver
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(dst);
ARM_COMPUTE_UNUSED(policy);
ARM_COMPUTE_RETURN_ERROR_ON(src == dst);
+#ifdef __aarch64__
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
+ DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
+ DataType::F32, DataType::S32, DataType::S64, DataType::U64);
+#else // __aarch64__
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
DataType::F32, DataType::S32);
+#endif // __aarch64__
+
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
DataType::U32, DataType::S32, DataType::F32);
@@ -141,6 +148,13 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, Conver
&& dst->data_type() != DataType::F16
&& dst->data_type() != DataType::F32 && dst->data_type() != DataType::U8),
"Only data_types supported [in] S32 -> [out] QASYMM8, F16, F32, U8");
+#ifdef __aarch64__
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::S64 && dst->data_type() != DataType::F32,
+ "Only data_types supported [in] S64 -> [out] F32");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::U64 && dst->data_type() != DataType::F32,
+ "Only data_types supported [in] U64 -> [out] F32");
+#endif // __aarch64__
// Validate in case of configured dst
if(dst->total_size() > 0)
@@ -174,6 +188,111 @@ Status CpuCastKernel::validate(const ITensorInfo *src, const ITensorInfo *dst, C
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, dst, policy));
return Status{};
}
+#ifdef __aarch64__
+namespace
+{
+template <typename T1, typename T2>
+inline void internal_neon_convert(const T1 *src_ptr, T2 *dst_ptr)
+{
+ ARM_COMPUTE_UNUSED(src_ptr);
+ ARM_COMPUTE_UNUSED(dst_ptr);
+}
+
+template <>
+inline void internal_neon_convert<int64_t, float>(const int64_t *src_ptr, float *dst_ptr)
+{
+ const float64x2x4_t texels0 =
+ {
+ {
+ vcvtq_f64_s64(vld1q_s64(src_ptr)),
+ vcvtq_f64_s64(vld1q_s64(src_ptr + 2)),
+ vcvtq_f64_s64(vld1q_s64(src_ptr + 4)),
+ vcvtq_f64_s64(vld1q_s64(src_ptr + 6))
+ }
+ };
+ const float64x2x4_t texels1 =
+ {
+ {
+ vcvtq_f64_s64(vld1q_s64(src_ptr + 8)),
+ vcvtq_f64_s64(vld1q_s64(src_ptr + 10)),
+ vcvtq_f64_s64(vld1q_s64(src_ptr + 12)),
+ vcvtq_f64_s64(vld1q_s64(src_ptr + 14))
+ }
+ };
+ const float32x4x4_t texels =
+ {
+ {
+ vcombine_f32(vcvt_f32_f64(texels0.val[0]), vcvt_f32_f64(texels0.val[1])),
+ vcombine_f32(vcvt_f32_f64(texels0.val[2]), vcvt_f32_f64(texels0.val[3])),
+ vcombine_f32(vcvt_f32_f64(texels1.val[0]), vcvt_f32_f64(texels1.val[1])),
+ vcombine_f32(vcvt_f32_f64(texels1.val[2]), vcvt_f32_f64(texels1.val[3]))
+ }
+ };
+ vst1q_f32(dst_ptr, texels.val[0]);
+ vst1q_f32(dst_ptr + 4, texels.val[1]);
+ vst1q_f32(dst_ptr + 8, texels.val[2]);
+ vst1q_f32(dst_ptr + 12, texels.val[3]);
+}
+
+template <>
+inline void internal_neon_convert<uint64_t, float>(const uint64_t *src_ptr, float *dst_ptr)
+{
+ const float64x2x4_t texels0 =
+ {
+ {
+ vcvtq_f64_u64(vld1q_u64(src_ptr)),
+ vcvtq_f64_u64(vld1q_u64(src_ptr + 2)),
+ vcvtq_f64_u64(vld1q_u64(src_ptr + 4)),
+ vcvtq_f64_u64(vld1q_u64(src_ptr + 6))
+ }
+ };
+ const float64x2x4_t texels1 =
+ {
+ {
+ vcvtq_f64_u64(vld1q_u64(src_ptr + 8)),
+ vcvtq_f64_u64(vld1q_u64(src_ptr + 10)),
+ vcvtq_f64_u64(vld1q_u64(src_ptr + 12)),
+ vcvtq_f64_u64(vld1q_u64(src_ptr + 14))
+ }
+ };
+
+ const float32x4x4_t texels =
+ {
+ {
+ vcombine_f32(vcvt_f32_f64(texels0.val[0]), vcvt_f32_f64(texels0.val[1])),
+ vcombine_f32(vcvt_f32_f64(texels0.val[2]), vcvt_f32_f64(texels0.val[3])),
+ vcombine_f32(vcvt_f32_f64(texels1.val[0]), vcvt_f32_f64(texels1.val[1])),
+ vcombine_f32(vcvt_f32_f64(texels1.val[2]), vcvt_f32_f64(texels1.val[3]))
+ }
+ };
+
+ vst1q_f32(dst_ptr, texels.val[0]);
+ vst1q_f32(dst_ptr + 4, texels.val[1]);
+ vst1q_f32(dst_ptr + 8, texels.val[2]);
+ vst1q_f32(dst_ptr + 12, texels.val[3]);
+}
+
+template <typename T1, typename T2>
+inline void convert64(Iterator &src, Iterator &dst, const Window &win, int window_start_x, int window_end_x, int window_step_x)
+{
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto src_ptr = reinterpret_cast<const T1 *>(src.ptr());
+ const auto dst_ptr = reinterpret_cast<T2 *>(dst.ptr());
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ internal_neon_convert<T1, T2>(src_ptr + x, dst_ptr + x);
+ }
+ for(; x < window_end_x; ++x)
+ {
+ *(dst_ptr + x) = static_cast<T2>(*(src_ptr + x));
+ }
+ },
+ src, dst);
+}
+} // namespace
+#endif // __aarch64__
void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
{
@@ -203,6 +322,37 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
switch(_src->info()->data_type())
{
+#ifdef __aarch64__
+ case DataType::U64:
+ {
+ switch(_dst->info()->data_type())
+ {
+ case DataType::F32:
+ {
+ convert64<uint64_t, float>(src, dst, win, window_start_x, window_end_x, window_step_x);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("dst data type not supported");
+ }
+ break;
+ }
+ case DataType::S64:
+ {
+ switch(_dst->info()->data_type())
+ {
+ case DataType::F32:
+ {
+ convert64<int64_t, float>(src, dst, win, window_start_x, window_end_x, window_step_x);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("dst data type not supported");
+ }
+ break;
+ }
+#endif // __aarch64__
+
case DataType::QASYMM8_SIGNED:
{
switch(_dst->info()->data_type())
@@ -909,7 +1059,6 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
ARM_COMPUTE_ERROR("dst data type not supported");
}
break;
-
case DataType::S32:
switch(_dst->info()->data_type())
{