aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuPool3dKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuPool3dKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuPool3dKernel.cpp17
1 files changed, 15 insertions, 2 deletions
diff --git a/src/cpu/kernels/CpuPool3dKernel.cpp b/src/cpu/kernels/CpuPool3dKernel.cpp
index 3321967d2f..1305f7c5e8 100644
--- a/src/cpu/kernels/CpuPool3dKernel.cpp
+++ b/src/cpu/kernels/CpuPool3dKernel.cpp
@@ -44,11 +44,20 @@ using namespace misc::shape_calculator;
static const std::vector<CpuPool3dKernel::Pooling3dKernel> available_kernels =
{
{
+ "neon_qu8_ndhwc_poolMxNxD",
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); },
+ REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_q8_pool3d)
+ },
+ {
+ "neon_qs8_ndhwc_poolMxNxD",
+ [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
+ REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_q8_signed_pool3d)
+ },
+ {
"neon_fp16_ndhwc_poolMxNxD",
[](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16 && data.isa.fp16); },
REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_pool3d)
},
-
{
"neon_fp32_ndhwc_poolMxNxD",
[](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); },
@@ -61,7 +70,11 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_layout() != DataLayout::NDHWC, "Only NDHWC layout supported");
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((!is_data_type_float(src->data_type())) && (!pool_info.exclude_padding
+ && (pool_info.pool_type == PoolingType::AVG)),
+ "Exclude padding is unsupported for non-float types for Avg op");
const auto data_layout = src->data_layout();
const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);