diff options
Diffstat (limited to 'src/cpu/kernels/CpuSubKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuSubKernel.cpp | 45 |
1 files changed, 31 insertions, 14 deletions
diff --git a/src/cpu/kernels/CpuSubKernel.cpp b/src/cpu/kernels/CpuSubKernel.cpp index 37a087f115..875d613dca 100644 --- a/src/cpu/kernels/CpuSubKernel.cpp +++ b/src/cpu/kernels/CpuSubKernel.cpp @@ -29,14 +29,15 @@ #include "src/core/common/Registrars.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" +#include "src/cpu/kernels/add/generic/neon/impl.h" #include "src/cpu/kernels/sub/neon/list.h" #if defined(ENABLE_FP32_KERNELS) namespace { - static constexpr size_t default_mws_N1_fp32_neon = 24385; - static constexpr size_t default_mws_V1_fp32_neon = 40520; -} +static constexpr size_t default_mws_N1_fp32_neon = 24385; +static constexpr size_t default_mws_V1_fp32_neon = 40520; +} // namespace #endif /* ENABLE_FP32_KERNELS */ namespace arm_compute @@ -47,46 +48,59 @@ namespace kernels { namespace { +using CpuSubKernelDataTypeISASelectorData = CpuAddKernelDataTypeISASelectorData; +using CpuSubKernelDataTypeISASelectorDataPtr = CpuAddKernelDataTypeISASelectorDataPtr; + static const std::vector<CpuSubKernel::SubKernel> available_kernels = { { "neon_fp32_sub", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); }, + [](const CpuSubKernelDataTypeISASelectorData & data) { return (data.dt == DataType::F32); }, REGISTER_FP32_NEON(arm_compute::cpu::sub_same_neon<float>) }, { "neon_fp16_sub", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.fp16; }, + [](const CpuSubKernelDataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.fp16; }, REGISTER_FP16_NEON(arm_compute::cpu::sub_same_neon<float16_t>) }, { "neon_u8_sub", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::U8); }, + [](const CpuSubKernelDataTypeISASelectorData & data) { return (data.dt == DataType::U8); }, REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<uint8_t>) }, { "neon_s16_sub", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::S16); }, + [](const CpuSubKernelDataTypeISASelectorData & data) { return (data.dt == DataType::S16); }, REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<int16_t>) }, { "neon_s32_sub", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::S32); }, + [](const CpuSubKernelDataTypeISASelectorData & data) { return (data.dt == DataType::S32); }, REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<int32_t>) }, { + "neon_qu8_sub_fixedpoint", + [](const CpuSubKernelDataTypeISASelectorData & data) { return ((data.dt == DataType::QASYMM8) && data.can_use_fixedpoint); }, + REGISTER_QASYMM8_NEON(arm_compute::cpu::sub_qasymm8_neon_fixedpoint) + }, + { + "neon_qs8_sub_fixedpoint", + [](const CpuSubKernelDataTypeISASelectorData & data) { return ((data.dt == DataType::QASYMM8_SIGNED) && data.can_use_fixedpoint); }, + REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::sub_qasymm8_signed_neon_fixedpoint) + }, + { "neon_qu8_sub", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); }, + [](const CpuSubKernelDataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::sub_qasymm8_neon) }, { "neon_qs8_sub", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, + [](const CpuSubKernelDataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::sub_qasymm8_signed_neon) }, { "neon_qs16_sub", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QSYMM16); }, + [](const CpuSubKernelDataTypeISASelectorData & data) { return (data.dt == DataType::QSYMM16); }, REGISTER_QSYMM16_NEON(arm_compute::cpu::sub_qsymm16_neon) }, }; @@ -99,7 +113,8 @@ inline Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1); - const auto *uk = CpuSubKernel::get_implementation(DataTypeISASelectorData{ src0.data_type(), CPUInfo::get().get_isa() }); + const auto can_use_fixedpoint = sub_q8_neon_fixedpoint_possible(&src0, &src1, &dst); + const auto uk = CpuSubKernel::get_implementation<CpuSubKernelDataTypeISASelectorData>(CpuSubKernelDataTypeISASelectorData{ src0.data_type(), CPUInfo::get().get_isa(), can_use_fixedpoint }); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); @@ -131,7 +146,9 @@ void CpuSubKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, I set_shape_if_empty(*dst, out_shape); set_data_type_if_unknown(*dst, src0->data_type()); - const auto *uk = CpuSubKernel::get_implementation(DataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa() }); + const auto can_use_fixedpoint = sub_q8_neon_fixedpoint_possible(src0, src1, dst); + const auto uk = CpuSubKernel::get_implementation<CpuSubKernelDataTypeISASelectorData>(CpuSubKernelDataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa(), can_use_fixedpoint }); + ARM_COMPUTE_ERROR_ON_NULLPTR(uk); _policy = policy; @@ -180,7 +197,7 @@ size_t CpuSubKernel::get_mws(const CPUInfo &platform, size_t thread_count) const return std::max(static_cast<size_t>(1), mws); } } -#else /* ENABLE_FP32_KERNELS */ +#else /* ENABLE_FP32_KERNELS */ ARM_COMPUTE_UNUSED(platform); #endif /* ENABLE_FP32_KERNELS */ return ICPPKernel::default_mws; |