diff options
Diffstat (limited to 'src/cpu/kernels/CpuCastKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuCastKernel.cpp | 43 |
1 files changed, 7 insertions, 36 deletions
diff --git a/src/cpu/kernels/CpuCastKernel.cpp b/src/cpu/kernels/CpuCastKernel.cpp index d478328d07..764a1ec71c 100644 --- a/src/cpu/kernels/CpuCastKernel.cpp +++ b/src/cpu/kernels/CpuCastKernel.cpp @@ -75,46 +75,34 @@ static const std::vector<CpuCastKernel::CastKernel> available_kernels = REGISTER_FP16_NEON(arm_compute::cpu::neon_fp32_to_fp16_cast) }, { - "neon_fp32_to_bf16_cast", - [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F32 && data.dst_dt == DataType::BFLOAT16 && data.isa.bf16; }, - REGISTER_BF16_NEON(arm_compute::cpu::neon_fp32_to_bfloat16_cast) - }, - { "neon_s32_cast", [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::S32 && data.dst_dt == DataType::F16 && data.isa.fp16; }, REGISTER_FP16_NEON(arm_compute::cpu::neon_s32_to_fp16_cast) }, - { - "neon_bf16_cast", - [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::BFLOAT16 && data.dst_dt == DataType::F32 && data.isa.bf16; }, - REGISTER_BF16_NEON(arm_compute::cpu::neon_bfloat16_to_fp32_cast) - }, }; Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, ConvertPolicy policy) { ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(dst); - ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(src); - ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(dst); ARM_COMPUTE_UNUSED(policy); ARM_COMPUTE_RETURN_ERROR_ON(src == dst); #ifdef __aarch64__ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, - DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16, + DataType::S16, DataType::U16, DataType::F16, DataType::F32, DataType::S32, DataType::S64, DataType::U64); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, - DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16, + DataType::S16, DataType::U16, DataType::F16, DataType::U32, DataType::S32, DataType::F32, DataType::S64); #else // __aarch64__ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, - DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16, + DataType::S16, DataType::U16, DataType::F16, DataType::F32, DataType::S32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, - DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16, + DataType::S16, DataType::U16, DataType::F16, DataType::U32, DataType::S32, DataType::F32); #endif // __aarch64__ @@ -136,18 +124,15 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, Conver ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::S16 && (dst->data_type() != DataType::QASYMM8_SIGNED && dst->data_type() != DataType::U8 && dst->data_type() != DataType::S32), "Only data_types supported [in] S16 -> [out] U8, S32"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::BFLOAT16 && dst->data_type() != DataType::F32, - "Only data_types supported [in] BFLOAT16 -> [out] F32"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::F16 && (dst->data_type() != DataType::QASYMM8_SIGNED && dst->data_type() != DataType::QASYMM8 && dst->data_type() != DataType::U8 && dst->data_type() != DataType::F32 && dst->data_type() != DataType::S32), "Only data_types supported [in] F16 -> [out] QASYMM8, F32, S32, U8"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::F32 && (dst->data_type() != DataType::QASYMM8_SIGNED && dst->data_type() != DataType::QASYMM8 - && dst->data_type() != DataType::F16 && dst->data_type() != DataType::BFLOAT16 + && dst->data_type() != DataType::F16 && dst->data_type() != DataType::S32 && dst->data_type() != DataType::U8), - "Only data_types supported [in] F32 -> [out] QASYMM8, BFLOAT16, F16, S32, U8"); + "Only data_types supported [in] F32 -> [out] QASYMM8, F16, S32, U8"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::S32 && (dst->data_type() != DataType::QASYMM8_SIGNED && dst->data_type() != DataType::QASYMM8 && dst->data_type() != DataType::F16 @@ -346,7 +331,7 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr Iterator src(_src, win); Iterator dst(_dst, win); - /*ukernel runs only when using fp16/bfloat16, so we validate it isn't a nullptr only before using it */ + /*ukernel runs only when using fp16, so we validate it isn't a nullptr only before using it */ const auto *uk = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ _src->info()->data_type(), _dst->info()->data_type(), CPUInfo::get().get_isa() }); switch(_src->info()->data_type()) @@ -948,13 +933,6 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr } break; } - case DataType::BFLOAT16: - { - /* Up-conversion BFLOAT16 -> F32 */ - ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); - uk->ukernel(_src, _dst, info, _policy, window); - break; - } case DataType::F16: { /* conversion F16 -> any data type */ @@ -972,13 +950,6 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr uk->ukernel(_src, _dst, info, _policy, window); break; } - case DataType::BFLOAT16: - { - /* Down-conversion F32 -> BFLOAT16 */ - ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr); - uk->ukernel(_src, _dst, info, _policy, window); - break; - } case DataType::S32: { /* Conversion F32 -> S32 */ |