aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Zlotnik <dana.zlotnik@arm.com>2022-01-17 09:54:26 +0200
committerDana Zlotnik <dana.zlotnik@arm.com>2022-02-14 12:49:53 +0000
commit6a2df886f32dcf7af4258163b0652f0fab07ecc5 (patch)
tree4ad16670d54d29de96df7cc5b582d52a6012255a
parent69854ba71f91f86c2a1c8a2301e91dcd93030561 (diff)
downloadComputeLibrary-6a2df886f32dcf7af4258163b0652f0fab07ecc5.tar.gz
Add kernel selection UT for submitted kernels
* Softmax kernel * Elementwise unary kernel * Elementwise binary ** This change require some refactor in the kernel cpp and h files Resolves COMPMID-5043 Change-Id: I58979b023ec31d759690847b3f85fc4baefbbf98 Signed-off-by: Dana Zlotnik <dana.zlotnik@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7033 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
-rw-r--r--filelist.json27
-rw-r--r--src/cpu/kernels/CpuElementwiseKernel.cpp431
-rw-r--r--src/cpu/kernels/CpuElementwiseKernel.h51
-rw-r--r--src/cpu/kernels/CpuElementwiseUnaryKernel.cpp30
-rw-r--r--src/cpu/kernels/CpuKernelSelectionTypes.h12
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.cpp139
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.h10
-rw-r--r--tests/validation/NEON/ElementwiseKernelSelection.cpp158
-rw-r--r--tests/validation/NEON/SoftmaxLayer.cpp73
9 files changed, 620 insertions, 311 deletions
diff --git a/filelist.json b/filelist.json
index ba19321a50..bba3d568a6 100644
--- a/filelist.json
+++ b/filelist.json
@@ -1262,20 +1262,23 @@
"common": [
"src/cpu/operators/CpuElementwise.cpp",
"src/cpu/kernels/CpuElementwiseKernel.cpp",
- "src/runtime/NEON/functions/NEElementwiseOperations.cpp",
- "src/cpu/kernels/elementwise_binary/generic/neon/qasymm8.cpp",
- "src/cpu/kernels/elementwise_binary/generic/neon/qasymm8_signed.cpp"
+ "src/runtime/NEON/functions/NEElementwiseOperations.cpp"
],
"neon":{
"fp32": ["src/cpu/kernels/elementwise_binary/generic/neon/fp32.cpp"],
"fp16": ["src/cpu/kernels/elementwise_binary/generic/neon/fp16.cpp"],
- "integer": ["src/cpu/kernels/elementwise_binary/generic/neon/integer.cpp"]
+ "integer": ["src/cpu/kernels/elementwise_binary/generic/neon/integer.cpp"],
+ "qasymm8": ["src/cpu/kernels/elementwise_binary/generic/neon/qasymm8.cpp"],
+ "qasymm8_signed": ["src/cpu/kernels/elementwise_binary/generic/neon/qasymm8_signed.cpp"]
},
"sve": {
"common": ["src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp" ],
"integer": ["src/cpu/kernels/elementwise_binary/generic/sve/integer.cpp"],
"fp32": ["src/cpu/kernels/elementwise_binary/generic/sve/fp32.cpp"],
- "fp16": ["src/cpu/kernels/elementwise_binary/generic/sve/fp16.cpp"],
+ "fp16": ["src/cpu/kernels/elementwise_binary/generic/sve/fp16.cpp"]
+
+ },
+ "sve2":{
"qasymm8": ["src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8.cpp"],
"qasymm8_signed": ["src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp"]
}
@@ -1899,16 +1902,20 @@
],
"neon":{
"fp32": ["src/cpu/kernels/softmax/generic/neon/fp32.cpp"],
- "fp16": ["src/cpu/kernels/softmax/generic/neon/fp16.cpp"],
- "qasymm8": ["src/cpu/kernels/softmax/generic/neon/qasymm8.cpp"],
- "qasymm8_signed": ["src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp"]
+ "fp16": ["src/cpu/kernels/softmax/generic/neon/fp16.cpp"],
+ "qasymm8":[ "src/cpu/kernels/softmax/generic/neon/qasymm8.cpp"],
+ "qasymm8_signed":["src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp"]
},
"sve": {
"common": [ "src/cpu/kernels/softmax/generic/sve/impl.cpp" ],
"fp32": ["src/cpu/kernels/softmax/generic/sve/fp32.cpp"],
"fp16": ["src/cpu/kernels/softmax/generic/sve/fp16.cpp"],
- "qasymm8": ["src/cpu/kernels/softmax/generic/sve/qasymm8.cpp" ,"src/cpu/kernels/softmax/generic/sve2/qasymm8.cpp" ],
- "qasymm8_signed": ["src/cpu/kernels/softmax/generic/sve/qasymm8_signed.cpp", "src/cpu/kernels/softmax/generic/sve2/qasymm8_signed.cpp"]
+ "qasymm8": ["src/cpu/kernels/softmax/generic/sve/qasymm8.cpp" ],
+ "qasymm8_signed": ["src/cpu/kernels/softmax/generic/sve/qasymm8_signed.cpp"]
+ },
+ "sve2":{
+ "qasymm8":[ "src/cpu/kernels/softmax/generic/sve2/qasymm8.cpp"],
+ "qasymm8_signed":["src/cpu/kernels/softmax/generic/sve2/qasymm8_signed.cpp"]
}
}
},
diff --git a/src/cpu/kernels/CpuElementwiseKernel.cpp b/src/cpu/kernels/CpuElementwiseKernel.cpp
index 53179ae95f..4b285fc2be 100644
--- a/src/cpu/kernels/CpuElementwiseKernel.cpp
+++ b/src/cpu/kernels/CpuElementwiseKernel.cpp
@@ -40,214 +40,255 @@ namespace kernels
{
namespace
{
-struct ElementwiseSelectorData
+template <ArithmeticOperation op>
+const std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> available_kernels_arithmetic =
{
- DataType dt;
- const CPUInfo &ci;
-};
-
-using ElementwiseSelector = std::add_pointer<bool(const ElementwiseSelectorData &)>::type;
-using UKernelType = CpuElementwiseKernel::ElementwiseFunction;
-struct ElementwiseKernel
-{
- const char *name;
- const ElementwiseSelector is_selected;
- UKernelType *ukernel;
-};
-
-template <ArithmeticOperation op>
-CpuElementwiseKernel::UKernelInfo configure_arithm_func(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
-{
- ARM_COMPUTE_UNUSED(src1, dst);
- static ElementwiseKernel kernels[] =
{
-#if defined(ARM_COMPUTE_ENABLE_SVE)
- {
- "sve_fp32_elementwise",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::F32 && data.ci.has_sve(); },
- REGISTER_FP32_SVE((arm_compute::cpu::sve_fp32_elementwise_binary<op>))
- },
+ "sve2_qu8_arithmetic",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "sve_s32_elementwise",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::S32 && data.ci.has_sve(); },
- REGISTER_INTEGER_SVE((arm_compute::cpu::sve_s32_elementwise_binary<op>))
+ return data.dt == DataType::QASYMM8 && data.isa.sve2 && static_cast<ArithmeticOperation>(data.op) == op;
},
+ REGISTER_QASYMM8_SVE2(sve2_qasymm8_elementwise_binary<op>)
+ },
+ {
+ "sve2_qs8_arithmetic",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "sve_s16_elementwise",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::S16 && data.ci.has_sve(); },
- REGISTER_INTEGER_SVE((arm_compute::cpu::sve_s16_elementwise_binary<op>))
+ return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2 && static_cast<ArithmeticOperation>(data.op) == op;
},
+ REGISTER_QASYMM8_SIGNED_SVE2(sve2_qasymm8_signed_elementwise_binary<op>)
+ },
+ {
+ "sve_fp32_arithmetic",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "sve_fp16_elementwise",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_sve(); },
- REGISTER_FP16_SVE((arm_compute::cpu::sve_fp16_elementwise_binary<op>))
+ return data.dt == DataType::F32 && data.isa.sve && static_cast<ArithmeticOperation>(data.op) == op;
},
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */
-#if defined(ARM_COMPUTE_ENABLE_NEON)
+ REGISTER_FP32_SVE(sve_fp32_elementwise_binary<op>)
+ },
+ {
+ "sve_s32_arithmetic",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_fp32_elementwise",
-
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::F32; },
- REGISTER_FP32_NEON((arm_compute::cpu::neon_fp32_elementwise_binary<op>))
+ return data.dt == DataType::S32 && data.isa.sve && static_cast<ArithmeticOperation>(data.op) == op;
},
+ REGISTER_INTEGER_SVE(sve_s32_elementwise_binary<op>)
+ },
+ {
+ "sve_s16_arithmetic",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_s32_elementwise",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::S32; },
- REGISTER_INTEGER_NEON((arm_compute::cpu::neon_s32_elementwise_binary<op>))
+ return data.dt == DataType::S16 && data.isa.sve && static_cast<ArithmeticOperation>(data.op) == op;
},
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+ REGISTER_INTEGER_SVE(sve_s16_elementwise_binary<op>)
+ },
+ {
+ "sve_fp16_arithmetic",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_fp16_elementwise",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_fp16(); },
- REGISTER_FP16_NEON((arm_compute::cpu::neon_fp16_elementwise_binary<op>))
+ return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16 && static_cast<ArithmeticOperation>(data.op) == op;
},
-#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
+ REGISTER_FP16_SVE(sve_fp16_elementwise_binary<op>)
+ },
+ {
+ "neon_fp32_arithmetic",
+
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_s16_elementwise",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::S16; },
- REGISTER_INTEGER_NEON((arm_compute::cpu::neon_s16_elementwise_binary<op>))
+ return data.dt == DataType::F32 && static_cast<ArithmeticOperation>(data.op) == op;
},
-#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */
-#if defined(ARM_COMPUTE_ENABLE_SVE2)
+ REGISTER_FP32_NEON(neon_fp32_elementwise_binary<op>)
+ },
+ {
+ "neon_s32_arithmetic",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "sve2_qu8_elementwise",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::QASYMM8 && data.ci.has_sve2(); },
- REGISTER_QASYMM8_SVE2((arm_compute::cpu::sve2_qasymm8_elementwise_binary<op>))
+ return data.dt == DataType::S32 && static_cast<ArithmeticOperation>(data.op) == op;
},
+ REGISTER_INTEGER_NEON(neon_s32_elementwise_binary<op>)
+ },
+ {
+ "neon_fp16_arithmetic",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "sve2_qs8_elementwise",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.ci.has_sve2(); },
- REGISTER_QASYMM8_SIGNED_SVE2((arm_compute::cpu::sve2_qasymm8_signed_elementwise_binary<op>))
+ return data.dt == DataType::F16 && data.isa.fp16 && static_cast<ArithmeticOperation>(data.op) == op;
},
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
-#if defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE)
+ REGISTER_FP16_NEON(neon_fp16_elementwise_binary<op>)
+ },
+ {
+ "neon_s16_arithmetic",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_qu8_elementwise",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::QASYMM8; },
- REGISTER_QASYMM8_NEON((arm_compute::cpu::neon_qasymm8_elementwise_binary<op>))
+ return data.dt == DataType::S16 && static_cast<ArithmeticOperation>(data.op) == op;
},
+ REGISTER_INTEGER_NEON(neon_s16_elementwise_binary<op>)
+ },
+ {
+ "neon_qu8_arithmetic",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_qs8_elementwise",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; },
- REGISTER_QASYMM8_SIGNED_NEON((arm_compute::cpu::neon_qasymm8_signed_elementwise_binary<op>))
+ return data.dt == DataType::QASYMM8 && static_cast<ArithmeticOperation>(data.op) == op;
},
-#endif /* defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE) */
- };
-
- for(const auto &uk : kernels)
+ REGISTER_QASYMM8_NEON(neon_qasymm8_elementwise_binary<op>)
+ },
{
- if(uk.is_selected({ src0->data_type(), CPUInfo::get() }))
+ "neon_qs8_arithmetic",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- return { uk.name, uk.ukernel };
- }
- }
-
- return { "", nullptr };
-}
-
-template <ComparisonOperation op>
-CpuElementwiseKernel::UKernelInfo configure_comp_func(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
+ return data.dt == DataType::QASYMM8_SIGNED && static_cast<ArithmeticOperation>(data.op) == op;
+ },
+ REGISTER_QASYMM8_SIGNED_NEON(neon_qasymm8_signed_elementwise_binary<op>)
+ },
+};
+template <ComparisonOperation op>
+const std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> available_kernels_comperison =
{
- ARM_COMPUTE_UNUSED(src1, dst);
- static ElementwiseKernel kernels[] =
{
-#if defined(ARM_COMPUTE_ENABLE_SVE)
+ "sve2_qu8_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "sve_u8_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::U8 && data.ci.has_sve(); },
- REGISTER_INTEGER_SVE(arm_compute::cpu::sve_u8_comparison_elementwise_binary<op>)
+ return data.dt == DataType::QASYMM8 && data.isa.sve2 && static_cast<ComparisonOperation>(data.op) == op;
},
+ REGISTER_QASYMM8_SVE2(sve2_qasymm8_comparison_elementwise_binary<op>)
+ },
+ {
+ "sve2_qs8_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "sve_fp32_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::F32 && data.ci.has_sve(); },
- REGISTER_FP32_SVE(arm_compute::cpu::sve_fp32_comparison_elementwise_binary<op>)
+ return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2 && static_cast<ComparisonOperation>(data.op) == op;
},
+ REGISTER_QASYMM8_SIGNED_SVE2(sve2_qasymm8_signed_comparison_elementwise_binary<op>)
+ },
+ {
+ "sve_u8_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "sve_s16_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::S16 && data.ci.has_sve(); },
- REGISTER_INTEGER_SVE(arm_compute::cpu::sve_s16_comparison_elementwise_binary<op>)
+ return data.dt == DataType::U8 && data.isa.sve && static_cast<ComparisonOperation>(data.op) == op;
},
+ REGISTER_INTEGER_SVE(sve_u8_comparison_elementwise_binary<op>)
+ },
+ {
+ "sve_fp32_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "sve_s32_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::S32 && data.ci.has_sve(); },
- REGISTER_INTEGER_SVE(arm_compute::cpu::sve_s32_comparison_elementwise_binary<op>)
+ return data.dt == DataType::F32 && data.isa.sve && static_cast<ComparisonOperation>(data.op) == op;
},
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */
-#if defined(ARM_COMPUTE_ENABLE_NEON)
+ REGISTER_FP32_SVE(sve_fp32_comparison_elementwise_binary<op>)
+ },
+ {
+ "sve_s16_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_u8_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::U8; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_u8_comparison_elementwise_binary<op>)
+ return data.dt == DataType::S16 && data.isa.sve && static_cast<ComparisonOperation>(data.op) == op;
},
+ REGISTER_INTEGER_SVE(sve_s16_comparison_elementwise_binary<op>)
+ },
+ {
+ "sve_s32_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_fp32_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::F32; },
- REGISTER_FP32_NEON(arm_compute::cpu::neon_fp32_comparison_elementwise_binary<op>)
+ return data.dt == DataType::S32 && data.isa.sve && static_cast<ComparisonOperation>(data.op) == op;
},
+ REGISTER_INTEGER_SVE(sve_s32_comparison_elementwise_binary<op>)
+ },
+ {
+ "sve_fp16_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_s16_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::S16; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s16_comparison_elementwise_binary<op>)
+ return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16 && static_cast<ComparisonOperation>(data.op) == op;
},
+ REGISTER_FP16_SVE(sve_fp16_comparison_elementwise_binary<op>)
+ },
+ {
+ "neon_u8_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_s32_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::S32; },
- REGISTER_INTEGER_NEON(arm_compute::cpu::neon_s32_comparison_elementwise_binary<op>)
+ return data.dt == DataType::U8 && static_cast<ComparisonOperation>(data.op) == op;
},
-#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */
-#if defined(ARM_COMPUTE_ENABLE_SVE2)
+ REGISTER_INTEGER_NEON(neon_u8_comparison_elementwise_binary<op>)
+ },
+ {
+ "neon_fp32_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "sve2_qu8_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::QASYMM8 && data.ci.has_sve2(); },
- REGISTER_QASYMM8_SVE2(arm_compute::cpu::sve2_qasymm8_comparison_elementwise_binary<op>)
+ return data.dt == DataType::F32 && static_cast<ComparisonOperation>(data.op) == op;
},
+ REGISTER_FP32_NEON(neon_fp32_comparison_elementwise_binary<op>)
+ },
+ {
+ "neon_s16_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "sve2_qs8_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.ci.has_sve2(); },
- REGISTER_QASYMM8_SIGNED_SVE2(arm_compute::cpu::sve2_qasymm8_signed_comparison_elementwise_binary<op>)
+ return data.dt == DataType::S16 && static_cast<ComparisonOperation>(data.op) == op;
},
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
-#if defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE)
+ REGISTER_INTEGER_NEON(neon_s16_comparison_elementwise_binary<op>)
+ },
+ {
+ "neon_s32_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_qu8_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::QASYMM8; },
- REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_comparison_elementwise_binary<op>)
+ return data.dt == DataType::S32 && static_cast<ComparisonOperation>(data.op) == op;
},
+ REGISTER_INTEGER_NEON(neon_s32_comparison_elementwise_binary<op>)
+ },
+ {
+ "neon_qu8_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_qs8_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; },
- REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_signed_comparison_elementwise_binary<op>)
+ return data.dt == DataType::QASYMM8 && static_cast<ComparisonOperation>(data.op) == op;
},
-#endif /* defined(ARM_COMPUTE_ENABLE_NEON ||ARM_COMPUTE_ENABLE_SVE) */
-#if defined(ARM_COMPUTE_ENABLE_SVE)
+ REGISTER_QASYMM8_NEON(neon_qasymm8_comparison_elementwise_binary<op>)
+ },
+ {
+ "neon_qs8_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "sve_fp16_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_sve(); },
- REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_comparison_elementwise_binary<op>)
+ return data.dt == DataType::QASYMM8_SIGNED && static_cast<ComparisonOperation>(data.op) == op;
},
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */
-#if defined(ARM_COMPUTE_ENABLE_NEON) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+ REGISTER_QASYMM8_SIGNED_NEON(neon_qasymm8_signed_comparison_elementwise_binary<op>)
+ },
+ {
+ "neon_fp16_comparison",
+ [](const ElementwiseDataTypeISASelectorData & data)
{
- "neon_fp16_comparison",
- [](const ElementwiseSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_fp16(); },
- REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_comparison_elementwise_binary<op>)
+ return data.dt == DataType::F16 && data.isa.fp16 && static_cast<ComparisonOperation>(data.op) == op;
},
-#endif /* defined(ARM_COMPUTE_ENABLE_NEON) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
- };
+ REGISTER_FP16_NEON(neon_fp16_comparison_elementwise_binary<op>)
+ },
+};
+} // namespace
- for(const auto &uk : kernels)
- {
- if(uk.is_selected({ src0->data_type(), CPUInfo::get() }))
- {
- return { uk.name, uk.ukernel };
- }
- }
+const std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> &CpuArithmeticKernel::get_available_kernels()
+{
+ static std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> available_kernels;
+ std::move(available_kernels_arithmetic<ArithmeticOperation::ADD>.begin(), available_kernels_arithmetic<ArithmeticOperation::ADD>.end(), std::back_inserter(available_kernels));
+ std::move(available_kernels_arithmetic<ArithmeticOperation::SUB>.begin(), available_kernels_arithmetic<ArithmeticOperation::SUB>.end(), std::back_inserter(available_kernels));
+ std::move(available_kernels_arithmetic<ArithmeticOperation::DIV>.begin(), available_kernels_arithmetic<ArithmeticOperation::DIV>.end(), std::back_inserter(available_kernels));
+ std::move(available_kernels_arithmetic<ArithmeticOperation::MIN>.begin(), available_kernels_arithmetic<ArithmeticOperation::MIN>.end(), std::back_inserter(available_kernels));
+ std::move(available_kernels_arithmetic<ArithmeticOperation::MAX>.begin(), available_kernels_arithmetic<ArithmeticOperation::MAX>.end(), std::back_inserter(available_kernels));
+ std::move(available_kernels_arithmetic<ArithmeticOperation::SQUARED_DIFF>.begin(), available_kernels_arithmetic<ArithmeticOperation::SQUARED_DIFF>.end(), std::back_inserter(available_kernels));
+ std::move(available_kernels_arithmetic<ArithmeticOperation::POWER>.begin(), available_kernels_arithmetic<ArithmeticOperation::POWER>.end(), std::back_inserter(available_kernels));
+ std::move(available_kernels_arithmetic<ArithmeticOperation::PRELU>.begin(), available_kernels_arithmetic<ArithmeticOperation::PRELU>.end(), std::back_inserter(available_kernels));
+
+ return available_kernels;
+}
- return { "", nullptr };
+const std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> &CpuComparisonKernel::get_available_kernels()
+{
+ static std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> available_kernels;
+ std::move(available_kernels_comperison<ComparisonOperation::Equal>.begin(), available_kernels_comperison<ComparisonOperation::Equal>.end(), std::back_inserter(available_kernels));
+ std::move(available_kernels_comperison<ComparisonOperation::NotEqual>.begin(), available_kernels_comperison<ComparisonOperation::NotEqual>.end(), std::back_inserter(available_kernels));
+ std::move(available_kernels_comperison<ComparisonOperation::Greater>.begin(), available_kernels_comperison<ComparisonOperation::Greater>.end(), std::back_inserter(available_kernels));
+ std::move(available_kernels_comperison<ComparisonOperation::GreaterEqual>.begin(), available_kernels_comperison<ComparisonOperation::GreaterEqual>.end(), std::back_inserter(available_kernels));
+ std::move(available_kernels_comperison<ComparisonOperation::Less>.begin(), available_kernels_comperison<ComparisonOperation::Less>.end(), std::back_inserter(available_kernels));
+ std::move(available_kernels_comperison<ComparisonOperation::LessEqual>.begin(), available_kernels_comperison<ComparisonOperation::LessEqual>.end(), std::back_inserter(available_kernels));
+
+ return available_kernels;
}
-} // namespace
-Status CpuElementwiseKernel::validate_arguments_common(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
+template <class Derived>
+Status CpuElementwiseKernel<Derived>::validate_arguments_common(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
{
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src0);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1);
@@ -266,14 +307,38 @@ Status CpuElementwiseKernel::validate_arguments_common(const ITensorInfo &src0,
return Status{};
}
-void CpuElementwiseKernel::configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
+void CpuArithmeticKernel::configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
+
+ const auto *uk = CpuArithmeticKernel::get_implementation(ElementwiseDataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa(), static_cast<int>(_op) });
+
+ ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
+
+ _run_method = uk->ukernel;
+ _name = std::string("CpuArithmeticKernel").append("/").append(uk->name);
+
+ // If any of shapes is dynamic, expect a configured window and dst at run-time.
+ if(src0->is_dynamic() || src1->is_dynamic())
+ {
+ return;
+ }
+
+ auto shape_and_window = compute_output_shape_and_window(src0->tensor_shape(), src1->tensor_shape());
+ auto_init_if_empty(*dst, shape_and_window.first, 1, src0->data_type());
+ ICpuKernel::configure(shape_and_window.second);
+}
+
+void CpuComparisonKernel::configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
- const auto uk = get_implementation(src0, src1, dst);
+ const auto *uk = CpuComparisonKernel::get_implementation(ElementwiseDataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa(), static_cast<int>(_op) });
- _run_method = uk.ukernel;
- _name = std::string("CpuElementwiseKernel").append("/").append(uk.name);
+ ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
+
+ _run_method = uk->ukernel;
+ _name = std::string("CpuComparisonKernel").append("/").append(uk->name);
// If any of shapes is dynamic, expect a configured window and dst at run-time.
if(src0->is_dynamic() || src1->is_dynamic())
@@ -286,7 +351,8 @@ void CpuElementwiseKernel::configure_common(const ITensorInfo *src0, const ITens
ICpuKernel::configure(shape_and_window.second);
}
-void CpuElementwiseKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
+template <class Derived>
+void CpuElementwiseKernel<Derived>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON(_run_method == nullptr);
@@ -297,18 +363,23 @@ void CpuElementwiseKernel::run_op(ITensorPack &tensors, const Window &window, co
_run_method(src0, src1, dst, window);
}
+template void CpuElementwiseKernel<CpuArithmeticKernel>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info);
+template void CpuElementwiseKernel<CpuComparisonKernel>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info);
-const char *CpuElementwiseKernel::name() const
+template <class Derived>
+const char *CpuElementwiseKernel<Derived>::name() const
{
return _name.c_str();
}
+template const char *CpuElementwiseKernel<CpuArithmeticKernel>::name() const;
+template const char *CpuElementwiseKernel<CpuComparisonKernel>::name() const;
/** Arithmetic operators (min, max, squared_diff) */
void CpuArithmeticKernel::configure(ArithmeticOperation op, const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
{
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
_op = op;
- configure_common(src0, src1, dst);
+ CpuArithmeticKernel::configure_common(src0, src1, dst);
}
Status CpuArithmeticKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
@@ -330,35 +401,13 @@ Status CpuArithmeticKernel::validate(ArithmeticOperation op, const ITensorInfo *
return Status{};
}
-CpuElementwiseKernel::UKernelInfo CpuArithmeticKernel::get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
-{
- switch(_op)
- {
- case ArithmeticOperation::MAX:
- return configure_arithm_func<ArithmeticOperation::MAX>(src0, src1, dst);
- case ArithmeticOperation::MIN:
- return configure_arithm_func<ArithmeticOperation::MIN>(src0, src1, dst);
- case ArithmeticOperation::SQUARED_DIFF:
- return configure_arithm_func<ArithmeticOperation::SQUARED_DIFF>(src0, src1, dst);
- case ArithmeticOperation::PRELU:
- return configure_arithm_func<ArithmeticOperation::PRELU>(src0, src1, dst);
- case ArithmeticOperation::DIV:
- return configure_arithm_func<ArithmeticOperation::DIV>(src0, src1, dst);
- case ArithmeticOperation::POWER:
- return configure_arithm_func<ArithmeticOperation::POWER>(src0, src1, dst);
- default:
- ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
- }
- return { "", nullptr };
-}
-
/** The division operator */
void CpuDivisionKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
{
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
_op = ArithmeticOperation::DIV;
- configure_common(src0, src1, dst);
+ CpuArithmeticKernel::configure_common(src0, src1, dst);
}
Status CpuDivisionKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
@@ -379,7 +428,7 @@ void CpuPowerKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1,
{
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
_op = ArithmeticOperation::POWER;
- configure_common(src0, src1, dst);
+ CpuArithmeticKernel::configure_common(src0, src1, dst);
}
Status CpuPowerKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
@@ -400,7 +449,7 @@ void CpuComparisonKernel::configure(ComparisonOperation op, const ITensorInfo *s
{
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
_op = op;
- configure_common(src0, src1, dst);
+ CpuComparisonKernel::configure_common(src0, src1, dst);
}
Status CpuComparisonKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
@@ -421,28 +470,6 @@ Status CpuComparisonKernel::validate(ComparisonOperation op, const ITensorInfo *
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
return Status{};
}
-
-CpuElementwiseKernel::UKernelInfo CpuComparisonKernel::get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
-{
- switch(_op)
- {
- case ComparisonOperation::Equal:
- return configure_comp_func<ComparisonOperation::Equal>(src0, src1, dst);
- case ComparisonOperation::NotEqual:
- return configure_comp_func<ComparisonOperation::NotEqual>(src0, src1, dst);
- case ComparisonOperation::Greater:
- return configure_comp_func<ComparisonOperation::Greater>(src0, src1, dst);
- case ComparisonOperation::GreaterEqual:
- return configure_comp_func<ComparisonOperation::GreaterEqual>(src0, src1, dst);
- case ComparisonOperation::Less:
- return configure_comp_func<ComparisonOperation::Less>(src0, src1, dst);
- case ComparisonOperation::LessEqual:
- return configure_comp_func<ComparisonOperation::LessEqual>(src0, src1, dst);
- default:
- ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
- }
- return { "", nullptr };
-}
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/CpuElementwiseKernel.h b/src/cpu/kernels/CpuElementwiseKernel.h
index 8cd5d58a96..2785e0a44c 100644
--- a/src/cpu/kernels/CpuElementwiseKernel.h
+++ b/src/cpu/kernels/CpuElementwiseKernel.h
@@ -39,23 +39,29 @@ namespace kernels
* @f[ dst(x,y) = OP(src0(x,y), src1(x,y))@f]
*
*/
-class CpuElementwiseKernel : public ICpuKernel<CpuElementwiseKernel>
+template <class Derived>
+class CpuElementwiseKernel : public ICpuKernel<Derived>
{
+private:
+ using ElementwiseKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, const Window &)>::type;
+
public:
CpuElementwiseKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuElementwiseKernel);
using ElementwiseFunction = void(const ITensor *, const ITensor *, ITensor *, const Window &);
- struct UKernelInfo
- {
- std::string name;
- std::function<ElementwiseFunction> ukernel;
- };
-
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
+
const char *name() const override;
+ struct ElementwiseKernel
+ {
+ const char *name;
+ const ElementwiseDataTypeISASelectorPtr is_selected;
+ ElementwiseKernelPtr ukernel;
+ };
+
protected:
/** Validate the argument passed to the kernel
*
@@ -65,27 +71,12 @@ protected:
*/
static Status validate_arguments_common(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst);
- /** Commmon configure function for element-wise operators with no additional options (e.g. Min, Max, SquaredDiff)
- *
- */
- void configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst);
-
- /** Function to get the micro kernel implementation
- *
- * @param[in] src0 First input tensor information
- * @param[in] src1 Second input tensor information
- * @param[in] dst Output tensor information
- *
- * @return the function instance for the micro kernel
- */
- virtual UKernelInfo get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst) = 0;
-
protected:
std::function<ElementwiseFunction> _run_method{ nullptr };
std::string _name{};
};
-class CpuArithmeticKernel : public CpuElementwiseKernel
+class CpuArithmeticKernel : public CpuElementwiseKernel<CpuArithmeticKernel>
{
public:
CpuArithmeticKernel() = default;
@@ -107,7 +98,12 @@ public:
*/
static Status validate(ArithmeticOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst);
+ static const std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> &get_available_kernels();
+
protected:
+ /** Commmon configure function for element-wise operators with no additional options (e.g. Min, Max, SquaredDiff)
+ */
+ void configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst);
// Inherited methods overridden:
static Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst);
@@ -122,7 +118,6 @@ private:
*
* @return the function instance for the micro kernel
*/
- UKernelInfo get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst) override;
};
class CpuDivisionKernel : public CpuArithmeticKernel
@@ -177,7 +172,7 @@ protected:
static Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst);
};
-class CpuComparisonKernel : public CpuElementwiseKernel
+class CpuComparisonKernel : public CpuElementwiseKernel<CpuComparisonKernel>
{
public:
CpuComparisonKernel() = default;
@@ -199,7 +194,12 @@ public:
*/
static Status validate(ComparisonOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst);
+ static const std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> &get_available_kernels();
+
protected:
+ /** Commmon configure function for element-wise operators with no additional options (e.g. Min, Max, SquaredDiff)
+ */
+ void configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst);
// Inherited methods overridden:
static Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst);
@@ -212,7 +212,6 @@ private:
*
* @return the function instance for the micro kernel
*/
- UKernelInfo get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst) override;
ComparisonOperation _op{};
};
diff --git a/src/cpu/kernels/CpuElementwiseUnaryKernel.cpp b/src/cpu/kernels/CpuElementwiseUnaryKernel.cpp
index e8211fe93e..335de78aca 100644
--- a/src/cpu/kernels/CpuElementwiseUnaryKernel.cpp
+++ b/src/cpu/kernels/CpuElementwiseUnaryKernel.cpp
@@ -44,12 +44,11 @@ namespace
{
static const std::vector<CpuElementwiseUnaryKernel::ElementwiseUnaryKernel> available_kernels =
{
-#if defined(ARM_COMPUTE_ENABLE_SVE)
{
"sve_fp32_elementwise_unary",
[](const DataTypeISASelectorData & data)
{
- return data.dt == DataType::F32 && data.isa.sve;
+ return (data.dt == DataType::F32 && data.isa.sve);
},
REGISTER_FP32_SVE(sve_fp32_elementwise_unary)
},
@@ -57,35 +56,42 @@ static const std::vector<CpuElementwiseUnaryKernel::ElementwiseUnaryKernel> avai
"sve_fp16_elementwise_unary",
[](const DataTypeISASelectorData & data)
{
- return (data.dt == DataType::F16) && data.isa.sve;
+ return (data.dt == DataType::F16 && data.isa.sve && data.isa.fp16);
},
REGISTER_FP16_SVE(sve_fp16_elementwise_unary),
},
{
"sve_s32_elementwise_unary",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::S32 && data.isa.sve; },
+ [](const DataTypeISASelectorData & data)
+ {
+ return (data.dt == DataType::S32 && data.isa.sve);
+ },
REGISTER_INTEGER_SVE(sve_s32_elementwise_unary),
},
-#endif // defined(ARM_COMPUTE_ENABLE_SVE)
-#if defined(ARM_COMPUTE_ENABLE_NEON)
{
"neon_fp32_elementwise_unary",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::F32; },
+ [](const DataTypeISASelectorData & data)
+ {
+ return data.dt == DataType::F32;
+ },
REGISTER_FP32_NEON(neon_fp32_elementwise_unary),
},
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
{
"neon_fp16_elementwise_unary",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.fp16; },
+ [](const DataTypeISASelectorData & data)
+ {
+ return data.dt == DataType::F16 && data.isa.fp16;
+ },
REGISTER_FP16_NEON(neon_fp16_elementwise_unary),
},
-#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
{
"neon_s32_elementwise_unary",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::S32; },
+ [](const DataTypeISASelectorData & data)
+ {
+ return data.dt == DataType::S32;
+ },
REGISTER_INTEGER_NEON(neon_s32_elementwise_unary),
},
-#endif // defined(ARM_COMPUTE_ENABLE_NEON)
};
} // namespace
diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h
index 60dcea4a35..60bbd5933c 100644
--- a/src/cpu/kernels/CpuKernelSelectionTypes.h
+++ b/src/cpu/kernels/CpuKernelSelectionTypes.h
@@ -49,9 +49,17 @@ struct PoolDataTypeISASelectorData
cpuinfo::CpuIsaInfo isa;
};
+struct ElementwiseDataTypeISASelectorData
+{
+ DataType dt;
+ cpuinfo::CpuIsaInfo isa;
+ int op;
+};
+
// Selector pointer types
-using DataTypeISASelectorPtr = std::add_pointer<bool(const DataTypeISASelectorData &data)>::type;
-using PoolDataTypeISASelectorPtr = std::add_pointer<bool(const PoolDataTypeISASelectorData &data)>::type;
+using DataTypeISASelectorPtr = std::add_pointer<bool(const DataTypeISASelectorData &data)>::type;
+using PoolDataTypeISASelectorPtr = std::add_pointer<bool(const PoolDataTypeISASelectorData &data)>::type;
+using ElementwiseDataTypeISASelectorPtr = std::add_pointer<bool(const ElementwiseDataTypeISASelectorData &data)>::type;
} // namespace kernels
} // namespace cpu
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp
index 6766b10120..93cce785bd 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.cpp
+++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp
@@ -22,6 +22,7 @@
* SOFTWARE.
*/
#include "src/cpu/kernels/CpuSoftmaxKernel.h"
+
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
@@ -29,10 +30,12 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include "src/core/CPP/Validate.h"
-#include "src/core/common/Registrars.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
+
+#include "src/core/common/Registrars.h"
#include "src/cpu/kernels/softmax/list.h"
+
namespace arm_compute
{
namespace cpu
@@ -44,57 +47,53 @@ namespace
/* Softmax Logits 1D Max - identifying the max value of 1D Logits */
static const std::vector<CpuLogits1DMaxKernel::SoftmaxLogits1DMaxKernel> available_kernels_max_logits =
{
-#if defined(ARM_COMPUTE_ENABLE_SVE)
{
"sve_fp32_logits_1d_max",
[](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32) && data.isa.sve; },
- REGISTER_FP32_SVE(arm_compute::cpu::sve_fp32_logits)
+ REGISTER_FP32_SVE(sve_fp32_logits)
},
{
"sve_fp16_logits_1d_max",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.sve; },
- REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_logits)
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.sve && data.isa.fp16; },
+ REGISTER_FP16_SVE(sve_fp16_logits)
},
{
"sve_qu8_logits_1d_max",
[](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8) && data.isa.sve; },
- REGISTER_QASYMM8_SVE(arm_compute::cpu::sve_qasymm8_logits)
+ REGISTER_QASYMM8_SVE(sve_qasymm8_logits)
},
{
"sve_qs8_logits_1d_max",
[](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve; },
- REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::sve_qasymm8_signed_logits)
+ REGISTER_QASYMM8_SIGNED_SVE(sve_qasymm8_signed_logits)
},
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */
-#if defined(ARM_COMPUTE_ENABLE_NEON)
{
"neon_fp32_logits_1d_max",
[](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); },
- REGISTER_FP32_NEON(arm_compute::cpu::neon_fp32_logits)
+ REGISTER_FP32_NEON(neon_fp32_logits)
},
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
{
"neon_fp16_logits_1d_max",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16); },
- REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_logits)
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.fp16; },
+ REGISTER_FP16_NEON(neon_fp16_logits)
},
-#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
{
"neon_qu8_logits_1d_max",
[](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); },
- REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_logits)
+ REGISTER_QASYMM8_NEON(neon_qasymm8_logits)
},
{
"neon_qs8_logits_1d_max",
[](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
- REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_singed_logits)
+ REGISTER_QASYMM8_SIGNED_NEON(neon_qasymm8_singed_logits)
},
-#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */
};
+
Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
{
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
+
// Validate in case of configured output
if(output.total_size() != 0)
{
@@ -102,6 +101,7 @@ Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorI
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
}
+
return Status{};
}
} //namespace
@@ -109,37 +109,48 @@ const std::vector<CpuLogits1DMaxKernel::SoftmaxLogits1DMaxKernel> &CpuLogits1DMa
{
return available_kernels_max_logits;
}
+
void CpuLogits1DMaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*src, *dst));
+
// Softmax across the x dimension
const TensorShape output_shape = TensorShape(src->tensor_shape()).set(0, 1);
// Output auto initialization if not yet initialized
auto_init_if_empty(*dst, output_shape, 1, src->data_type(), src->quantization_info());
+
const auto *uk = get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() });
- ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
+ ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
+
_run_method = uk->ukernel;
_name = std::string("CpuLogits1DMaxKernel").append("/").append(uk->name);
- Window win = calculate_max_window(*src, Steps());
+
+ Window win = calculate_max_window(*src, Steps());
ICpuKernel::configure(win);
}
+
Status CpuLogits1DMaxKernel::validate(const ITensorInfo *src, const ITensorInfo *dst)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*src, *dst));
+
return Status{};
}
+
void CpuLogits1DMaxKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
ARM_COMPUTE_ERROR_ON(_run_method == nullptr);
+
const auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
auto dst = tensors.get_tensor(TensorType::ACL_DST);
+
_run_method(src, dst, window);
}
+
const char *CpuLogits1DMaxKernel::name() const
{
return _name.c_str();
@@ -149,46 +160,38 @@ const char *CpuLogits1DMaxKernel::name() const
template <bool IS_LOG>
static const std::vector<typename CpuLogits1DSoftmaxKernel<IS_LOG>::SoftmaxLogits1DKernel> available_kernels_logits =
{
-#if defined(ARM_COMPUTE_ENABLE_SVE)
+ {
+ "sve2_qu8_softmax_logits_1d",
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8) && data.isa.sve2; },
+ REGISTER_QASYMM8_SVE2(sve2_qasymm8_softmax)
+ },
+ {
+ "sve2_qs8_softmax_logits_1d",
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve2; },
+ REGISTER_QASYMM8_SIGNED_SVE2(sve2_qasymm8_signed_softmax)
+ },
{
"sve_fp32_softmax_logits_1d",
[](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32) && data.isa.sve; },
- REGISTER_FP32_SVE(arm_compute::cpu::sve_fp32_softmax)
+ REGISTER_FP32_SVE(sve_fp32_softmax)
},
{
"sve_fp16_softmax_logits_1d",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.sve; },
- REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_softmax)
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.sve && data.isa.fp16; },
+ REGISTER_FP16_SVE(sve_fp16_softmax)
},
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */
-#if defined(ARM_COMPUTE_ENABLE_NEON)
+
{
"neon_fp32_softmax_logits_1d",
[](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); },
- REGISTER_FP32_NEON(arm_compute::cpu::neon_fp32_softmax)
+ REGISTER_FP32_NEON(neon_fp32_softmax)
},
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
{
"neon_fp16_softmax_logits_1d",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16); },
- REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_softmax)
- },
-#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
-#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */
-#if defined(ARM_COMPUTE_ENABLE_SVE2)
- {
- "sve2_qu8_softmax_logits_1d",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8) && data.isa.sve2; },
- REGISTER_QASYMM8_SVE2(arm_compute::cpu::sve2_qasymm8_softmax)
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.fp16; },
+ REGISTER_FP16_NEON(neon_fp16_softmax)
},
{
- "sve2_qs8_softmax_logits_1d",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve2; },
- REGISTER_QASYMM8_SIGNED_SVE2(arm_compute::cpu::sve2_qasymm8_signed_softmax)
- },
-#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
-#if defined(ARM_COMPUTE_ENABLE_NEON)
- {
"neon_qu8_softmax_logits_1d",
[](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); },
REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_softmax)
@@ -198,7 +201,6 @@ static const std::vector<typename CpuLogits1DSoftmaxKernel<IS_LOG>::SoftmaxLogit
[](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_signed_softmax)
},
-#endif //defined(ARM_COMPUTE_ENABLE_NEON)
};
namespace
{
@@ -209,11 +211,14 @@ Status validate_arguments_logits_softmax(const ITensorInfo &src, const ITensorIn
// Check input
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
+
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type());
+
// Check max
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &max);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(src.tensor_shape()).set(0, 1), max.tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&src, &max);
+
// Check output if configured
if(dst.total_size() != 0)
{
@@ -222,6 +227,7 @@ Status validate_arguments_logits_softmax(const ITensorInfo &src, const ITensorIn
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &dst);
ARM_COMPUTE_RETURN_ERROR_ON(dst.quantization_info() != output_quantization);
}
+
// Check tmp if configured
if(tmp.total_size() != 0)
{
@@ -231,69 +237,90 @@ Status validate_arguments_logits_softmax(const ITensorInfo &src, const ITensorIn
// on the maximum number of threads that will run in parallel.
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &tmp);
}
+
return Status{};
}
} // namespace
-template <bool IS_LOG>
+
+template <bool IS_LOG>
const std::vector<typename CpuLogits1DSoftmaxKernel<IS_LOG>::SoftmaxLogits1DKernel> &CpuLogits1DSoftmaxKernel<IS_LOG>::get_available_kernels()
{
return available_kernels_logits<IS_LOG>;
}
+
template <bool IS_LOG>
void CpuLogits1DSoftmaxKernel<IS_LOG>::configure(const ITensorInfo *src, const ITensorInfo *max, ITensorInfo *dst, const float beta, ITensorInfo *tmp)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, max, dst, tmp);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*src, *max, *dst, beta, *tmp, IS_LOG));
+
// Configure kernel window
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src->data_type());
+
// Output auto initialization if not yet initialized
const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(src->data_type(), IS_LOG) : dst->quantization_info();
auto_init_if_empty(*dst, TensorInfo(*src).set_quantization_info(output_quantization).reset_padding());
+
// Tmp auto initialization if not yet initialized
const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : src->data_type();
auto_init_if_empty(*tmp, TensorInfo(*src).set_data_type(tmp_data_type).reset_padding());
+
const auto *uk = CpuLogits1DSoftmaxKernel<IS_LOG>::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() });
- ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
+ ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
+
std::string kernel_name = IS_LOG ? std::string("CpuLogits1DLogSoftmaxKernel") : std::string("CpuLogits1DSoftmaxKernel");
- _beta = beta;
- _run_method = uk->ukernel;
- _name = kernel_name.append("/").append(uk->name);
+
+ _beta = beta;
+ _run_method = uk->ukernel;
+ _name = kernel_name.append("/").append(uk->name);
+
// Configure kernel window
Window win = calculate_max_window(*max, Steps());
- ICPPKernel::configure(win);
+
+ ICpuKernel<CpuLogits1DSoftmaxKernel<IS_LOG>>::configure(win);
}
+
template <bool IS_LOG>
Status CpuLogits1DSoftmaxKernel<IS_LOG>::validate(const ITensorInfo *src, const ITensorInfo *max,
const ITensorInfo *dst, const float beta, const ITensorInfo *tmp)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, max, dst, tmp);
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*src, *max, *dst, beta, *tmp, IS_LOG));
+
return Status{};
}
+
template <bool IS_LOG>
void CpuLogits1DSoftmaxKernel<IS_LOG>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
- ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICPPKernel::window(), window);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel<CpuLogits1DSoftmaxKernel<IS_LOG>>::window(), window);
ARM_COMPUTE_ERROR_ON(_run_method == nullptr);
- const auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0);
- auto max = tensors.get_tensor(TensorType::ACL_SRC_1);
- auto dst = tensors.get_tensor(TensorType::ACL_DST_0);
- auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
+
+ const auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0);
+ auto max = tensors.get_tensor(TensorType::ACL_SRC_1);
+ auto dst = tensors.get_tensor(TensorType::ACL_DST_0);
+ auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
+
const unsigned int num_elems_processed_per_iteration = src->info()->valid_region().shape.x();
const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration;
+
ARM_COMPUTE_ERROR_ON(tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
+
void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
_run_method(src, max, tmp_for_thread, dst, _beta, IS_LOG, window);
}
+
template <bool IS_LOG>
const char *CpuLogits1DSoftmaxKernel<IS_LOG>::name() const
{
return _name.c_str();
}
+
template class CpuLogits1DSoftmaxKernel<true>;
template class CpuLogits1DSoftmaxKernel<false>;
+
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h
index df7d3f7d9b..59f43bd1d2 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.h
+++ b/src/cpu/kernels/CpuSoftmaxKernel.h
@@ -23,8 +23,10 @@
*/
#ifndef ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H
#define ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H
+
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuKernel.h"
+
namespace arm_compute
{
namespace cpu
@@ -53,21 +55,25 @@ public:
* @return a status
*/
static Status validate(const ITensorInfo *src, const ITensorInfo *dst);
+
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
+
struct SoftmaxLogits1DMaxKernel
{
const char *name;
const DataTypeISASelectorPtr is_selected;
SoftmaxLogits1DMaxKernelPtr ukernel;
};
+
static const std::vector<SoftmaxLogits1DMaxKernel> &get_available_kernels();
private:
SoftmaxLogits1DMaxKernelPtr _run_method{ nullptr };
std::string _name{};
};
+
/** Interface for softmax computation for QASYMM8 with pre-computed max. */
template <bool IS_LOG = false>
class CpuLogits1DSoftmaxKernel : public ICpuKernel<CpuLogits1DSoftmaxKernel<IS_LOG>>
@@ -78,6 +84,7 @@ private:
public:
CpuLogits1DSoftmaxKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuLogits1DSoftmaxKernel);
+
/** Set the input and output tensors.
*
* @param[in] src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
@@ -97,15 +104,18 @@ public:
*/
static Status validate(const ITensorInfo *src, const ITensorInfo *max,
const ITensorInfo *dst, const float beta, const ITensorInfo *tmp);
+
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
+
struct SoftmaxLogits1DKernel
{
const char *name;
const DataTypeISASelectorPtr is_selected;
SoftmaxLogits1DKernelPtr ukernel;
};
+
static const std::vector<SoftmaxLogits1DKernel> &get_available_kernels();
private:
diff --git a/tests/validation/NEON/ElementwiseKernelSelection.cpp b/tests/validation/NEON/ElementwiseKernelSelection.cpp
new file mode 100644
index 0000000000..78adc747fd
--- /dev/null
+++ b/tests/validation/NEON/ElementwiseKernelSelection.cpp
@@ -0,0 +1,158 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/Types.h"
+#include "src/common/cpuinfo/CpuIsaInfo.h"
+#include "src/cpu/kernels/CpuElementwiseKernel.h"
+#include "src/cpu/kernels/CpuElementwiseUnaryKernel.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+TEST_SUITE(NEON)
+TEST_SUITE(KernelSelection)
+
+DATA_TEST_CASE(KernelSelection_elementwise_unary, framework::DatasetMode::ALL, concat(
+ combine(framework::dataset::make("CpuExt", std::string("NEON")),
+ framework::dataset::make("DataType", { DataType::F32,
+ DataType::F16,
+ DataType::S32
+ })),
+ combine(framework::dataset::make("CpuExt", std::string("SVE")),
+ framework::dataset::make("DataType", { DataType::F32,
+ DataType::F16,
+ DataType::S32
+ }))),
+ cpu_ext, data_type)
+{
+ using namespace cpu::kernels;
+
+ cpuinfo::CpuIsaInfo cpu_isa{};
+ cpu_isa.neon = (cpu_ext == "NEON");
+ cpu_isa.sve = (cpu_ext == "SVE");
+ cpu_isa.fp16 = (data_type == DataType::F16);
+
+ const auto *selected_impl = CpuElementwiseUnaryKernel::get_implementation(DataTypeISASelectorData{ data_type, cpu_isa }, cpu::KernelSelectionType::Preferred);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl);
+
+ std::string expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_elementwise_unary";
+ std::string actual = selected_impl->name;
+
+ ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS);
+}
+
+DATA_TEST_CASE(KernelSelection_elementwise_arithmetic, framework::DatasetMode::ALL, concat(concat(
+ combine(framework::dataset::make("CpuExt", std::string("NEON")),
+ framework::dataset::make("DataType", { DataType::F32,
+ DataType::F16,
+ DataType::S32,
+ DataType::S16,
+ DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED
+ })),
+ combine(framework::dataset::make("CpuExt", std::string("SVE")),
+ framework::dataset::make("DataType", { DataType::F32,
+ DataType::F16,
+ DataType::S32,
+ DataType::S16
+ }))),
+ combine(framework::dataset::make("CpuExt", std::string("SVE2")),
+ framework::dataset::make("DataType", { DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED
+ }))),
+ cpu_ext, data_type)
+{
+ using namespace cpu::kernels;
+
+ cpuinfo::CpuIsaInfo cpu_isa{};
+ cpu_isa.neon = (cpu_ext == "NEON");
+ cpu_isa.sve = (cpu_ext == "SVE");
+ cpu_isa.sve2 = (cpu_ext == "SVE2");
+ cpu_isa.fp16 = (data_type == DataType::F16);
+
+ const auto *selected_impl = CpuArithmeticKernel::get_implementation(
+ ElementwiseDataTypeISASelectorData{ data_type, cpu_isa, static_cast<int>(ArithmeticOperation::ADD) },
+ cpu::KernelSelectionType::Preferred);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl);
+
+ std::string expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_arithmetic";
+ std::string actual = selected_impl->name;
+
+ ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS);
+}
+
+DATA_TEST_CASE(KernelSelection_elementwise_comparison, framework::DatasetMode::ALL, concat(concat(
+ combine(framework::dataset::make("CpuExt", std::string("NEON")),
+ framework::dataset::make("DataType", { DataType::F32,
+ DataType::F16,
+ DataType::S32,
+ DataType::S16,
+ DataType::U8,
+ DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED
+ })),
+ combine(framework::dataset::make("CpuExt", std::string("SVE")),
+ framework::dataset::make("DataType", { DataType::F32,
+ DataType::F16,
+ DataType::S32,
+ DataType::S16,
+ DataType::U8
+ }))),
+ combine(framework::dataset::make("CpuExt", std::string("SVE2")),
+ framework::dataset::make("DataType", { DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED
+ }))),
+ cpu_ext, data_type)
+{
+ using namespace cpu::kernels;
+
+ cpuinfo::CpuIsaInfo cpu_isa{};
+ cpu_isa.neon = (cpu_ext == "NEON");
+ cpu_isa.sve = (cpu_ext == "SVE");
+ cpu_isa.sve2 = (cpu_ext == "SVE2");
+ cpu_isa.fp16 = (data_type == DataType::F16);
+
+ const auto *selected_impl = CpuComparisonKernel::get_implementation(
+ ElementwiseDataTypeISASelectorData{ data_type, cpu_isa, static_cast<int>(ComparisonOperation::Equal) },
+ cpu::KernelSelectionType::Preferred);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl);
+
+ std::string expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_comparison";
+ std::string actual = selected_impl->name;
+
+ ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS);
+}
+
+TEST_SUITE_END()
+TEST_SUITE_END() // Neon
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp
index 2a9e30604e..9084353743 100644
--- a/tests/validation/NEON/SoftmaxLayer.cpp
+++ b/tests/validation/NEON/SoftmaxLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,6 +25,8 @@
#include "arm_compute/runtime/NEON/functions/NESoftmaxLayer.h"
#include "arm_compute/runtime/Tensor.h"
#include "arm_compute/runtime/TensorAllocator.h"
+#include "src/common/cpuinfo/CpuIsaInfo.h"
+#include "src/cpu/kernels/CpuSoftmaxKernel.h"
#include "tests/NEON/Accessor.h"
#include "tests/PaddingCalculator.h"
#include "tests/datasets/ShapeDatasets.h"
@@ -33,7 +35,6 @@
#include "tests/framework/datasets/Datasets.h"
#include "tests/validation/Validation.h"
#include "tests/validation/fixtures/SoftmaxLayerFixture.h"
-
namespace arm_compute
{
namespace test
@@ -62,7 +63,6 @@ const auto CNNDataTypes = framework::dataset::make("DataType",
TEST_SUITE(NEON)
TEST_SUITE(SoftmaxLayer)
-
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
@@ -121,6 +121,73 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
template <typename T>
using NESoftmaxLayerFixture = SoftmaxValidationFixture<Tensor, Accessor, NESoftmaxLayer, T>;
+DATA_TEST_CASE(KernelSelection_max_logits, framework::DatasetMode::ALL, concat(
+ combine(framework::dataset::make("CpuExt", std::string("NEON")),
+ framework::dataset::make("DataType", { DataType::F32,
+ DataType::F16,
+ DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED
+ })),
+ combine(framework::dataset::make("CpuExt", std::string("SVE")),
+ framework::dataset::make("DataType", { DataType::F32,
+ DataType::F16,
+ DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED
+ }))),
+ cpu_ext, data_type)
+{
+ using namespace cpu::kernels;
+
+ cpuinfo::CpuIsaInfo cpu_isa{};
+ cpu_isa.neon = (cpu_ext == "NEON");
+ cpu_isa.sve = (cpu_ext == "SVE");
+ cpu_isa.fp16 = (data_type == DataType::F16);
+
+ const auto *selected_impl = CpuLogits1DMaxKernel::get_implementation(DataTypeISASelectorData{ data_type, cpu_isa }, cpu::KernelSelectionType::Preferred);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl);
+
+ std::string expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_logits_1d_max";
+ std::string actual = selected_impl->name;
+
+ ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS);
+}
+
+DATA_TEST_CASE(KernelSelection_logits, framework::DatasetMode::ALL, concat(concat(
+ combine(framework::dataset::make("CpuExt", std::string("NEON")),
+ framework::dataset::make("DataType", { DataType::F32,
+ DataType::F16,
+ DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED
+ })),
+ combine(framework::dataset::make("CpuExt", std::string("SVE")),
+ framework::dataset::make("DataType", { DataType::F32,
+ DataType::F16
+ }))),
+ combine(framework::dataset::make("CpuExt", std::string("SVE2")),
+ framework::dataset::make("DataType", { DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED
+ }))),
+ cpu_ext, data_type)
+{
+ using namespace cpu::kernels;
+
+ cpuinfo::CpuIsaInfo cpu_isa{};
+ cpu_isa.neon = (cpu_ext == "NEON");
+ cpu_isa.sve = (cpu_ext == "SVE");
+ cpu_isa.sve2 = (cpu_ext == "SVE2");
+ cpu_isa.fp16 = (data_type == DataType::F16);
+
+ const auto *selected_impl = CpuLogits1DSoftmaxKernel<false>::get_implementation(DataTypeISASelectorData{ data_type, cpu_isa }, cpu::KernelSelectionType::Preferred);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl);
+
+ std::string expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_softmax_logits_1d";
+ std::string actual = selected_impl->name;
+
+ ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS);
+}
+
TEST_SUITE(Float)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)