diff options
Diffstat (limited to 'src/cpu/kernels/CpuScaleKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuScaleKernel.cpp | 33 |
1 files changed, 24 insertions, 9 deletions
diff --git a/src/cpu/kernels/CpuScaleKernel.cpp b/src/cpu/kernels/CpuScaleKernel.cpp index e7386a385a..b8bb5ad18a 100644 --- a/src/cpu/kernels/CpuScaleKernel.cpp +++ b/src/cpu/kernels/CpuScaleKernel.cpp @@ -25,14 +25,9 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/Window.h" -#include "arm_compute/core/utils/misc/Utility.h" -#include "src/core/CPP/Validate.h" -#include "src/core/NEON/wrapper/wrapper.h" #include "src/core/common/Registrars.h" -#include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/ScaleHelpers.h" #include "src/core/helpers/WindowHelpers.h" -#include "src/core/utils/ScaleUtils.h" #include "src/cpu/kernels/scale/neon/list.h" #include "src/cpu/kernels/scale/sve/list.h" #include "support/Rounding.h" @@ -68,22 +63,34 @@ static const std::vector<CpuScaleKernel::ScaleKernel> available_kernels = }, { "sve_qu8_scale", - [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve; }, + [](const ScaleKernelDataTypeISASelectorData & data) + { + return data.dt == DataType::QASYMM8 && data.isa.sve && data.interpolation_policy != InterpolationPolicy::BILINEAR; + }, REGISTER_QASYMM8_SVE(arm_compute::cpu::qasymm8_sve_scale) }, { "sve_qs8_scale", - [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve; }, + [](const ScaleKernelDataTypeISASelectorData & data) + { + return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve && data.interpolation_policy != InterpolationPolicy::BILINEAR; + }, REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::qasymm8_signed_sve_scale) }, { "sve_u8_scale", - [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::U8 && data.isa.sve; }, + [](const ScaleKernelDataTypeISASelectorData & data) + { + return data.dt == DataType::U8 && data.isa.sve && data.interpolation_policy != InterpolationPolicy::BILINEAR; + }, REGISTER_INTEGER_SVE(arm_compute::cpu::u8_sve_scale) }, { "sve_s16_scale", - [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::S16 && data.isa.sve; }, + [](const ScaleKernelDataTypeISASelectorData & data) + { + return data.dt == DataType::S16 && data.isa.sve && data.interpolation_policy != InterpolationPolicy::BILINEAR; + }, REGISTER_INTEGER_SVE(arm_compute::cpu::s16_sve_scale) }, { @@ -112,6 +119,11 @@ static const std::vector<CpuScaleKernel::ScaleKernel> available_kernels = REGISTER_INTEGER_NEON(arm_compute::cpu::u8_neon_scale) }, { + "neon_s8_scale", + [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::S8; }, + REGISTER_INTEGER_NEON(arm_compute::cpu::s8_neon_scale) + }, + { "neon_s16_scale", [](const ScaleKernelDataTypeISASelectorData & data) { return data.dt == DataType::S16; }, REGISTER_INTEGER_NEON(arm_compute::cpu::s16_neon_scale) @@ -140,6 +152,9 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dx, const I ARM_COMPUTE_RETURN_ERROR_ON(output_width == 0); ARM_COMPUTE_RETURN_ERROR_ON(output_height == 0); + ARM_COMPUTE_RETURN_ERROR_ON((src->data_type() == DataType::S8) && (data_layout != DataLayout::NHWC || info.interpolation_policy != InterpolationPolicy::BILINEAR + || info.border_mode != BorderMode::REPLICATE)); + if(info.interpolation_policy == InterpolationPolicy::NEAREST_NEIGHBOR && offsets != nullptr) { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(offsets, 1, DataType::S32); |