From da816752cad76c8e1b367e8e9c648994a1af599a Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 2 Jul 2021 09:22:14 +0100 Subject: Remove redundant implementations of Add/Sub operators Allows only implementations where inputs/output are of the same data type and removes legacy Computer Vision ones. Signed-off-by: Georgios Pinitas Change-Id: Ia2b3d23a04236aab682f0c36a1110a30f7c06d1c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5900 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- src/core/cpu/kernels/CpuAddKernel.cpp | 141 ++++++---------------------------- 1 file changed, 25 insertions(+), 116 deletions(-) (limited to 'src/core/cpu/kernels/CpuAddKernel.cpp') diff --git a/src/core/cpu/kernels/CpuAddKernel.cpp b/src/core/cpu/kernels/CpuAddKernel.cpp index 12766037a7..61b7b19443 100644 --- a/src/core/cpu/kernels/CpuAddKernel.cpp +++ b/src/core/cpu/kernels/CpuAddKernel.cpp @@ -45,14 +45,7 @@ namespace { struct AddSelectorData { - /* Data types for all ITensorInfos: - dt1 -> src0 - dt2 -> src1 - dt3 -> dst - */ - DataType dt1; - DataType dt2; - DataType dt3; + DataType dt; const CPUInfo &ci; }; @@ -72,7 +65,7 @@ static const AddKernel available_kernels[] = "sve2_qu8_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)) && data.ci.has_sve(); + return (data.dt == DataType::QASYMM8) && data.ci.has_sve(); }, REGISTER_QASYMM8_SVE(arm_compute::cpu::add_qasymm8_sve) }, @@ -80,7 +73,7 @@ static const AddKernel available_kernels[] = "sve2_qs8_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)) && data.ci.has_sve(); + return (data.dt == DataType::QASYMM8_SIGNED) && data.ci.has_sve(); }, REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::add_qasymm8_signed_sve) }, @@ -88,7 +81,7 @@ static const AddKernel available_kernels[] = "sve2_qs16_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)) && data.ci.has_sve(); + return (data.dt == DataType::QSYMM16) && data.ci.has_sve(); }, REGISTER_QSYMM16_SVE(arm_compute::cpu::add_qsymm16_sve) }, @@ -98,7 +91,7 @@ static const AddKernel available_kernels[] = "sve_fp32_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)) && data.ci.has_sve(); + return (data.dt == DataType::F32) && data.ci.has_sve(); }, REGISTER_FP32_SVE(arm_compute::cpu::add_same_sve) }, @@ -106,7 +99,7 @@ static const AddKernel available_kernels[] = "sve_fp16_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)) && data.ci.has_sve(); + return (data.dt == DataType::F16) && data.ci.has_sve(); }, REGISTER_FP16_SVE(arm_compute::cpu::add_same_sve) }, @@ -114,7 +107,7 @@ static const AddKernel available_kernels[] = "sve_u8_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)) && data.ci.has_sve(); + return (data.dt == DataType::U8) && data.ci.has_sve(); }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve) }, @@ -122,7 +115,7 @@ static const AddKernel available_kernels[] = "sve_s16_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)) && data.ci.has_sve(); + return (data.dt == DataType::S16) && data.ci.has_sve(); }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve) }, @@ -130,39 +123,15 @@ static const AddKernel available_kernels[] = "sve_s32_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)) && data.ci.has_sve(); + return (data.dt == DataType::S32) && data.ci.has_sve(); }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve) }, - { - "sve_u8_s16_s16_add", - [](const AddSelectorData & data) - { - return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)) && data.ci.has_sve(); - }, - REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_s16_s16_sve) - }, - { - "sve_s16_u8_s16_add", - [](const AddSelectorData & data) - { - return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)) && data.ci.has_sve(); - }, - REGISTER_INTEGER_SVE(arm_compute::cpu::add_s16_u8_s16_sve) - }, - { - "sve_u8_u8_s16_add", - [](const AddSelectorData & data) - { - return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)) && data.ci.has_sve(); - }, - REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_u8_s16_sve) - }, #endif /* defined(ARM_COMPUTE_ENABLE_SVE) */ #if defined(ARM_COMPUTE_ENABLE_NEON) { "neon_fp32_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::F32); }, REGISTER_FP32_NEON(arm_compute::cpu::add_same_neon) }, #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) @@ -170,56 +139,41 @@ static const AddKernel available_kernels[] = "neon_fp16_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)) && data.ci.has_fp16(); + return (data.dt == DataType::F16) && data.ci.has_fp16(); }, REGISTER_FP16_NEON(arm_compute::cpu::add_same_neon) }, #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ { "neon_u8_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::U8); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_same_neon) }, { "neon_s16_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::S16); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_same_neon) }, { "neon_s32_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::S32); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_same_neon) }, - { - "neon_u8_s16_s16_add", - [](const AddSelectorData & data) { return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)); }, - REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_s16_s16_neon) - }, - { - "neon_s16_u8_s16_add", - [](const AddSelectorData & data) { return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)); }, - REGISTER_INTEGER_NEON(arm_compute::cpu::add_s16_u8_s16_neon) - }, - { - "neon_u8_u8_s16_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)); }, - REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_u8_s16_neon) - }, #endif /* defined(ARM_COMPUTE_ENABLE_NEON) */ #if defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE) { "neon_qu8_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::QASYMM8); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::add_qasymm8_neon) }, { "neon_qs8_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::add_qasymm8_signed_neon) }, { "neon_qs16_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::QSYMM16); }, REGISTER_QSYMM16_NEON(arm_compute::cpu::add_qsymm16_neon) }, #endif /* defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE) */ @@ -231,11 +185,11 @@ static const AddKernel available_kernels[] = * * @return A matching micro-kernel else nullptr */ -const AddKernel *get_implementation(const CPUInfo &cpuinfo, DataType dt1, DataType dt2, DataType dt3) +const AddKernel *get_implementation(const CPUInfo &cpuinfo, DataType dt) { for(const auto &uk : available_kernels) { - if(uk.is_selected({ dt1, dt2, dt3, cpuinfo })) + if(uk.is_selected({ dt, cpuinfo })) { return &uk; } @@ -251,9 +205,7 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, - DataType::S16, DataType::QSYMM16, DataType::F16, - DataType::S32, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1); const TensorShape out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape()); @@ -265,25 +217,12 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons // Validate in case of configured dst if(dst.total_size() > 0) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::U8) - && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S32 && src1.data_type() == DataType::S32 && dst.data_type() == DataType::S32) - && !(src0.data_type() == DataType::F32 && src1.data_type() == DataType::F32 && dst.data_type() == DataType::F32) - && !(src0.data_type() == DataType::F16 && src1.data_type() == DataType::F16 && dst.data_type() == DataType::F16) - && !(src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8 && dst.data_type() == DataType::QASYMM8) - && !(src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED && dst.data_type() == DataType::QASYMM8_SIGNED) - && !(src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16 && dst.data_type() == DataType::QSYMM16), - "You called addition with the wrong image formats"); - + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &dst); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0), "Wrong shape for dst"); } - const auto *uk = get_implementation(CPUInfo::get(), src0.data_type(), src1.data_type(), dst.data_type()); + const auto *uk = get_implementation(CPUInfo::get(), src0.data_type()); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); return Status{}; @@ -294,38 +233,8 @@ std::pair validate_and_configure_window(const ITensorInfo &src0, const TensorShape &out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape()); // Auto initialize dst if not initialized - { - set_shape_if_empty(dst, out_shape); - - if(src0.data_type() == DataType::S16 || src1.data_type() == DataType::S16) - { - set_format_if_unknown(dst, Format::S16); - } - if(src0.data_type() == DataType::S32 || src1.data_type() == DataType::S32) - { - set_format_if_unknown(dst, Format::S32); - } - else if(src0.data_type() == DataType::F16 || src1.data_type() == DataType::F16) - { - set_format_if_unknown(dst, Format::F16); - } - else if(src0.data_type() == DataType::F32 || src1.data_type() == DataType::F32) - { - set_format_if_unknown(dst, Format::F32); - } - else if(src0.data_type() == DataType::QASYMM8 || src1.data_type() == DataType::QASYMM8) - { - set_data_type_if_unknown(dst, DataType::QASYMM8); - } - else if(src0.data_type() == DataType::QASYMM8_SIGNED || src1.data_type() == DataType::QASYMM8_SIGNED) - { - set_data_type_if_unknown(dst, DataType::QASYMM8_SIGNED); - } - else if(src0.data_type() == DataType::QSYMM16 || src1.data_type() == DataType::QSYMM16) - { - set_data_type_if_unknown(dst, DataType::QSYMM16); - } - } + set_shape_if_empty(dst, out_shape); + set_data_type_if_unknown(dst, src0.data_type()); Window win = calculate_max_window(out_shape, Steps()); @@ -339,7 +248,7 @@ void CpuAddKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, I ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst, policy)); - const auto uk = get_implementation(CPUInfo::get(), src0->data_type(), src1->data_type(), dst->data_type()); + const auto uk = get_implementation(CPUInfo::get(), src0->data_type()); ARM_COMPUTE_ERROR_ON_NULLPTR(uk); _policy = policy; -- cgit v1.2.1