diff options
Diffstat (limited to 'src/core/cpu/kernels/CpuSubKernel.cpp')
-rw-r--r-- | src/core/cpu/kernels/CpuSubKernel.cpp | 87 |
1 files changed, 19 insertions, 68 deletions
diff --git a/src/core/cpu/kernels/CpuSubKernel.cpp b/src/core/cpu/kernels/CpuSubKernel.cpp index 098a324377..fa7a55805e 100644 --- a/src/core/cpu/kernels/CpuSubKernel.cpp +++ b/src/core/cpu/kernels/CpuSubKernel.cpp @@ -41,9 +41,7 @@ namespace { struct SubSelectorData { - DataType dt1; - DataType dt2; - DataType dt3; + DataType dt; }; using SubSelectorPtr = std::add_pointer<bool(const SubSelectorData &data)>::type; @@ -60,59 +58,44 @@ static const SubKernel available_kernels[] = { { "neon_fp32_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::F32); }, REGISTER_FP32_NEON(arm_compute::cpu::sub_same_neon<float>) }, #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) { "neon_fp16_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::F16); }, REGISTER_FP16_NEON(arm_compute::cpu::sub_same_neon<float16_t>) }, #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ { "neon_u8_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::U8); }, REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<uint8_t>) }, { "neon_s16_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::S16); }, REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<int16_t>) }, { "neon_s32_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::S32); }, REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<int32_t>) }, { - "neon_u8_s16_s16_sub", - [](const SubSelectorData & data) { return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)); }, - REGISTER_INTEGER_NEON(arm_compute::cpu::sub_u8_s16_s16_neon) - }, - { - "neon_s16_u8_s16_sub", - [](const SubSelectorData & data) { return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)); }, - REGISTER_INTEGER_NEON(arm_compute::cpu::sub_s16_u8_s16_neon) - }, - { - "neon_u8_u8_s16_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)); }, - REGISTER_INTEGER_NEON(arm_compute::cpu::sub_u8_u8_s16_neon) - }, - { "neon_qu8_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::QASYMM8); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::sub_qasymm8_neon) }, { "neon_qs8_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::sub_qasymm8_signed_neon) }, { - "neon_s16_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)); }, + "neon_qs16_sub", + [](const SubSelectorData & data) { return (data.dt == DataType::QSYMM16); }, REGISTER_QSYMM16_NEON(arm_compute::cpu::sub_qsymm16_neon) }, }; @@ -123,11 +106,11 @@ static const SubKernel available_kernels[] = * * @return A matching micro-kernel else nullptr */ -const SubKernel *get_implementation(DataType dt1, DataType dt2, DataType dt3) +const SubKernel *get_implementation(DataType dt) { for(const auto &uk : available_kernels) { - if(uk.is_selected({ dt1, dt2, dt3 })) + if(uk.is_selected({ dt })) { return &uk; } @@ -141,54 +124,21 @@ inline Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src0); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16, - DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&dst, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16, - DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1); - const auto *uk = get_implementation(src0.data_type(), src1.data_type(), dst.data_type()); + const auto *uk = get_implementation(src0.data_type()); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); const TensorShape out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8) - && !(src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8) - && !(src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED) - && !(src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16) - && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8) - && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::U8) - && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S32 && src1.data_type() == DataType::S32) - && !(src0.data_type() == DataType::F32 && src1.data_type() == DataType::F32) - && !(src0.data_type() == DataType::F16 && src1.data_type() == DataType::F16), - "You called subtract with the wrong image formats"); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - (src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED && policy == ConvertPolicy::WRAP) - || (src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP) - || (src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16 && policy == ConvertPolicy::WRAP), - "Convert policy cannot be WRAP if datatype is quantized"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_quantized(src0.data_type()) && (policy == ConvertPolicy::WRAP), + "Convert policy cannot be WRAP if datatype is quantized"); // Validate in case of configured dst if(dst.total_size() > 0) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::U8) - && !(src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8 && dst.data_type() == DataType::QASYMM8) - && !(src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED && dst.data_type() == DataType::QASYMM8_SIGNED) - && !(src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16 && dst.data_type() == DataType::QSYMM16) - && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S32 && src1.data_type() == DataType::S32 && dst.data_type() == DataType::S32) - && !(src0.data_type() == DataType::F32 && src1.data_type() == DataType::F32 && dst.data_type() == DataType::F32) - && !(src0.data_type() == DataType::F16 && src1.data_type() == DataType::F16 && dst.data_type() == DataType::F16), - "You called subtract with the wrong image formats"); - + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &dst); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0), "Wrong shape for dst"); } @@ -205,8 +155,9 @@ void CpuSubKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, I // Auto initialize dst if not initialized set_shape_if_empty(*dst, out_shape); + set_data_type_if_unknown(*dst, src0->data_type()); - const auto *uk = get_implementation(src0->data_type(), src1->data_type(), dst->data_type()); + const auto *uk = get_implementation(src0->data_type()); ARM_COMPUTE_ERROR_ON_NULLPTR(uk); _policy = policy; |