aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuDirectConv3dKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuDirectConv3dKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuDirectConv3dKernel.cpp87
1 files changed, 47 insertions, 40 deletions
diff --git a/src/cpu/kernels/CpuDirectConv3dKernel.cpp b/src/cpu/kernels/CpuDirectConv3dKernel.cpp
index 22c60cd994..b5b2aed1ba 100644
--- a/src/cpu/kernels/CpuDirectConv3dKernel.cpp
+++ b/src/cpu/kernels/CpuDirectConv3dKernel.cpp
@@ -29,12 +29,13 @@
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "src/core/CPP/Validate.h"
-#include "src/core/NEON/wrapper/wrapper.h"
+#include "arm_compute/core/Validate.h"
+
#include "src/core/common/Registrars.h"
+#include "src/core/CPP/Validate.h"
#include "src/core/helpers/AutoConfiguration.h"
+#include "src/core/NEON/wrapper/wrapper.h"
#include "src/cpu/kernels/conv3d/neon/list.h"
#include <algorithm>
@@ -49,43 +50,37 @@ namespace kernels
{
namespace
{
-static const std::vector<CpuDirectConv3dKernel::DirectConv3dKernel> available_kernels =
-{
+static const std::vector<CpuDirectConv3dKernel::DirectConv3dKernel> available_kernels = {
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
- {
- "neon_fp16_directconv3d",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.fp16; },
- REGISTER_FP16_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float16_t>)
- },
+ {"neon_fp16_directconv3d",
+ [](const DataTypeISASelectorData &data) { return data.dt == DataType::F16 && data.isa.fp16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float16_t>)},
#endif /* !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
- {
- "neon_fp32_directconv3d",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::F32; },
- REGISTER_FP32_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float>)
- },
- {
- "neon_qasymm8_directconv3d",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8; },
- REGISTER_QASYMM8_NEON(arm_compute::cpu::directconv3d_quantized_neon_ndhwc<uint8_t>)
- },
- {
- "neon_qasymm8_signed_directconv3d",
- [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; },
- REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::directconv3d_quantized_neon_ndhwc<int8_t>)
- }
-};
-
-Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv_info)
+ {"neon_fp32_directconv3d", [](const DataTypeISASelectorData &data) { return data.dt == DataType::F32; },
+ REGISTER_FP32_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float>)},
+ {"neon_qasymm8_directconv3d", [](const DataTypeISASelectorData &data) { return data.dt == DataType::QASYMM8; },
+ REGISTER_QASYMM8_NEON(arm_compute::cpu::directconv3d_quantized_neon_ndhwc<uint8_t>)},
+ {"neon_qasymm8_signed_directconv3d",
+ [](const DataTypeISASelectorData &data) { return data.dt == DataType::QASYMM8_SIGNED; },
+ REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::directconv3d_quantized_neon_ndhwc<int8_t>)}};
+
+Status validate_arguments(const ITensorInfo *src0,
+ const ITensorInfo *src1,
+ const ITensorInfo *src2,
+ const ITensorInfo *dst,
+ const Conv3dInfo &conv_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
ARM_COMPUTE_RETURN_ERROR_ON(src0->data_layout() != DataLayout::NDHWC);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src0, src1, dst);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src0);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32, DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1);
ARM_COMPUTE_RETURN_ERROR_ON(conv_info.dilation != Size3D(1U, 1U, 1U));
- const auto *uk = CpuDirectConv3dKernel::get_implementation(DataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa() });
+ const auto *uk =
+ CpuDirectConv3dKernel::get_implementation(DataTypeISASelectorData{src0->data_type(), CPUInfo::get().get_isa()});
ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
@@ -96,9 +91,9 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
ARM_COMPUTE_RETURN_ERROR_ON(src1->num_dimensions() > 5);
ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(1) != src0->dimension(channel_idx));
- if(src2 != nullptr)
+ if (src2 != nullptr)
{
- if(is_data_type_quantized(src0->data_type()))
+ if (is_data_type_quantized(src0->data_type()))
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2, 1, DataType::S32);
}
@@ -106,14 +101,16 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2);
}
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->dimension(0) != src1->dimension(0), "Biases size and number of dst feature maps should match");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->dimension(0) != src1->dimension(0),
+ "Biases size and number of dst feature maps should match");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->num_dimensions() > 1, "Biases should be one dimensional");
}
// Checks performed when output is configured
- if(dst->total_size() != 0)
+ if (dst->total_size() != 0)
{
- TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv_info);
+ TensorShape output_shape =
+ misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv_info);
DataType data_type = src0->data_type();
@@ -125,12 +122,17 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
}
} // namespace
-void CpuDirectConv3dKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv_info)
+void CpuDirectConv3dKernel::configure(const ITensorInfo *src0,
+ const ITensorInfo *src1,
+ const ITensorInfo *src2,
+ ITensorInfo *dst,
+ const Conv3dInfo &conv_info)
{
ARM_COMPUTE_UNUSED(src2);
ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
- const auto *uk = CpuDirectConv3dKernel::get_implementation(DataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa() });
+ const auto *uk =
+ CpuDirectConv3dKernel::get_implementation(DataTypeISASelectorData{src0->data_type(), CPUInfo::get().get_isa()});
ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
@@ -139,7 +141,8 @@ void CpuDirectConv3dKernel::configure(const ITensorInfo *src0, const ITensorInfo
_name = std::string("CpuDirectConv3dKernel").append("/").append(uk->name);
// Get convolved dimensions
- TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv_info);
+ TensorShape output_shape =
+ misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv_info);
DataType data_type = src0->data_type();
@@ -154,7 +157,11 @@ void CpuDirectConv3dKernel::configure(const ITensorInfo *src0, const ITensorInfo
ICpuKernel::configure(win);
}
-Status CpuDirectConv3dKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv_info)
+Status CpuDirectConv3dKernel::validate(const ITensorInfo *src0,
+ const ITensorInfo *src1,
+ const ITensorInfo *src2,
+ const ITensorInfo *dst,
+ const Conv3dInfo &conv_info)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, src2, dst, conv_info));
@@ -188,4 +195,4 @@ const std::vector<CpuDirectConv3dKernel::DirectConv3dKernel> &CpuDirectConv3dKer
} // namespace kernels
} // namespace cpu
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute