aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuCastKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuCastKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuCastKernel.cpp43
1 files changed, 7 insertions, 36 deletions
diff --git a/src/cpu/kernels/CpuCastKernel.cpp b/src/cpu/kernels/CpuCastKernel.cpp
index d478328d07..764a1ec71c 100644
--- a/src/cpu/kernels/CpuCastKernel.cpp
+++ b/src/cpu/kernels/CpuCastKernel.cpp
@@ -75,46 +75,34 @@ static const std::vector<CpuCastKernel::CastKernel> available_kernels =
REGISTER_FP16_NEON(arm_compute::cpu::neon_fp32_to_fp16_cast)
},
{
- "neon_fp32_to_bf16_cast",
- [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::F32 && data.dst_dt == DataType::BFLOAT16 && data.isa.bf16; },
- REGISTER_BF16_NEON(arm_compute::cpu::neon_fp32_to_bfloat16_cast)
- },
- {
"neon_s32_cast",
[](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::S32 && data.dst_dt == DataType::F16 && data.isa.fp16; },
REGISTER_FP16_NEON(arm_compute::cpu::neon_s32_to_fp16_cast)
},
- {
- "neon_bf16_cast",
- [](const CastDataTypeISASelectorData & data) { return data.src_dt == DataType::BFLOAT16 && data.dst_dt == DataType::F32 && data.isa.bf16; },
- REGISTER_BF16_NEON(arm_compute::cpu::neon_bfloat16_to_fp32_cast)
- },
};
Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, ConvertPolicy policy)
{
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(dst);
- ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(src);
- ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(dst);
ARM_COMPUTE_UNUSED(policy);
ARM_COMPUTE_RETURN_ERROR_ON(src == dst);
#ifdef __aarch64__
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
- DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
+ DataType::S16, DataType::U16, DataType::F16,
DataType::F32, DataType::S32, DataType::S64, DataType::U64);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
- DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
+ DataType::S16, DataType::U16, DataType::F16,
DataType::U32, DataType::S32, DataType::F32, DataType::S64);
#else // __aarch64__
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
- DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
+ DataType::S16, DataType::U16, DataType::F16,
DataType::F32, DataType::S32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
- DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
+ DataType::S16, DataType::U16, DataType::F16,
DataType::U32, DataType::S32, DataType::F32);
#endif // __aarch64__
@@ -136,18 +124,15 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, Conver
ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::S16 && (dst->data_type() != DataType::QASYMM8_SIGNED && dst->data_type() != DataType::U8 && dst->data_type() != DataType::S32),
"Only data_types supported [in] S16 -> [out] U8, S32");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::BFLOAT16 && dst->data_type() != DataType::F32,
- "Only data_types supported [in] BFLOAT16 -> [out] F32");
-
ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::F16 && (dst->data_type() != DataType::QASYMM8_SIGNED && dst->data_type() != DataType::QASYMM8
&& dst->data_type() != DataType::U8
&& dst->data_type() != DataType::F32 && dst->data_type() != DataType::S32),
"Only data_types supported [in] F16 -> [out] QASYMM8, F32, S32, U8");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::F32 && (dst->data_type() != DataType::QASYMM8_SIGNED && dst->data_type() != DataType::QASYMM8
- && dst->data_type() != DataType::F16 && dst->data_type() != DataType::BFLOAT16
+ && dst->data_type() != DataType::F16
&& dst->data_type() != DataType::S32 && dst->data_type() != DataType::U8),
- "Only data_types supported [in] F32 -> [out] QASYMM8, BFLOAT16, F16, S32, U8");
+ "Only data_types supported [in] F32 -> [out] QASYMM8, F16, S32, U8");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_type() == DataType::S32 && (dst->data_type() != DataType::QASYMM8_SIGNED && dst->data_type() != DataType::QASYMM8
&& dst->data_type() != DataType::F16
@@ -346,7 +331,7 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
Iterator src(_src, win);
Iterator dst(_dst, win);
- /*ukernel runs only when using fp16/bfloat16, so we validate it isn't a nullptr only before using it */
+ /*ukernel runs only when using fp16, so we validate it isn't a nullptr only before using it */
const auto *uk = CpuCastKernel::get_implementation(CastDataTypeISASelectorData{ _src->info()->data_type(), _dst->info()->data_type(), CPUInfo::get().get_isa() });
switch(_src->info()->data_type())
@@ -948,13 +933,6 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
}
break;
}
- case DataType::BFLOAT16:
- {
- /* Up-conversion BFLOAT16 -> F32 */
- ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
- uk->ukernel(_src, _dst, info, _policy, window);
- break;
- }
case DataType::F16:
{
/* conversion F16 -> any data type */
@@ -972,13 +950,6 @@ void CpuCastKernel::run_op(ITensorPack &tensors, const Window &window, const Thr
uk->ukernel(_src, _dst, info, _policy, window);
break;
}
- case DataType::BFLOAT16:
- {
- /* Down-conversion F32 -> BFLOAT16 */
- ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
- uk->ukernel(_src, _dst, info, _policy, window);
- break;
- }
case DataType::S32:
{
/* Conversion F32 -> S32 */