diff options
Diffstat (limited to 'src/cpu/kernels/CpuScaleKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuScaleKernel.cpp | 34 |
1 files changed, 20 insertions, 14 deletions
diff --git a/src/cpu/kernels/CpuScaleKernel.cpp b/src/cpu/kernels/CpuScaleKernel.cpp index e230dfa938..c9e858fc02 100644 --- a/src/cpu/kernels/CpuScaleKernel.cpp +++ b/src/cpu/kernels/CpuScaleKernel.cpp @@ -52,62 +52,68 @@ static const std::vector<CpuScaleKernel::ScaleKernel> available_kernels = { { "sve_fp16_scale", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16; }, + [](const ScaleKernelDataTypeISASelectorData & data) + { + return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16 && data.interpolation_policy != InterpolationPolicy::BILINEAR; + }, REGISTER_FP16_SVE(arm_compute::cpu::fp16_sve_scale) }, { "sve_fp32_scale", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::F32 && data.isa.sve; }, + [](const ScaleKernelDataTypeISASelectorData & data) + { + return data.dt == DataType::F32 && data.isa.sve && data.interpolation_policy != InterpolationPolicy::BILINEAR; + }, REGISTER_FP32_SVE(arm_compute::cpu::fp32_sve_scale) }, { "sve_qu8_scale", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve; }, + [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve; }, REGISTER_QASYMM8_SVE(arm_compute::cpu::qasymm8_sve_scale) }, { "sve_qs8_scale", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve; }, + [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve; }, REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::qasymm8_signed_sve_scale) }, { "sve_u8_scale", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::U8 && data.isa.sve; }, + [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::U8 && data.isa.sve; }, REGISTER_INTEGER_SVE(arm_compute::cpu::u8_sve_scale) }, { "sve_s16_scale", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::S16 && data.isa.sve; }, + [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::S16 && data.isa.sve; }, REGISTER_INTEGER_SVE(arm_compute::cpu::s16_sve_scale) }, { "neon_fp16_scale", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.fp16; }, + [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.fp16; }, REGISTER_FP16_NEON(arm_compute::cpu::common_neon_scale<float16_t>) }, { "neon_fp32_scale", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::F32; }, + [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::F32; }, REGISTER_FP32_NEON(arm_compute::cpu::common_neon_scale<float>) }, { "neon_qu8_scale", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8; }, + [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8; }, REGISTER_QASYMM8_NEON(arm_compute::cpu::qasymm8_neon_scale) }, { "neon_qs8_scale", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; }, + [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::qasymm8_signed_neon_scale) }, { "neon_u8_scale", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::U8; }, + [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::U8; }, REGISTER_INTEGER_NEON(arm_compute::cpu::u8_neon_scale) }, { "neon_s16_scale", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::S16; }, + [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::S16; }, REGISTER_INTEGER_NEON(arm_compute::cpu::s16_neon_scale) }, }; @@ -115,7 +121,7 @@ static const std::vector<CpuScaleKernel::ScaleKernel> available_kernels = Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dx, const ITensorInfo *dy, const ITensorInfo *offsets, ITensorInfo *dst, const ScaleKernelInfo &info) { - const auto *uk = CpuScaleKernel::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() }); + const auto *uk = CpuScaleKernel::get_implementation(ScaleKernelDataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa(), info.interpolation_policy }); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); @@ -174,7 +180,7 @@ void CpuScaleKernel::configure(const ITensorInfo *src, const ITensorInfo *dx, co dst, info)); - const auto *uk = CpuScaleKernel::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() }); + const auto *uk = CpuScaleKernel::get_implementation(ScaleKernelDataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa(), info.interpolation_policy }); ARM_COMPUTE_ERROR_ON_NULLPTR(uk); _run_method = uk->ukernel; |