diff options
Diffstat (limited to 'src/cpu/kernels/CpuSubKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuSubKernel.cpp | 64 |
1 files changed, 19 insertions, 45 deletions
diff --git a/src/cpu/kernels/CpuSubKernel.cpp b/src/cpu/kernels/CpuSubKernel.cpp index ec65f12dfc..c12feb4331 100644 --- a/src/cpu/kernels/CpuSubKernel.cpp +++ b/src/cpu/kernels/CpuSubKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -39,85 +39,52 @@ namespace kernels { namespace { -struct SubSelectorData -{ - DataType dt; -}; - -using SubSelectorPtr = std::add_pointer<bool(const SubSelectorData &data)>::type; -using SubKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, const ConvertPolicy &, const Window &)>::type; - -struct SubKernel -{ - const char *name; - const SubSelectorPtr is_selected; - SubKernelPtr ukernel; -}; - -static const SubKernel available_kernels[] = +static const std::vector<CpuSubKernel::SubKernel> available_kernels = { { "neon_fp32_sub", - [](const SubSelectorData & data) { return (data.dt == DataType::F32); }, + [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); }, REGISTER_FP32_NEON(arm_compute::cpu::sub_same_neon<float>) }, #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) { "neon_fp16_sub", - [](const SubSelectorData & data) { return (data.dt == DataType::F16); }, + [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.fp16; }, REGISTER_FP16_NEON(arm_compute::cpu::sub_same_neon<float16_t>) }, #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ { "neon_u8_sub", - [](const SubSelectorData & data) { return (data.dt == DataType::U8); }, + [](const DataTypeISASelectorData & data) { return (data.dt == DataType::U8); }, REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<uint8_t>) }, { "neon_s16_sub", - [](const SubSelectorData & data) { return (data.dt == DataType::S16); }, + [](const DataTypeISASelectorData & data) { return (data.dt == DataType::S16); }, REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<int16_t>) }, { "neon_s32_sub", - [](const SubSelectorData & data) { return (data.dt == DataType::S32); }, + [](const DataTypeISASelectorData & data) { return (data.dt == DataType::S32); }, REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<int32_t>) }, { "neon_qu8_sub", - [](const SubSelectorData & data) { return (data.dt == DataType::QASYMM8); }, + [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::sub_qasymm8_neon) }, { "neon_qs8_sub", - [](const SubSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, + [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::sub_qasymm8_signed_neon) }, { "neon_qs16_sub", - [](const SubSelectorData & data) { return (data.dt == DataType::QSYMM16); }, + [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QSYMM16); }, REGISTER_QSYMM16_NEON(arm_compute::cpu::sub_qsymm16_neon) }, }; -/** Micro-kernel selector - * - * @param[in] data Selection data passed to help pick the appropriate micro-kernel - * - * @return A matching micro-kernel else nullptr - */ -const SubKernel *get_implementation(DataType dt) -{ - for(const auto &uk : available_kernels) - { - if(uk.is_selected({ dt })) - { - return &uk; - } - } - return nullptr; -} - inline Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst, ConvertPolicy policy) { ARM_COMPUTE_UNUSED(policy); @@ -126,7 +93,8 @@ inline Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1); - const auto *uk = get_implementation(src0.data_type()); + const auto *uk = CpuSubKernel::get_implementation(DataTypeISASelectorData{ src0.data_type(), CPUInfo::get().get_isa() }); + ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); const TensorShape out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape()); @@ -157,7 +125,7 @@ void CpuSubKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, I set_shape_if_empty(*dst, out_shape); set_data_type_if_unknown(*dst, src0->data_type()); - const auto *uk = get_implementation(src0->data_type()); + const auto *uk = CpuSubKernel::get_implementation(DataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa() }); ARM_COMPUTE_ERROR_ON_NULLPTR(uk); _policy = policy; @@ -196,6 +164,12 @@ const char *CpuSubKernel::name() const { return _name.c_str(); } + +const std::vector<CpuSubKernel::SubKernel> &CpuSubKernel::get_available_kernels() +{ + return available_kernels; +} + } // namespace kernels } // namespace cpu } // namespace arm_compute |