diff options
Diffstat (limited to 'src/core/cpu/kernels/CpuScaleKernel.cpp')
-rw-r--r-- | src/core/cpu/kernels/CpuScaleKernel.cpp | 29 |
1 files changed, 15 insertions, 14 deletions
diff --git a/src/core/cpu/kernels/CpuScaleKernel.cpp b/src/core/cpu/kernels/CpuScaleKernel.cpp index 29475fa63f..a072dbd896 100644 --- a/src/core/cpu/kernels/CpuScaleKernel.cpp +++ b/src/core/cpu/kernels/CpuScaleKernel.cpp @@ -50,7 +50,8 @@ namespace { struct ScaleSelectorData { - DataType dt; + DataType dt; + const CPUInfo &ci; }; using ScaleSelectorPtr = std::add_pointer<bool(const ScaleSelectorData &data)>::type; using ScaleKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const ITensor *, const ITensor *, const ITensor *, @@ -64,43 +65,43 @@ struct ScaleKernel static const ScaleKernel available_kernels[] = { -#if defined(ENABLE_SVE) +#if defined(ARM_COMPUTE_ENABLE_SVE) { "fp16_sve_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::F16; }, + [](const ScaleSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_sve(); }, REGISTER_FP16_SVE(arm_compute::cpu::fp16_sve_scale) }, { "f32_sve_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::F32; }, + [](const ScaleSelectorData & data) { return data.dt == DataType::F32 && data.ci.has_sve(); }, REGISTER_FP32_SVE(arm_compute::cpu::fp32_sve_scale) }, { "qasymm8_sve_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::QASYMM8; }, + [](const ScaleSelectorData & data) { return data.dt == DataType::QASYMM8 && data.ci.has_sve(); }, REGISTER_QASYMM8_SVE(arm_compute::cpu::qasymm8_sve_scale) }, { "qasymm8_signed_sve_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; }, + [](const ScaleSelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.ci.has_sve(); }, REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::qasymm8_signed_sve_scale) }, { "u8_sve_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::U8; }, + [](const ScaleSelectorData & data) { return data.dt == DataType::U8 && data.ci.has_sve(); }, REGISTER_INTEGER_SVE(arm_compute::cpu::u8_sve_scale) }, { "s16_sve_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::S16; }, + [](const ScaleSelectorData & data) { return data.dt == DataType::S16 && data.ci.has_sve(); }, REGISTER_INTEGER_SVE(arm_compute::cpu::s16_sve_scale) }, -#endif /* defined(ENABLE_SVE) */ -#if defined(ENABLE_NEON) +#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */ +#if defined(ARM_COMPUTE_ENABLE_NEON) #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) { "common_neon_scale", - [](const ScaleSelectorData & data) { return data.dt == DataType::F16; }, + [](const ScaleSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_fp16(); }, REGISTER_FP16_NEON(arm_compute::cpu::common_neon_scale<float16_t>) }, #endif /* !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ @@ -129,7 +130,7 @@ static const ScaleKernel available_kernels[] = [](const ScaleSelectorData & data) { return data.dt == DataType::S16; }, REGISTER_INTEGER_NEON(arm_compute::cpu::common_neon_scale<int16_t>) }, -#endif /* defined(ENABLE_NEON) */ +#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */ }; /** Micro-kernel selector @@ -153,7 +154,7 @@ const ScaleKernel *get_implementation(const ScaleSelectorData &data) Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dx, const ITensorInfo *dy, const ITensorInfo *offsets, ITensorInfo *dst, const ScaleKernelInfo &info) { - const auto *uk = get_implementation(ScaleSelectorData{ src->data_type() }); + const auto *uk = get_implementation(ScaleSelectorData{ src->data_type(), CPUInfo::get() }); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(dst); @@ -607,7 +608,7 @@ void CpuScaleKernel::run_op(ITensorPack &tensors, const Window &window, const Th } else { - const auto *uk = get_implementation(ScaleSelectorData{ src->info()->data_type() }); + const auto *uk = get_implementation(ScaleSelectorData{ src->info()->data_type(), CPUInfo::get() }); uk->ukernel(src, dst, offsets, dx, dy, _policy, _border_mode, _constant_border_value, _sampling_offset, _align_corners, window); } } |