From 4a1c91767142f76e92bf4575564d7e54fcd0ebf4 Mon Sep 17 00:00:00 2001 From: Pablo Marquez Tello Date: Tue, 18 Jul 2023 14:51:24 +0100 Subject: Add support for input S64/U64 in CpuCastKernel * The kernel now supports the following conversions: S64 -> F32 U64 -> F32 * Resolves MLCE-1089 Change-Id: I277cf58b78d919fde25947520d2056e1412c7f82 Signed-off-by: Pablo Marquez Tello Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9935 Tested-by: Arm Jenkins Reviewed-by: Viet-Hoa Do Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- src/cpu/kernels/CpuCastKernel.cpp | 153 +++++++++++++++++++++++++++++++++++++- src/cpu/kernels/CpuCastKernel.h | 3 +- 2 files changed, 153 insertions(+), 3 deletions(-) (limited to 'src/cpu/kernels') diff --git a/src/cpu/kernels/CpuCastKernel.cpp b/src/cpu/kernels/CpuCastKernel.cpp index 15a9ddcab4..641dea40dc 100644 --- a/src/cpu/kernels/CpuCastKernel.cpp +++ b/src/cpu/kernels/CpuCastKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2022 Arm Limited. + * Copyright (c) 2016-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -99,9 +99,16 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, Conver ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(dst); ARM_COMPUTE_UNUSED(policy); ARM_COMPUTE_RETURN_ERROR_ON(src == dst); +#ifdef __aarch64__ + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, + DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16, + DataType::F32, DataType::S32, DataType::S64, DataType::U64); +#else // __aarch64__ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16, DataType::F32, DataType::S32); +#endif // __aarch64__ + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16, DataType::U32, DataType::S32, DataType::F32); @@ -141,6 +148,13 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, Conver && dst->data_type() != DataType::F16 && dst->data_type() != DataType::F32 && dst->data_type() != DataType::U8), "Only data_types supported [in] S32 -> [out] QASYMM8, F16, F32, U8"); +#ifdef __aarch64__ + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::S64 && dst->data_type() != DataType::F32, + "Only data_types supported [in] S64 -> [out] F32"); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::U64 && dst->data_type() != DataType::F32, + "Only data_types supported [in] U64 -> [out] F32"); +#endif // __aarch64__ // Validate in case of configured dst if(dst->total_size() > 0) @@ -174,6 +188,111 @@ Status CpuCastKernel::validate(const ITensorInfo *src, const ITensorInfo *dst, C ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, dst, policy)); return Status{}; } +#ifdef __aarch64__ +namespace +{ +template +inline void internal_neon_convert(const T1 *src_ptr, T2 *dst_ptr) +{ + ARM_COMPUTE_UNUSED(src_ptr); + ARM_COMPUTE_UNUSED(dst_ptr); +} + +template <> +inline void internal_neon_convert(const int64_t *src_ptr, float *dst_ptr) +{ + const float64x2x4_t texels0 = + { + { + vcvtq_f64_s64(vld1q_s64(src_ptr)), + vcvtq_f64_s64(vld1q_s64(src_ptr + 2)), + vcvtq_f64_s64(vld1q_s64(src_ptr + 4)), + vcvtq_f64_s64(vld1q_s64(src_ptr + 6)) + } + }; + const float64x2x4_t texels1 = + { + { + vcvtq_f64_s64(vld1q_s64(src_ptr + 8)), + vcvtq_f64_s64(vld1q_s64(src_ptr + 10)), + vcvtq_f64_s64(vld1q_s64(src_ptr + 12)), + vcvtq_f64_s64(vld1q_s64(src_ptr + 14)) + } + }; + const float32x4x4_t texels = + { + { + vcombine_f32(vcvt_f32_f64(texels0.val[0]), vcvt_f32_f64(texels0.val[1])), + vcombine_f32(vcvt_f32_f64(texels0.val[2]), vcvt_f32_f64(texels0.val[3])), + vcombine_f32(vcvt_f32_f64(texels1.val[0]), vcvt_f32_f64(texels1.val[1])), + vcombine_f32(vcvt_f32_f64(texels1.val[2]), vcvt_f32_f64(texels1.val[3])) + } + }; + vst1q_f32(dst_ptr, texels.val[0]); + vst1q_f32(dst_ptr + 4, texels.val[1]); + vst1q_f32(dst_ptr + 8, texels.val[2]); + vst1q_f32(dst_ptr + 12, texels.val[3]); +} + +template <> +inline void internal_neon_convert(const uint64_t *src_ptr, float *dst_ptr) +{ + const float64x2x4_t texels0 = + { + { + vcvtq_f64_u64(vld1q_u64(src_ptr)), + vcvtq_f64_u64(vld1q_u64(src_ptr + 2)), + vcvtq_f64_u64(vld1q_u64(src_ptr + 4)), + vcvtq_f64_u64(vld1q_u64(src_ptr + 6)) + } + }; + const float64x2x4_t texels1 = + { + { + vcvtq_f64_u64(vld1q_u64(src_ptr + 8)), + vcvtq_f64_u64(vld1q_u64(src_ptr + 10)), + vcvtq_f64_u64(vld1q_u64(src_ptr + 12)), + vcvtq_f64_u64(vld1q_u64(src_ptr + 14)) + } + }; + + const float32x4x4_t texels = + { + { + vcombine_f32(vcvt_f32_f64(texels0.val[0]), vcvt_f32_f64(texels0.val[1])), + vcombine_f32(vcvt_f32_f64(texels0.val[2]), vcvt_f32_f64(texels0.val[3])), + vcombine_f32(vcvt_f32_f64(texels1.val[0]), vcvt_f32_f64(texels1.val[1])), + vcombine_f32(vcvt_f32_f64(texels1.val[2]), vcvt_f32_f64(texels1.val[3])) + } + }; + + vst1q_f32(dst_ptr, texels.val[0]); + vst1q_f32(dst_ptr + 4, texels.val[1]); + vst1q_f32(dst_ptr + 8, texels.val[2]); + vst1q_f32(dst_ptr + 12, texels.val[3]); +} + +template +inline void convert64(Iterator &src, Iterator &dst, const Window &win, int window_start_x, int window_end_x, int window_step_x) +{ + execute_window_loop(win, [&](const Coordinates &) + { + const auto src_ptr = reinterpret_cast(src.ptr()); + const auto dst_ptr = reinterpret_cast(dst.ptr()); + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + internal_neon_convert(src_ptr + x, dst_ptr + x); + } + for(; x < window_end_x; ++x) + { + *(dst_ptr + x) = static_cast(*(src_ptr + x)); + } + }, + src, dst); +} +} // namespace +#endif // __aarch64__ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) { @@ -203,6 +322,37 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr switch(_src->info()->data_type()) { +#ifdef __aarch64__ + case DataType::U64: + { + switch(_dst->info()->data_type()) + { + case DataType::F32: + { + convert64(src, dst, win, window_start_x, window_end_x, window_step_x); + break; + } + default: + ARM_COMPUTE_ERROR("dst data type not supported"); + } + break; + } + case DataType::S64: + { + switch(_dst->info()->data_type()) + { + case DataType::F32: + { + convert64(src, dst, win, window_start_x, window_end_x, window_step_x); + break; + } + default: + ARM_COMPUTE_ERROR("dst data type not supported"); + } + break; + } +#endif // __aarch64__ + case DataType::QASYMM8_SIGNED: { switch(_dst->info()->data_type()) @@ -909,7 +1059,6 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr ARM_COMPUTE_ERROR("dst data type not supported"); } break; - case DataType::S32: switch(_dst->info()->data_type()) { diff --git a/src/cpu/kernels/CpuCastKernel.h b/src/cpu/kernels/CpuCastKernel.h index de4ace2140..76237368d8 100644 --- a/src/cpu/kernels/CpuCastKernel.h +++ b/src/cpu/kernels/CpuCastKernel.h @@ -57,9 +57,10 @@ public: * - BFLOAT16 -> F32 * - F16 -> QASYMM8_SIGNED, QASYMM8, F32, S32, U8 * - S32 -> QASYMM8_SIGNED, QASYMM8, F16, F32, U8 + * - S64 -> F32 * - F32 -> QASYMM8_SIGNED, QASYMM8, BFLOAT16, F16, S32, U8 * - * @param[in] src The src tensor to convert. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/BFLOAT16/F16/F32. + * @param[in] src The src tensor to convert. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/S32/S64/BFLOAT16/F16/F32. * @param[out] dst The dst tensor. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/U32/S32/BFLOAT16/F16/F32. * @param[in] policy Conversion policy. * -- cgit v1.2.1