diff options
Diffstat (limited to 'src/cpu/kernels/CpuPool3dKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuPool3dKernel.cpp | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/src/cpu/kernels/CpuPool3dKernel.cpp b/src/cpu/kernels/CpuPool3dKernel.cpp index 3321967d2f..1305f7c5e8 100644 --- a/src/cpu/kernels/CpuPool3dKernel.cpp +++ b/src/cpu/kernels/CpuPool3dKernel.cpp @@ -44,11 +44,20 @@ using namespace misc::shape_calculator; static const std::vector<CpuPool3dKernel::Pooling3dKernel> available_kernels = { { + "neon_qu8_ndhwc_poolMxNxD", + [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); }, + REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_q8_pool3d) + }, + { + "neon_qs8_ndhwc_poolMxNxD", + [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, + REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_q8_signed_pool3d) + }, + { "neon_fp16_ndhwc_poolMxNxD", [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16 && data.isa.fp16); }, REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_pool3d) }, - { "neon_fp32_ndhwc_poolMxNxD", [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); }, @@ -61,7 +70,11 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst); ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_layout() != DataLayout::NDHWC, "Only NDHWC layout supported"); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32, DataType::QASYMM8, DataType::QASYMM8_SIGNED); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG((!is_data_type_float(src->data_type())) && (!pool_info.exclude_padding + && (pool_info.pool_type == PoolingType::AVG)), + "Exclude padding is unsupported for non-float types for Avg op"); const auto data_layout = src->data_layout(); const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); |