diff options
Diffstat (limited to 'src/cpu/kernels/CpuDirectConv3dKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuDirectConv3dKernel.cpp | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/src/cpu/kernels/CpuDirectConv3dKernel.cpp b/src/cpu/kernels/CpuDirectConv3dKernel.cpp index 595b5f1330..4f47787c93 100644 --- a/src/cpu/kernels/CpuDirectConv3dKernel.cpp +++ b/src/cpu/kernels/CpuDirectConv3dKernel.cpp @@ -76,6 +76,16 @@ static const DirectConv3dKernel available_kernels[] = "neon_fp32_directconv3d", [](const DirectConv3dSelectorData & data) { return data.dt == DataType::F32; }, REGISTER_FP32_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float>) + }, + { + "neon_qasymm8_directconv3d", + [](const DirectConv3dSelectorData & data) { return data.dt == DataType::QASYMM8; }, + REGISTER_QASYMM8_NEON(arm_compute::cpu::directconv3d_quantized_neon_ndhwc<uint8_t>) + }, + { + "neon_qasymm8_signed_directconv3d", + [](const DirectConv3dSelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; }, + REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::directconv3d_quantized_neon_ndhwc<int8_t>) } }; @@ -105,7 +115,7 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst); ARM_COMPUTE_RETURN_ERROR_ON(src0->data_layout() != DataLayout::NDHWC); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src0); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32, DataType::QASYMM8, DataType::QASYMM8_SIGNED); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1); const DataLayout data_layout = src0->data_layout(); @@ -117,10 +127,16 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons if(src2 != nullptr) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->dimension(0) != src1->dimension(0), - "biases size and number of output feature maps should match"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->num_dimensions() > 1, "biases should be one dimensional"); + if(is_data_type_quantized(src0->data_type())) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2, 1, DataType::S32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2); + } + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->dimension(0) != src1->dimension(0), "Biases size and number of dst feature maps should match"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->num_dimensions() > 1, "Biases should be one dimensional"); } // Checks performed when output is configured @@ -136,7 +152,7 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons return Status{}; } -} +} // namespace void CpuDirectConv3dKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv_info) { |