aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuSubKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuSubKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuSubKernel.cpp64
1 files changed, 19 insertions, 45 deletions
diff --git a/src/cpu/kernels/CpuSubKernel.cpp b/src/cpu/kernels/CpuSubKernel.cpp
index ec65f12dfc..c12feb4331 100644
--- a/src/cpu/kernels/CpuSubKernel.cpp
+++ b/src/cpu/kernels/CpuSubKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -39,85 +39,52 @@ namespace kernels
{
namespace
{
-struct SubSelectorData
-{
- DataType dt;
-};
-
-using SubSelectorPtr = std::add_pointer<bool(const SubSelectorData &data)>::type;
-using SubKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, const ConvertPolicy &, const Window &)>::type;
-
-struct SubKernel
-{
- const char *name;
- const SubSelectorPtr is_selected;
- SubKernelPtr ukernel;
-};
-
-static const SubKernel available_kernels[] =
+static const std::vector<CpuSubKernel::SubKernel> available_kernels =
{
{
"neon_fp32_sub",
- [](const SubSelectorData & data) { return (data.dt == DataType::F32); },
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); },
REGISTER_FP32_NEON(arm_compute::cpu::sub_same_neon<float>)
},
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
{
"neon_fp16_sub",
- [](const SubSelectorData & data) { return (data.dt == DataType::F16); },
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.fp16; },
REGISTER_FP16_NEON(arm_compute::cpu::sub_same_neon<float16_t>)
},
#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
{
"neon_u8_sub",
- [](const SubSelectorData & data) { return (data.dt == DataType::U8); },
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::U8); },
REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<uint8_t>)
},
{
"neon_s16_sub",
- [](const SubSelectorData & data) { return (data.dt == DataType::S16); },
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::S16); },
REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<int16_t>)
},
{
"neon_s32_sub",
- [](const SubSelectorData & data) { return (data.dt == DataType::S32); },
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::S32); },
REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<int32_t>)
},
{
"neon_qu8_sub",
- [](const SubSelectorData & data) { return (data.dt == DataType::QASYMM8); },
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); },
REGISTER_QASYMM8_NEON(arm_compute::cpu::sub_qasymm8_neon)
},
{
"neon_qs8_sub",
- [](const SubSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::sub_qasymm8_signed_neon)
},
{
"neon_qs16_sub",
- [](const SubSelectorData & data) { return (data.dt == DataType::QSYMM16); },
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QSYMM16); },
REGISTER_QSYMM16_NEON(arm_compute::cpu::sub_qsymm16_neon)
},
};
-/** Micro-kernel selector
- *
- * @param[in] data Selection data passed to help pick the appropriate micro-kernel
- *
- * @return A matching micro-kernel else nullptr
- */
-const SubKernel *get_implementation(DataType dt)
-{
- for(const auto &uk : available_kernels)
- {
- if(uk.is_selected({ dt }))
- {
- return &uk;
- }
- }
- return nullptr;
-}
-
inline Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst, ConvertPolicy policy)
{
ARM_COMPUTE_UNUSED(policy);
@@ -126,7 +93,8 @@ inline Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src
DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1);
- const auto *uk = get_implementation(src0.data_type());
+ const auto *uk = CpuSubKernel::get_implementation(DataTypeISASelectorData{ src0.data_type(), CPUInfo::get().get_isa() });
+
ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
const TensorShape out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape());
@@ -157,7 +125,7 @@ void CpuSubKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, I
set_shape_if_empty(*dst, out_shape);
set_data_type_if_unknown(*dst, src0->data_type());
- const auto *uk = get_implementation(src0->data_type());
+ const auto *uk = CpuSubKernel::get_implementation(DataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa() });
ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
_policy = policy;
@@ -196,6 +164,12 @@ const char *CpuSubKernel::name() const
{
return _name.c_str();
}
+
+const std::vector<CpuSubKernel::SubKernel> &CpuSubKernel::get_available_kernels()
+{
+ return available_kernels;
+}
+
} // namespace kernels
} // namespace cpu
} // namespace arm_compute