diff options
Diffstat (limited to 'src/cpu/kernels/CpuScaleKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuScaleKernel.cpp | 72 |
1 files changed, 23 insertions, 49 deletions
diff --git a/src/cpu/kernels/CpuScaleKernel.cpp b/src/cpu/kernels/CpuScaleKernel.cpp index 3063d8f682..60564a97dd 100644 --- a/src/cpu/kernels/CpuScaleKernel.cpp +++ b/src/cpu/kernels/CpuScaleKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -48,52 +48,37 @@ namespace kernels { namespace { -struct ScaleSelectorData -{ - DataType dt; - const CPUInfo &ci; -}; -using ScaleSelectorPtr = std::add_pointer<bool(const ScaleSelectorData &data)>::type; -using ScaleKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const ITensor *, const ITensor *, const ITensor *, - InterpolationPolicy, BorderMode, PixelValue, float, bool, const Window &)>::type; -struct ScaleKernel -{ - const char *name; - const ScaleSelectorPtr is_selected; - ScaleKernelPtr ukernel; -}; - -static const ScaleKernel available_kernels[] = +static const std::vector<CpuScaleKernel::ScaleKernel> available_kernels = { #if defined(ARM_COMPUTE_ENABLE_SVE) { "sve_fp16_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_sve(); }, + [](const DataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.sve; }, REGISTER_FP16_SVE(arm_compute::cpu::fp16_sve_scale) }, { "sve_fp32_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::F32 && data.ci.has_sve(); }, + [](const DataTypeISASelectorData & data) { return data.dt == DataType::F32 && data.isa.sve; }, REGISTER_FP32_SVE(arm_compute::cpu::fp32_sve_scale) }, { "sve_qu8_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::QASYMM8 && data.ci.has_sve(); }, + [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve; }, REGISTER_QASYMM8_SVE(arm_compute::cpu::qasymm8_sve_scale) }, { "sve_qs8_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.ci.has_sve(); }, + [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve; }, REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::qasymm8_signed_sve_scale) }, { "sve_u8_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::U8 && data.ci.has_sve(); }, + [](const DataTypeISASelectorData & data) { return data.dt == DataType::U8 && data.isa.sve; }, REGISTER_INTEGER_SVE(arm_compute::cpu::u8_sve_scale) }, { "sve_s16_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::S16 && data.ci.has_sve(); }, + [](const DataTypeISASelectorData & data) { return data.dt == DataType::S16 && data.isa.sve; }, REGISTER_INTEGER_SVE(arm_compute::cpu::s16_sve_scale) }, #endif /* defined(ARM_COMPUTE_ENABLE_SVE) */ @@ -101,60 +86,43 @@ static const ScaleKernel available_kernels[] = #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) { "neon_fp16_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_fp16(); }, + [](const DataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.fp16; }, REGISTER_FP16_NEON(arm_compute::cpu::common_neon_scale<float16_t>) }, #endif /* !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ { "neon_fp32_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::F32; }, + [](const DataTypeISASelectorData & data) { return data.dt == DataType::F32; }, REGISTER_FP32_NEON(arm_compute::cpu::common_neon_scale<float>) }, { "neon_qu8_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::QASYMM8; }, + [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8; }, REGISTER_QASYMM8_NEON(arm_compute::cpu::qasymm8_neon_scale) }, { "neon_qs8_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; }, + [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::qasymm8_signed_neon_scale) }, { "neon_u8_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::U8; }, + [](const DataTypeISASelectorData & data) { return data.dt == DataType::U8; }, REGISTER_INTEGER_NEON(arm_compute::cpu::u8_neon_scale) }, { "neon_s16_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::S16; }, + [](const DataTypeISASelectorData & data) { return data.dt == DataType::S16; }, REGISTER_INTEGER_NEON(arm_compute::cpu::s16_neon_scale) }, #endif /* defined(ARM_COMPUTE_ENABLE_NEON) */ }; -/** Micro-kernel selector - * - * @param[in] data Selection data passed to help pick the appropriate micro-kernel - * - * @return A matching micro-kernel else nullptr - */ -const ScaleKernel *get_implementation(const ScaleSelectorData &data) -{ - for(const auto &uk : available_kernels) - { - if(uk.is_selected(data)) - { - return &uk; - } - } - return nullptr; -} - Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dx, const ITensorInfo *dy, const ITensorInfo *offsets, ITensorInfo *dst, const ScaleKernelInfo &info) { - const auto *uk = get_implementation(ScaleSelectorData{ src->data_type(), CPUInfo::get() }); + const auto *uk = CpuScaleKernel::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() }); + ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(dst); @@ -212,7 +180,7 @@ void CpuScaleKernel::configure(const ITensorInfo *src, const ITensorInfo *dx, co dst, info)); - const auto *uk = get_implementation(ScaleSelectorData{ src->data_type(), CPUInfo::get() }); + const auto *uk = CpuScaleKernel::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() }); ARM_COMPUTE_ERROR_ON_NULLPTR(uk); _run_method = uk->ukernel; @@ -618,6 +586,12 @@ const char *CpuScaleKernel::name() const { return _name.c_str(); } + +const std::vector<CpuScaleKernel::ScaleKernel> &CpuScaleKernel::get_available_kernels() +{ + return available_kernels; +} + } // namespace kernels } // namespace cpu } // namespace arm_compute |