diff options
Diffstat (limited to 'src/core/cpu/kernels/CpuAddKernel.cpp')
-rw-r--r-- | src/core/cpu/kernels/CpuAddKernel.cpp | 120 |
1 files changed, 81 insertions, 39 deletions
diff --git a/src/core/cpu/kernels/CpuAddKernel.cpp b/src/core/cpu/kernels/CpuAddKernel.cpp index 7afdceae38..8d74b4027b 100644 --- a/src/core/cpu/kernels/CpuAddKernel.cpp +++ b/src/core/cpu/kernels/CpuAddKernel.cpp @@ -45,9 +45,15 @@ namespace { struct AddSelectorData { - DataType dt1; - DataType dt2; - DataType dt3; + /* Data types for all ITensorInfos: + dt1 -> src0 + dt2 -> src1 + dt3 -> dst + */ + DataType dt1; + DataType dt2; + DataType dt3; + const CPUInfo &ci; }; using AddSelectorPtr = std::add_pointer<bool(const AddSelectorData &data)>::type; @@ -61,49 +67,99 @@ struct AddKernel static const AddKernel available_kernels[] = { -#if defined(ENABLE_SVE) +#if defined(ARM_COMPUTE_ENABLE_SVE2) + { + "add_qasymm8_sve", + [](const AddSelectorData & data) + { + return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)) && data.ci.has_sve(); + }, + REGISTER_QASYMM8_SVE(arm_compute::cpu::add_qasymm8_sve) + }, + { + "add_qasymm8_signed_sve", + [](const AddSelectorData & data) + { + return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)) && data.ci.has_sve(); + }, + REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::add_qasymm8_signed_sve) + }, + { + "add_qsymm16_sve", + [](const AddSelectorData & data) + { + return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)) && data.ci.has_sve(); + }, + REGISTER_QSYMM16_SVE(arm_compute::cpu::add_qsymm16_sve) + }, +#endif /* !defined(ARM_COMPUTE_ENABLE_SVE2) */ +#if defined(ARM_COMPUTE_ENABLE_SVE) { "add_same_sve", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)); }, + [](const AddSelectorData & data) + { + return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)) && data.ci.has_sve(); + }, REGISTER_FP32_SVE(arm_compute::cpu::add_same_sve<float>) }, { "add_same_sve", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)); }, + [](const AddSelectorData & data) + { + return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)) && data.ci.has_sve(); + }, REGISTER_FP16_SVE(arm_compute::cpu::add_same_sve<float16_t>) }, { "add_same_sve", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)); }, + [](const AddSelectorData & data) + { + return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)) && data.ci.has_sve(); + }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve<uint8_t>) }, { "add_same_sve", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)); }, + [](const AddSelectorData & data) + { + return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)) && data.ci.has_sve(); + }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve<int16_t>) }, { "add_same_sve", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)); }, + [](const AddSelectorData & data) + { + return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)) && data.ci.has_sve(); + }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve<int32_t>) }, { "add_u8_s16_s16_sve", - [](const AddSelectorData & data) { return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)); }, + [](const AddSelectorData & data) + { + return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)) && data.ci.has_sve(); + }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_s16_s16_sve) }, { "add_s16_u8_s16_sve", - [](const AddSelectorData & data) { return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)); }, + [](const AddSelectorData & data) + { + return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)) && data.ci.has_sve(); + }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_s16_u8_s16_sve) }, { "add_u8_u8_s16_sve", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)); }, + [](const AddSelectorData & data) + { + return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)) && data.ci.has_sve(); + }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_u8_s16_sve) }, -#endif /* defined(ENABLE_SVE) */ -#if defined(ENABLE_NEON) +#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */ +#if defined(ARM_COMPUTE_ENABLE_NEON) { "add_same_neon", [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)); }, @@ -112,7 +168,10 @@ static const AddKernel available_kernels[] = #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) { "add_same_neon", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)); }, + [](const AddSelectorData & data) + { + return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)) && data.ci.has_fp16(); + }, REGISTER_FP16_NEON(arm_compute::cpu::add_same_neon<float16_t>) }, #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ @@ -146,24 +205,8 @@ static const AddKernel available_kernels[] = [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_u8_s16_neon) }, -#endif /* defined(ENABLE_NEON) */ -#if defined(__ARM_FEATURE_SVE2) - { - "add_qasymm8_sve", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)); }, - REGISTER_QASYMM8_SVE(arm_compute::cpu::add_qasymm8_sve) - }, - { - "add_qasymm8_signed_sve", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)); }, - REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::add_qasymm8_signed_sve) - }, - { - "add_qsymm16_sve", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)); }, - REGISTER_QSYMM16_SVE(arm_compute::cpu::add_qsymm16_sve) - }, -#else /* !defined(__ARM_FEATURE_SVE2) */ +#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */ +#if defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE) { "add_qasymm8_neon", [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)); }, @@ -179,8 +222,7 @@ static const AddKernel available_kernels[] = [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)); }, REGISTER_QSYMM16_NEON(arm_compute::cpu::add_qsymm16_neon) }, -#endif /* defined(ENABLE_NEON) */ - +#endif /* defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE) */ }; /** Micro-kernel selector @@ -189,11 +231,11 @@ static const AddKernel available_kernels[] = * * @return A matching micro-kernel else nullptr */ -const AddKernel *get_implementation(DataType dt1, DataType dt2, DataType dt3) +const AddKernel *get_implementation(const CPUInfo &cpuinfo, DataType dt1, DataType dt2, DataType dt3) { for(const auto &uk : available_kernels) { - if(uk.is_selected({ dt1, dt2, dt3 })) + if(uk.is_selected({ dt1, dt2, dt3, cpuinfo })) { return &uk; } @@ -241,7 +283,7 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons "Wrong shape for dst"); } - const auto *uk = get_implementation(src0.data_type(), src1.data_type(), dst.data_type()); + const auto *uk = get_implementation(CPUInfo::get(), src0.data_type(), src1.data_type(), dst.data_type()); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); return Status{}; @@ -327,7 +369,7 @@ void CpuAddKernel::run_op(ITensorPack &tensors, const Window &window, const Thre const ITensor *src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1); ITensor *dst = tensors.get_tensor(TensorType::ACL_DST); - const auto *uk = get_implementation(src0->info()->data_type(), src1->info()->data_type(), dst->info()->data_type()); + const auto *uk = get_implementation(CPUInfo::get(), src0->info()->data_type(), src1->info()->data_type(), dst->info()->data_type()); ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); uk->ukernel(src0, src1, dst, _policy, window); |