aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuScaleKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuScaleKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuScaleKernel.cpp34
1 files changed, 20 insertions, 14 deletions
diff --git a/src/cpu/kernels/CpuScaleKernel.cpp b/src/cpu/kernels/CpuScaleKernel.cpp
index e230dfa938..c9e858fc02 100644
--- a/src/cpu/kernels/CpuScaleKernel.cpp
+++ b/src/cpu/kernels/CpuScaleKernel.cpp
@@ -52,62 +52,68 @@ static const std::vector<CpuScaleKernel::ScaleKernel> available_kernels =
{
{
"sve_fp16_scale",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16; },
+ [](const ScaleKernelDataTypeISASelectorData & data)
+ {
+ return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16 && data.interpolation_policy != InterpolationPolicy::BILINEAR;
+ },
REGISTER_FP16_SVE(arm_compute::cpu::fp16_sve_scale)
},
{
"sve_fp32_scale",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::F32 && data.isa.sve; },
+ [](const ScaleKernelDataTypeISASelectorData & data)
+ {
+ return data.dt == DataType::F32 && data.isa.sve && data.interpolation_policy != InterpolationPolicy::BILINEAR;
+ },
REGISTER_FP32_SVE(arm_compute::cpu::fp32_sve_scale)
},
{
"sve_qu8_scale",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve; },
+ [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve; },
REGISTER_QASYMM8_SVE(arm_compute::cpu::qasymm8_sve_scale)
},
{
"sve_qs8_scale",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve; },
+ [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve; },
REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::qasymm8_signed_sve_scale)
},
{
"sve_u8_scale",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::U8 && data.isa.sve; },
+ [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::U8 && data.isa.sve; },
REGISTER_INTEGER_SVE(arm_compute::cpu::u8_sve_scale)
},
{
"sve_s16_scale",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::S16 && data.isa.sve; },
+ [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::S16 && data.isa.sve; },
REGISTER_INTEGER_SVE(arm_compute::cpu::s16_sve_scale)
},
{
"neon_fp16_scale",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.fp16; },
+ [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.fp16; },
REGISTER_FP16_NEON(arm_compute::cpu::common_neon_scale<float16_t>)
},
{
"neon_fp32_scale",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::F32; },
+ [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::F32; },
REGISTER_FP32_NEON(arm_compute::cpu::common_neon_scale<float>)
},
{
"neon_qu8_scale",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8; },
+ [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8; },
REGISTER_QASYMM8_NEON(arm_compute::cpu::qasymm8_neon_scale)
},
{
"neon_qs8_scale",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; },
+ [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; },
REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::qasymm8_signed_neon_scale)
},
{
"neon_u8_scale",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::U8; },
+ [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::U8; },
REGISTER_INTEGER_NEON(arm_compute::cpu::u8_neon_scale)
},
{
"neon_s16_scale",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::S16; },
+ [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::S16; },
REGISTER_INTEGER_NEON(arm_compute::cpu::s16_neon_scale)
},
};
@@ -115,7 +121,7 @@ static const std::vector<CpuScaleKernel::ScaleKernel> available_kernels =
Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dx, const ITensorInfo *dy,
const ITensorInfo *offsets, ITensorInfo *dst, const ScaleKernelInfo &info)
{
- const auto *uk = CpuScaleKernel::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() });
+ const auto *uk = CpuScaleKernel::get_implementation(ScaleKernelDataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa(), info.interpolation_policy });
ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
@@ -174,7 +180,7 @@ void CpuScaleKernel::configure(const ITensorInfo *src, const ITensorInfo *dx, co
dst,
info));
- const auto *uk = CpuScaleKernel::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() });
+ const auto *uk = CpuScaleKernel::get_implementation(ScaleKernelDataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa(), info.interpolation_policy });
ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
_run_method = uk->ukernel;