aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuDirectConv3dKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuDirectConv3dKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuDirectConv3dKernel.cpp28
1 files changed, 22 insertions, 6 deletions
diff --git a/src/cpu/kernels/CpuDirectConv3dKernel.cpp b/src/cpu/kernels/CpuDirectConv3dKernel.cpp
index 595b5f1330..4f47787c93 100644
--- a/src/cpu/kernels/CpuDirectConv3dKernel.cpp
+++ b/src/cpu/kernels/CpuDirectConv3dKernel.cpp
@@ -76,6 +76,16 @@ static const DirectConv3dKernel available_kernels[] =
"neon_fp32_directconv3d",
[](const DirectConv3dSelectorData & data) { return data.dt == DataType::F32; },
REGISTER_FP32_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float>)
+ },
+ {
+ "neon_qasymm8_directconv3d",
+ [](const DirectConv3dSelectorData & data) { return data.dt == DataType::QASYMM8; },
+ REGISTER_QASYMM8_NEON(arm_compute::cpu::directconv3d_quantized_neon_ndhwc<uint8_t>)
+ },
+ {
+ "neon_qasymm8_signed_directconv3d",
+ [](const DirectConv3dSelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; },
+ REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::directconv3d_quantized_neon_ndhwc<int8_t>)
}
};
@@ -105,7 +115,7 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
ARM_COMPUTE_RETURN_ERROR_ON(src0->data_layout() != DataLayout::NDHWC);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src0);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1);
const DataLayout data_layout = src0->data_layout();
@@ -117,10 +127,16 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
if(src2 != nullptr)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->dimension(0) != src1->dimension(0),
- "biases size and number of output feature maps should match");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->num_dimensions() > 1, "biases should be one dimensional");
+ if(is_data_type_quantized(src0->data_type()))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2, 1, DataType::S32);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2);
+ }
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->dimension(0) != src1->dimension(0), "Biases size and number of dst feature maps should match");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->num_dimensions() > 1, "Biases should be one dimensional");
}
// Checks performed when output is configured
@@ -136,7 +152,7 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
return Status{};
}
-}
+} // namespace
void CpuDirectConv3dKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv_info)
{