aboutsummaryrefslogtreecommitdiff
path: root/src/core/cpu/kernels/CpuAddKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/cpu/kernels/CpuAddKernel.cpp')
-rw-r--r--src/core/cpu/kernels/CpuAddKernel.cpp120
1 files changed, 81 insertions, 39 deletions
diff --git a/src/core/cpu/kernels/CpuAddKernel.cpp b/src/core/cpu/kernels/CpuAddKernel.cpp
index 7afdceae38..8d74b4027b 100644
--- a/src/core/cpu/kernels/CpuAddKernel.cpp
+++ b/src/core/cpu/kernels/CpuAddKernel.cpp
@@ -45,9 +45,15 @@ namespace
{
struct AddSelectorData
{
- DataType dt1;
- DataType dt2;
- DataType dt3;
+ /* Data types for all ITensorInfos:
+ dt1 -> src0
+ dt2 -> src1
+ dt3 -> dst
+ */
+ DataType dt1;
+ DataType dt2;
+ DataType dt3;
+ const CPUInfo &ci;
};
using AddSelectorPtr = std::add_pointer<bool(const AddSelectorData &data)>::type;
@@ -61,49 +67,99 @@ struct AddKernel
static const AddKernel available_kernels[] =
{
-#if defined(ENABLE_SVE)
+#if defined(ARM_COMPUTE_ENABLE_SVE2)
+ {
+ "add_qasymm8_sve",
+ [](const AddSelectorData & data)
+ {
+ return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)) && data.ci.has_sve();
+ },
+ REGISTER_QASYMM8_SVE(arm_compute::cpu::add_qasymm8_sve)
+ },
+ {
+ "add_qasymm8_signed_sve",
+ [](const AddSelectorData & data)
+ {
+ return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)) && data.ci.has_sve();
+ },
+ REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::add_qasymm8_signed_sve)
+ },
+ {
+ "add_qsymm16_sve",
+ [](const AddSelectorData & data)
+ {
+ return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)) && data.ci.has_sve();
+ },
+ REGISTER_QSYMM16_SVE(arm_compute::cpu::add_qsymm16_sve)
+ },
+#endif /* !defined(ARM_COMPUTE_ENABLE_SVE2) */
+#if defined(ARM_COMPUTE_ENABLE_SVE)
{
"add_same_sve",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)); },
+ [](const AddSelectorData & data)
+ {
+ return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)) && data.ci.has_sve();
+ },
REGISTER_FP32_SVE(arm_compute::cpu::add_same_sve<float>)
},
{
"add_same_sve",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)); },
+ [](const AddSelectorData & data)
+ {
+ return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)) && data.ci.has_sve();
+ },
REGISTER_FP16_SVE(arm_compute::cpu::add_same_sve<float16_t>)
},
{
"add_same_sve",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)); },
+ [](const AddSelectorData & data)
+ {
+ return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)) && data.ci.has_sve();
+ },
REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve<uint8_t>)
},
{
"add_same_sve",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)); },
+ [](const AddSelectorData & data)
+ {
+ return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)) && data.ci.has_sve();
+ },
REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve<int16_t>)
},
{
"add_same_sve",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)); },
+ [](const AddSelectorData & data)
+ {
+ return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)) && data.ci.has_sve();
+ },
REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve<int32_t>)
},
{
"add_u8_s16_s16_sve",
- [](const AddSelectorData & data) { return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)); },
+ [](const AddSelectorData & data)
+ {
+ return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)) && data.ci.has_sve();
+ },
REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_s16_s16_sve)
},
{
"add_s16_u8_s16_sve",
- [](const AddSelectorData & data) { return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)); },
+ [](const AddSelectorData & data)
+ {
+ return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)) && data.ci.has_sve();
+ },
REGISTER_INTEGER_SVE(arm_compute::cpu::add_s16_u8_s16_sve)
},
{
"add_u8_u8_s16_sve",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)); },
+ [](const AddSelectorData & data)
+ {
+ return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)) && data.ci.has_sve();
+ },
REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_u8_s16_sve)
},
-#endif /* defined(ENABLE_SVE) */
-#if defined(ENABLE_NEON)
+#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */
+#if defined(ARM_COMPUTE_ENABLE_NEON)
{
"add_same_neon",
[](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)); },
@@ -112,7 +168,10 @@ static const AddKernel available_kernels[] =
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
{
"add_same_neon",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)); },
+ [](const AddSelectorData & data)
+ {
+ return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)) && data.ci.has_fp16();
+ },
REGISTER_FP16_NEON(arm_compute::cpu::add_same_neon<float16_t>)
},
#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
@@ -146,24 +205,8 @@ static const AddKernel available_kernels[] =
[](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)); },
REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_u8_s16_neon)
},
-#endif /* defined(ENABLE_NEON) */
-#if defined(__ARM_FEATURE_SVE2)
- {
- "add_qasymm8_sve",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)); },
- REGISTER_QASYMM8_SVE(arm_compute::cpu::add_qasymm8_sve)
- },
- {
- "add_qasymm8_signed_sve",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)); },
- REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::add_qasymm8_signed_sve)
- },
- {
- "add_qsymm16_sve",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)); },
- REGISTER_QSYMM16_SVE(arm_compute::cpu::add_qsymm16_sve)
- },
-#else /* !defined(__ARM_FEATURE_SVE2) */
+#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */
+#if defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE)
{
"add_qasymm8_neon",
[](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)); },
@@ -179,8 +222,7 @@ static const AddKernel available_kernels[] =
[](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)); },
REGISTER_QSYMM16_NEON(arm_compute::cpu::add_qsymm16_neon)
},
-#endif /* defined(ENABLE_NEON) */
-
+#endif /* defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE) */
};
/** Micro-kernel selector
@@ -189,11 +231,11 @@ static const AddKernel available_kernels[] =
*
* @return A matching micro-kernel else nullptr
*/
-const AddKernel *get_implementation(DataType dt1, DataType dt2, DataType dt3)
+const AddKernel *get_implementation(const CPUInfo &cpuinfo, DataType dt1, DataType dt2, DataType dt3)
{
for(const auto &uk : available_kernels)
{
- if(uk.is_selected({ dt1, dt2, dt3 }))
+ if(uk.is_selected({ dt1, dt2, dt3, cpuinfo }))
{
return &uk;
}
@@ -241,7 +283,7 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons
"Wrong shape for dst");
}
- const auto *uk = get_implementation(src0.data_type(), src1.data_type(), dst.data_type());
+ const auto *uk = get_implementation(CPUInfo::get(), src0.data_type(), src1.data_type(), dst.data_type());
ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
return Status{};
@@ -327,7 +369,7 @@ void CpuAddKernel::run_op(ITensorPack &tensors, const Window &window, const Thre
const ITensor *src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
ITensor *dst = tensors.get_tensor(TensorType::ACL_DST);
- const auto *uk = get_implementation(src0->info()->data_type(), src1->info()->data_type(), dst->info()->data_type());
+ const auto *uk = get_implementation(CPUInfo::get(), src0->info()->data_type(), src1->info()->data_type(), dst->info()->data_type());
ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
uk->ukernel(src0, src1, dst, _policy, window);