aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSheri Zhang <sheri.zhang@arm.com>2021-10-15 19:54:17 +0100
committerSheri Zhang <sheri.zhang@arm.com>2021-10-18 17:36:47 +0000
commit5dda2177800009b24e31550ed849b1ef3fca6167 (patch)
treedfce69d52db6111d6751f4ee4add6ab172a3290d
parentc9cecc0e565e7b4978cecc92e03e6c93bb8d0cb9 (diff)
downloadComputeLibrary-5dda2177800009b24e31550ed849b1ef3fca6167.tar.gz
DirectConv3d support refine
- Decouple data support of CpuDirectConv3dKernel - Update documentation for Conv3d Signed-off-by: Sheri Zhang <sheri.zhang@arm.com> Change-Id: I1d94aa28f821f45a1a3d39cc3335c8faeee89f0d Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6453 Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/runtime/NEON/functions/NEConv3D.h1
-rw-r--r--docs/user_guide/data_layout.dox3
-rw-r--r--docs/user_guide/operator_list.dox3
-rw-r--r--docs/user_guide/release_version_and_change_log.dox8
-rw-r--r--src/core/NEON/NEMath.h15
-rw-r--r--src/core/NEON/NEMath.inl27
-rw-r--r--src/cpu/kernels/CpuDirectConv2dKernel.cpp11
-rw-r--r--src/cpu/kernels/CpuDirectConv3dKernel.cpp303
-rw-r--r--src/cpu/kernels/CpuDirectConv3dKernel.h24
-rw-r--r--src/cpu/kernels/conv3d/neon/list.h176
-rw-r--r--src/cpu/operators/CpuDirectConv3d.cpp18
-rw-r--r--src/cpu/operators/CpuDirectConv3d.h22
-rw-r--r--src/gpu/cl/kernels/ClDirectConv3dKernel.cpp66
-rw-r--r--src/gpu/cl/kernels/ClDirectConv3dKernel.h10
-rw-r--r--src/gpu/cl/operators/ClDirectConv3d.cpp10
-rw-r--r--src/gpu/cl/operators/ClDirectConv3d.h10
-rw-r--r--src/runtime/NEON/functions/NEConv3D.cpp5
17 files changed, 400 insertions, 312 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEConv3D.h b/arm_compute/runtime/NEON/functions/NEConv3D.h
index 487d357fa1..2b3a45f0af 100644
--- a/arm_compute/runtime/NEON/functions/NEConv3D.h
+++ b/arm_compute/runtime/NEON/functions/NEConv3D.h
@@ -29,7 +29,6 @@
#include "arm_compute/core/ITensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/FunctionDescriptors.h"
-#include "arm_compute/runtime/MemoryGroup.h"
#include <memory>
diff --git a/docs/user_guide/data_layout.dox b/docs/user_guide/data_layout.dox
index 97d3ea6262..ae69bbf457 100644
--- a/docs/user_guide/data_layout.dox
+++ b/docs/user_guide/data_layout.dox
@@ -34,8 +34,9 @@ the right-most letter represents the fastest changing dimension:
- NHWC: The native layout of Compute Library that delivers the best performance where channels are in the fastest changing dimension
- NCHW: Legacy layout where width is in the fastest changing dimension
+- NDHWC: New data layout for supporting 3D operators
-, where N = batch, C = channel, H = height, W = width.
+, where N = batch, C = channel, H = height, W = width, D = depth.
*/
} // namespace
diff --git a/docs/user_guide/operator_list.dox b/docs/user_guide/operator_list.dox
index ebc970d8c1..1d06a394a9 100644
--- a/docs/user_guide/operator_list.dox
+++ b/docs/user_guide/operator_list.dox
@@ -52,9 +52,10 @@ Compute Library supports the following data layouts (fast changing dimension fro
<ul>
<li>NHWC: The native layout of Compute Library that delivers the best performance where channels are in the fastest changing dimension
<li>NCHW: Legacy layout where width is in the fastest changing dimension
+ <li>NDHWC: New data layout for supporting 3D operators
<li>All: Agnostic to any specific data layout
</ul>
-where N = batches, C = channels, H = height, W = width
+where N = batches, C = channels, H = height, W = width, D = depth
<table>
<caption id="multi_row"></caption>
diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox
index 583cf4fb82..2470b45203 100644
--- a/docs/user_guide/release_version_and_change_log.dox
+++ b/docs/user_guide/release_version_and_change_log.dox
@@ -40,6 +40,14 @@ If there is more than one release in a month then an extra sequential number is
@section S2_2_changelog Changelog
+v21.11 Public major release
+ - Various bug fixes.
+ - Various optimizations.
+ - New OpenCL kernels / functions:
+ - @ref CLConv3D
+ - New Arm® Neon™ kernels / functions:
+ - @ref NEConv3D
+
v21.08 Public major release
- Various bug fixes.
- Various optimizations:
diff --git a/src/core/NEON/NEMath.h b/src/core/NEON/NEMath.h
index 13484c9c15..8118c4701f 100644
--- a/src/core/NEON/NEMath.h
+++ b/src/core/NEON/NEMath.h
@@ -239,6 +239,14 @@ float32x4_t vsinq_f32(float32x4_t val);
*/
float32x2_t vsin_f32(float32x2_t val);
+/** Reduce a vector to be a scalar by accumulating all lanes in the vector
+ *
+ * @param[in] v Vector to be reduced.
+ *
+ * @return the wrapped-around number.
+ */
+float vreduce(const float32x4_t &v);
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/** Calculate hyperbolic tangent.
*
@@ -319,6 +327,13 @@ float16x8_t vpowq_f16(float16x8_t val, float16x8_t n);
*/
float16x8_t vsinq_f16(float16x8_t val);
+/** Reduce a vector to be a scalar by accumulating all lanes in the vector
+ *
+ * @param[in] v Vector to be reduced.
+ *
+ * @return the wrapped-around number.
+ */
+float16_t vreduce(const float16x8_t &v);
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
} // namespace arm_compute
#include "src/core/NEON/NEMath.inl"
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl
index 5ac62badcc..05cf3013bc 100644
--- a/src/core/NEON/NEMath.inl
+++ b/src/core/NEON/NEMath.inl
@@ -193,7 +193,7 @@ inline float32x4_t vtanhq_f32(float32x4_t val)
static const float32x4_t CONST_THR = vdupq_n_f32(5.e-3);
static const float32x4_t CONST_1_3 = vdupq_n_f32(0.3333333f);
- float32x4_t x = vminq_f32(vmaxq_f32(val, CONST_MIN_TANH), CONST_MAX_TANH);
+ float32x4_t x = vminq_f32(vmaxq_f32(val, CONST_MIN_TANH), CONST_MAX_TANH);
// x * (1 - x^2/3) if |x| < 5.e-3 or (exp2x - 1) / (exp2x + 1) otherwise
float32x4_t exp2x = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vexpq_f32(vmulq_f32(CONST_2, x)), vmulq_f32(x, x));
float32x4_t num = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vsubq_f32(exp2x, CONST_1), vmulq_f32(CONST_1_3, exp2x));
@@ -418,6 +418,18 @@ inline float32x4x4_t convert_int_to_float<float32x4x4_t, int8x16_t>(const int8x1
return convert_int8x16_to_float32x4x4(in);
}
+inline float vreduce(const float32x4_t &v)
+{
+ const float32x2_t v0 = vget_high_f32(v);
+ const float32x2_t v1 = vget_low_f32(v);
+ const float32x2_t v_out = vadd_f32(v0, v1);
+
+ const float a = vget_lane_f32(v_out, 0);
+ const float b = vget_lane_f32(v_out, 1);
+
+ return a + b;
+}
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/** Exponent polynomial coefficients */
/** Logarithm polynomial coefficients */
@@ -550,6 +562,19 @@ inline float16x4_t vsin_f16(float16x4_t val)
return vcvt_f16_f32(vcombine_f32(res_low, res_high));
}
+inline float16_t vreduce(const float16x8_t &v)
+{
+ const float16x4_t v0 = vget_high_f16(v);
+ const float16x4_t v1 = vget_low_f16(v);
+ const float16x4_t v_out = vadd_f16(v0, v1);
+
+ const float16_t a = vget_lane_f16(v_out, 0);
+ const float16_t b = vget_lane_f16(v_out, 1);
+ const float16_t c = vget_lane_f16(v_out, 2);
+ const float16_t d = vget_lane_f16(v_out, 3);
+
+ return a + b + c + d;
+}
#endif /* DOXYGEN_SKIP_THIS */
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
} // namespace arm_compute
diff --git a/src/cpu/kernels/CpuDirectConv2dKernel.cpp b/src/cpu/kernels/CpuDirectConv2dKernel.cpp
index db1b5f3c54..68de9803eb 100644
--- a/src/cpu/kernels/CpuDirectConv2dKernel.cpp
+++ b/src/cpu/kernels/CpuDirectConv2dKernel.cpp
@@ -711,17 +711,6 @@ public:
}
};
-float vreduce(const float32x4_t &v)
-{
- auto v0 = wrapper::vgethigh(v);
- auto v1 = wrapper::vgetlow(v);
- auto v_out = wrapper::vadd(v0, v1);
-
- float a = wrapper::vgetlane(v_out, 0);
- float b = wrapper::vgetlane(v_out, 1);
- return a + b;
-}
-
template <typename T1, typename T2>
inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
diff --git a/src/cpu/kernels/CpuDirectConv3dKernel.cpp b/src/cpu/kernels/CpuDirectConv3dKernel.cpp
index fecdb2bcae..595b5f1330 100644
--- a/src/cpu/kernels/CpuDirectConv3dKernel.cpp
+++ b/src/cpu/kernels/CpuDirectConv3dKernel.cpp
@@ -23,9 +23,6 @@
*/
#include "src/cpu/kernels/CpuDirectConv3dKernel.h"
-#include "src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h"
-#include "src/core/NEON/wrapper/wrapper.h"
-
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
@@ -35,8 +32,10 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "src/core/CPP/Validate.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/core/common/Registrars.h"
#include "src/core/helpers/AutoConfiguration.h"
-#include "src/core/helpers/WindowHelpers.h"
+#include "src/cpu/kernels/conv3d/neon/list.h"
#include <algorithm>
@@ -50,236 +49,126 @@ namespace kernels
{
namespace
{
-Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv_info)
+struct DirectConv3dSelectorData
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
- ARM_COMPUTE_RETURN_ERROR_ON(src->data_layout() != DataLayout::NDHWC);
- ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights);
-
- const DataLayout data_layout = src->data_layout();
- const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
-
- // Weight layout is D, H, W, Cin, Cout
- ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 5);
- ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(1) != src->dimension(channel_idx));
+ DataType dt;
+ const CPUInfo &ci;
+};
+using DirectConv3dSelectorPtr = std::add_pointer<bool(const DirectConv3dSelectorData &data)>::type;
+using DirectConv3dKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, const ITensor *, ITensor *, const Conv3dInfo &, const Window &)>::type;
+struct DirectConv3dKernel
+{
+ const char *name;
+ const DirectConv3dSelectorPtr is_selected;
+ DirectConv3dKernelPtr ukernel;
+};
- if(biases != nullptr)
+static const DirectConv3dKernel available_kernels[] =
+{
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->dimension(0) != weights->dimension(0),
- "biases size and number of output feature maps should match");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->num_dimensions() > 1, "biases should be one dimensional");
- }
-
- // Checks performed when output is configured
- if(dst->total_size() != 0)
+ "neon_fp16_directconv3d",
+ [](const DirectConv3dSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_fp16(); },
+ REGISTER_FP16_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float16_t>)
+ },
+#endif /* !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
{
- TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src->tensor_shape(), weights->tensor_shape(), conv_info);
-
- DataType data_type = src->data_type();
-
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), output_shape);
- ARM_COMPUTE_RETURN_ERROR_ON(dst->data_type() != data_type);
+ "neon_fp32_directconv3d",
+ [](const DirectConv3dSelectorData & data) { return data.dt == DataType::F32; },
+ REGISTER_FP32_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc<float>)
}
+};
- return Status{};
-}
-
-/** Reduce a vector to be a scalar by accumulating all lanes in the vector
+/** Micro-kernel selector
*
- * @param[in] v Vector to be reduced.
+ * @param[in] data Selection data passed to help pick the appropriate micro-kernel
*
- * @return the wrapped-around number.
+ * @return A matching micro-kernel else nullptr
*/
-auto vreduce(const float32x4_t &v)
+const DirectConv3dKernel *get_implementation(const DirectConv3dSelectorData &data)
{
- auto v0 = wrapper::vgethigh(v);
- auto v1 = wrapper::vgetlow(v);
- auto v_out = wrapper::vadd(v0, v1);
-
- float a = wrapper::vgetlane(v_out, 0);
- float b = wrapper::vgetlane(v_out, 1);
- return a + b;
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-auto vreduce(const float16x8_t &v)
-{
- auto v0 = wrapper::vgethigh(v);
- auto v1 = wrapper::vgetlow(v);
- auto v_out = wrapper::vadd(v0, v1);
-
- float16_t a = wrapper::vgetlane(v_out, 0);
- float16_t b = wrapper::vgetlane(v_out, 1);
- float16_t c = wrapper::vgetlane(v_out, 2);
- float16_t d = wrapper::vgetlane(v_out, 3);
- return a + b + c + d;
-}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ for(const auto &uk : available_kernels)
+ {
+ if(uk.is_selected(data))
+ {
+ return &uk;
+ }
+ }
+ return nullptr;
}
-template <typename T>
-void CpuDirectConv3dKernel::convolve_ndhwc(const Window &window, const ITensor *src, const ITensor *weights, const ITensor *biases, ITensor *dst)
+Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv_info)
{
- using vtype = wrapper::traits::neon_bitvector<T, wrapper::traits::BitWidth::W128>;
- using vector_type = typename vtype::type;
- using tag_type = typename vtype::tag_type;
- constexpr int num_elems_read_per_iteration = 16 / sizeof(T);
-
- // Scalar quantities (N D H W Cin)
- const int element_size = src->info()->element_size();
- const int input_stride_w = src->info()->strides_in_bytes().y() / element_size;
- const int input_stride_h = src->info()->strides_in_bytes().z() / element_size;
- const int input_stride_d = src->info()->strides_in_bytes()[3] / element_size;
- const int input_stride_n = src->info()->strides_in_bytes()[4] / element_size;
- const int input_dim_w = src->info()->dimension(1);
- const int input_dim_h = src->info()->dimension(2);
- const int input_dim_d = src->info()->dimension(3);
-
- // Kernel info (D H W Cin Cout)
- const unsigned int kernel_stride_w = weights->info()->strides_in_bytes()[2] / element_size;
- const unsigned int kernel_stride_h = weights->info()->strides_in_bytes()[3] / element_size;
- const unsigned int kernel_stride_d = weights->info()->strides_in_bytes()[4] / element_size;
- const int kernel_dim_w = weights->info()->dimension(2);
- const int kernel_dim_h = weights->info()->dimension(3);
- const int kernel_dim_d = weights->info()->dimension(4);
-
- // Convolution padding and stride
- const int conv_pad_top = _conv_info.padding.top;
- const int conv_pad_left = _conv_info.padding.left;
- const int conv_pad_front = _conv_info.padding.front;
- const int conv_stride_w = _conv_info.stride.width;
- const int conv_stride_h = _conv_info.stride.height;
- const int conv_stride_d = _conv_info.stride.depth;
+ const auto *uk = get_implementation(DirectConv3dSelectorData{ src0->data_type(), CPUInfo::get() });
+ ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
- // Setup input window for the output iterator
- Window window_out = window;
- window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
+ ARM_COMPUTE_RETURN_ERROR_ON(src0->data_layout() != DataLayout::NDHWC);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src0);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1);
- // Setup input window for the weights iterator
- Window window_w = calculate_max_window(*weights->info(), Steps());
- window_w.set(Window::DimY, Window::Dimension(0, 1, 1));
- window_w.set(Window::DimZ, Window::Dimension(0, 1, 1));
- window_w.set(Window::DimW, Window::Dimension(0, 1, 1));
- window_w.set(4, Window::Dimension(0, 1, 1));
+ const DataLayout data_layout = src0->data_layout();
+ const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
- Iterator out(dst, window_out);
- Iterator wei(weights, window_w);
+ // Weight layout is D, H, W, Cin, Cout
+ ARM_COMPUTE_RETURN_ERROR_ON(src1->num_dimensions() > 5);
+ ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(1) != src0->dimension(channel_idx));
- const T *biases_ptr = nullptr;
- if(biases)
+ if(src2 != nullptr)
{
- biases_ptr = reinterpret_cast<T *>(biases->buffer() + biases->info()->offset_first_element_in_bytes());
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->dimension(0) != src1->dimension(0),
+ "biases size and number of output feature maps should match");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->num_dimensions() > 1, "biases should be one dimensional");
}
- execute_window_loop(window_out, [&](const Coordinates & id)
- {
- // We are computing the theoretical input starting points
- const int in_w_start_t = static_cast<int>(id.y()) * conv_stride_w - conv_pad_left;
- const int in_h_start_t = static_cast<int>(id.z()) * conv_stride_h - conv_pad_top;
- const int in_d_start_t = static_cast<int>(id[3]) * conv_stride_d - conv_pad_front;
- const int in_w_end_t = in_w_start_t + kernel_dim_w;
- const int in_h_end_t = in_h_start_t + kernel_dim_h;
- const int in_d_end_t = in_d_start_t + kernel_dim_d;
- // We are computing the valid initial and ending input points by checking the borders
- const int in_w_start = std::max(in_w_start_t, 0);
- const int in_h_start = std::max(in_h_start_t, 0);
- const int in_d_start = std::max(in_d_start_t, 0);
- const int in_w_end = std::min(in_w_end_t, input_dim_w);
- const int in_h_end = std::min(in_h_end_t, input_dim_h);
- const int in_d_end = std::min(in_d_end_t, input_dim_d);
+ // Checks performed when output is configured
+ if(dst->total_size() != 0)
+ {
+ TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv_info);
- // We use the input points to select the valid weight points to use
- const int wei_w_start = in_w_start - in_w_start_t;
- const int wei_h_start = in_h_start - in_h_start_t;
- const int wei_d_start = in_d_start - in_d_start_t;
- const int wei_w_end = kernel_dim_w - (in_w_end_t - in_w_end);
- const int wei_h_end = kernel_dim_h - (in_h_end_t - in_h_end);
- const int wei_d_end = kernel_dim_d - (in_d_end_t - in_d_end);
+ DataType data_type = src0->data_type();
- const int index_c_out_end = weights->info()->dimension(0);
- const int index_c_in_end = weights->info()->dimension(1);
- const T *const in_ptr_start = reinterpret_cast<const T *>(src->buffer() + src->info()->offset_first_element_in_bytes()) + id[4] * input_stride_n;
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), output_shape);
+ ARM_COMPUTE_RETURN_ERROR_ON(dst->data_type() != data_type);
+ }
- execute_window_loop(window_w, [&](const Coordinates & id_w)
- {
- /*
- * This is the loop in the weights, and it goes along OFM (output feature map)
- */
- const auto weights_ptr_start = reinterpret_cast<const T *>(wei.ptr());
- T out_temp = static_cast<T>(0);
- T *out_ptr = reinterpret_cast<T *>(out.ptr());
- for(int index_wei_d = wei_d_start, index_in_d = in_d_start; index_wei_d < wei_d_end; ++index_wei_d, ++index_in_d)
- {
- const auto in_ptr_d = in_ptr_start + index_in_d * input_stride_d;
- const auto weights_ptr_d = weights_ptr_start + index_wei_d * kernel_stride_d;
- for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h)
- {
- const T *const in_ptr_row = in_ptr_d + index_in_h * input_stride_h;
- const T *const weights_ptr_row = weights_ptr_d + index_wei_h * kernel_stride_h;
- for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w)
- {
- const T *in_ptr_mover = in_ptr_row + index_in_w * input_stride_w;
- const T *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w;
- int index_c_in = 0;
- vector_type out_temp_vec = wrapper::vdup_n(static_cast<T>(0), tag_type());
- vector_type w_vec = wrapper::vdup_n(static_cast<T>(0), tag_type());
- for(; index_c_in <= index_c_in_end - num_elems_read_per_iteration;
- index_c_in += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration)
- {
- const auto src_vec = wrapper::vloadq(in_ptr_mover);
- //Load Cin weights
- for(unsigned int k = 0; k < num_elems_read_per_iteration; ++k, weights_ptr_mover += index_c_out_end)
- {
- w_vec = wrapper::vsetlane(*weights_ptr_mover, w_vec, k);
- }
- out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec);
- }
- out_temp += vreduce(out_temp_vec);
- for(; index_c_in < index_c_in_end; ++index_c_in, ++in_ptr_mover, weights_ptr_mover += index_c_out_end)
- {
- const auto src_val = *(in_ptr_mover);
- const auto w_val = *(weights_ptr_mover);
- out_temp += src_val * w_val;
- }
- }
- }
- }
- *(reinterpret_cast<T *>(out_ptr + id_w[0])) = (biases) ? out_temp + biases_ptr[id_w[0]] : out_temp;
- },
- wei);
- },
- out);
+ return Status{};
+}
}
-void CpuDirectConv3dKernel::configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv_info)
+void CpuDirectConv3dKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv_info)
{
- ARM_COMPUTE_UNUSED(biases);
- ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst);
+ ARM_COMPUTE_UNUSED(src2);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
- _conv_info = conv_info;
+ const auto *uk = get_implementation(DirectConv3dSelectorData{ src0->data_type(), CPUInfo::get() });
+ ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
+
+ _conv_info = conv_info;
+ _run_method = uk->ukernel;
+ _name = std::string("CpuDirectConv3dKernel").append("/").append(uk->name);
// Get convolved dimensions
- TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src->tensor_shape(), weights->tensor_shape(), conv_info);
+ TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv_info);
- DataType data_type = src->data_type();
+ DataType data_type = src0->data_type();
// Output auto inizialitation if not yet initialized
auto_init_if_empty(*dst, output_shape, 1, data_type);
// Perform validation step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, biases, dst, conv_info));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, conv_info));
// Configure kernel window
Window win = calculate_max_window(*dst, Steps());
ICpuKernel::configure(win);
}
-Status CpuDirectConv3dKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv_info)
+Status CpuDirectConv3dKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv_info)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, src2, dst, conv_info));
return Status{};
}
@@ -289,35 +178,19 @@ void CpuDirectConv3dKernel::run_op(ITensorPack &tensors, const Window &window, c
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
+ ARM_COMPUTE_ERROR_ON(_run_method == nullptr);
- auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0);
- auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
- auto biases = tensors.get_const_tensor(TensorType::ACL_SRC_2);
- auto dst = tensors.get_tensor(TensorType::ACL_DST);
+ auto src0 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
+ auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+ auto src2 = tensors.get_const_tensor(TensorType::ACL_SRC_2);
+ auto dst = tensors.get_tensor(TensorType::ACL_DST);
- switch(src->info()->data_type())
- {
- case DataType::F32:
- {
- convolve_ndhwc<float>(window, src, weights, biases, dst);
- break;
- }
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- {
- convolve_ndhwc<float16_t>(window, src, weights, biases, dst);
- break;
- }
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- default:
- ARM_COMPUTE_ERROR("Data type not supported");
- break;
- }
+ _run_method(src0, src1, src2, dst, _conv_info, window);
}
const char *CpuDirectConv3dKernel::name() const
{
- return "CpuDirectConv3dKernel";
+ return _name.c_str();
}
} // namespace kernels
} // namespace cpu
diff --git a/src/cpu/kernels/CpuDirectConv3dKernel.h b/src/cpu/kernels/CpuDirectConv3dKernel.h
index c7dcb0fb5e..fc64e8518b 100644
--- a/src/cpu/kernels/CpuDirectConv3dKernel.h
+++ b/src/cpu/kernels/CpuDirectConv3dKernel.h
@@ -39,10 +39,7 @@ class CpuDirectConv3dKernel : public ICpuKernel
public:
CpuDirectConv3dKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuDirectConv3dKernel);
- /** Set the src, weights, and dst tensor info.
- *
- * Valid data layouts:
- * - NDHWC
+ /** Set the src, weights, biases and dst tensor info.
*
* Valid data type configurations:
* |src0 |src1 |src2 |dst |
@@ -50,34 +47,35 @@ public:
* |F16 |F16 |F16 |F16 |
* |F32 |F32 |F32 |F32 |
*
- * @param[in, out] src Input tensor info.
- * @param[in] weights Set of kernels to convolve the input volume.
+ * @param[in, out] src0 Input tensor info.
+ * @param[in] src1 Set of kernels to convolve the input volume.
* The 2nd dimension must be the same as the input's volume 1st dimension.
- * @param[in] biases Set of biases. Can be nullptr.
+ * @param[in] src2 Set of biases. Can be nullptr.
* @param[out] dst Output tensor info.
* The 1st dimensions must be equal to the 1st dimension of the @p kernels tensor.
* @param[in] conv_info Contains padding, stride, acitvation information.
*
*/
- void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv_info);
+ void configure(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv_info);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to CpuDirectConv3dKernel::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv_info);
+ static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv_info);
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
private:
- /* Template function for convolution NDHWC */
- template <typename T>
- void convolve_ndhwc(const Window &window, const ITensor *src, const ITensor *weights, const ITensor *biases, ITensor *dst);
+ /* Template function for convolution 3d NDHWC */
+ using DirectConv3dKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, const ITensor *, ITensor *, const Conv3dInfo &, const Window &)>::type;
- Conv3dInfo _conv_info{};
+ Conv3dInfo _conv_info{};
+ DirectConv3dKernelPtr _run_method{ nullptr };
+ std::string _name{};
};
} // namespace kernels
} // namespace cpu
diff --git a/src/cpu/kernels/conv3d/neon/list.h b/src/cpu/kernels/conv3d/neon/list.h
new file mode 100644
index 0000000000..b24785a48f
--- /dev/null
+++ b/src/cpu/kernels/conv3d/neon/list.h
@@ -0,0 +1,176 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
+#define SRC_CORE_NEON_KERNELS_CONV3D_LIST_H
+
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/Traits.h"
+#include "arm_compute/runtime/FunctionDescriptors.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/core/helpers/WindowHelpers.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+template <typename T>
+void directconv3d_float_neon_ndhwc(const ITensor *src0, const ITensor *src1, const ITensor *src2, ITensor *dst, const Conv3dInfo &conv_info, const Window &window)
+{
+ const ITensor *src = src0;
+ const ITensor *weights = src1;
+ const ITensor *biases = src2;
+
+ using vtype = wrapper::traits::neon_bitvector<T, wrapper::traits::BitWidth::W128>;
+ using vector_type = typename vtype::type;
+ using tag_type = typename vtype::tag_type;
+ constexpr int num_elems_read_per_iteration = 16 / sizeof(T);
+
+ // Scalar quantities (N D H W Cin)
+ const int element_size = src->info()->element_size();
+ const int input_stride_w = src->info()->strides_in_bytes().y() / element_size;
+ const int input_stride_h = src->info()->strides_in_bytes().z() / element_size;
+ const int input_stride_d = src->info()->strides_in_bytes()[3] / element_size;
+ const int input_stride_n = src->info()->strides_in_bytes()[4] / element_size;
+ const int input_dim_w = src->info()->dimension(1);
+ const int input_dim_h = src->info()->dimension(2);
+ const int input_dim_d = src->info()->dimension(3);
+
+ // Kernel info (D H W Cin Cout)
+ const unsigned int kernel_stride_w = weights->info()->strides_in_bytes()[2] / element_size;
+ const unsigned int kernel_stride_h = weights->info()->strides_in_bytes()[3] / element_size;
+ const unsigned int kernel_stride_d = weights->info()->strides_in_bytes()[4] / element_size;
+ const int kernel_dim_w = weights->info()->dimension(2);
+ const int kernel_dim_h = weights->info()->dimension(3);
+ const int kernel_dim_d = weights->info()->dimension(4);
+
+ // Convolution padding and stride
+ const int conv_pad_top = conv_info.padding.top;
+ const int conv_pad_left = conv_info.padding.left;
+ const int conv_pad_front = conv_info.padding.front;
+ const int conv_stride_w = conv_info.stride.width;
+ const int conv_stride_h = conv_info.stride.height;
+ const int conv_stride_d = conv_info.stride.depth;
+
+ // Setup input window for the output iterator
+ Window window_out = window;
+ window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ // Setup input window for the weights iterator
+ Window window_w = calculate_max_window(*weights->info(), Steps());
+ window_w.set(Window::DimY, Window::Dimension(0, 1, 1));
+ window_w.set(Window::DimZ, Window::Dimension(0, 1, 1));
+ window_w.set(Window::DimW, Window::Dimension(0, 1, 1));
+ window_w.set(4, Window::Dimension(0, 1, 1));
+
+ Iterator out(dst, window_out);
+ Iterator wei(weights, window_w);
+
+ const T *biases_ptr = nullptr;
+ if(biases != nullptr)
+ {
+ biases_ptr = reinterpret_cast<T *>(biases->buffer() + biases->info()->offset_first_element_in_bytes());
+ }
+ execute_window_loop(window_out, [&](const Coordinates & id)
+ {
+ // We are computing the theoretical input starting points
+ const int in_w_start_t = static_cast<int>(id.y()) * conv_stride_w - conv_pad_left;
+ const int in_h_start_t = static_cast<int>(id.z()) * conv_stride_h - conv_pad_top;
+ const int in_d_start_t = static_cast<int>(id[3]) * conv_stride_d - conv_pad_front;
+ const int in_w_end_t = in_w_start_t + kernel_dim_w;
+ const int in_h_end_t = in_h_start_t + kernel_dim_h;
+ const int in_d_end_t = in_d_start_t + kernel_dim_d;
+
+ // We are computing the valid initial and ending input points by checking the borders
+ const int in_w_start = std::max(in_w_start_t, 0);
+ const int in_h_start = std::max(in_h_start_t, 0);
+ const int in_d_start = std::max(in_d_start_t, 0);
+ const int in_w_end = std::min(in_w_end_t, input_dim_w);
+ const int in_h_end = std::min(in_h_end_t, input_dim_h);
+ const int in_d_end = std::min(in_d_end_t, input_dim_d);
+
+ // We use the input points to select the valid weight points to use
+ const int wei_w_start = in_w_start - in_w_start_t;
+ const int wei_h_start = in_h_start - in_h_start_t;
+ const int wei_d_start = in_d_start - in_d_start_t;
+ const int wei_w_end = kernel_dim_w - (in_w_end_t - in_w_end);
+ const int wei_h_end = kernel_dim_h - (in_h_end_t - in_h_end);
+ const int wei_d_end = kernel_dim_d - (in_d_end_t - in_d_end);
+
+ const int index_c_out_end = weights->info()->dimension(0);
+ const int index_c_in_end = weights->info()->dimension(1);
+ const T *const in_ptr_start = reinterpret_cast<const T *>(src->buffer() + src->info()->offset_first_element_in_bytes()) + id[4] * input_stride_n;
+
+ execute_window_loop(window_w, [&](const Coordinates & id_w)
+ {
+ /*
+ * This is the loop in the weights, and it goes along OFM (output feature map)
+ */
+ const auto weights_ptr_start = reinterpret_cast<const T *>(wei.ptr());
+ T out_temp = static_cast<T>(0);
+ T *out_ptr = reinterpret_cast<T *>(out.ptr());
+ for(int index_wei_d = wei_d_start, index_in_d = in_d_start; index_wei_d < wei_d_end; ++index_wei_d, ++index_in_d)
+ {
+ const auto in_ptr_d = in_ptr_start + index_in_d * input_stride_d;
+ const auto weights_ptr_d = weights_ptr_start + index_wei_d * kernel_stride_d;
+ for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h)
+ {
+ const T *const in_ptr_row = in_ptr_d + index_in_h * input_stride_h;
+ const T *const weights_ptr_row = weights_ptr_d + index_wei_h * kernel_stride_h;
+ for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w)
+ {
+ const T *in_ptr_mover = in_ptr_row + index_in_w * input_stride_w;
+ const T *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w;
+ int index_c_in = 0;
+ vector_type out_temp_vec = wrapper::vdup_n(static_cast<T>(0), tag_type());
+ vector_type w_vec = wrapper::vdup_n(static_cast<T>(0), tag_type());
+ for(; index_c_in <= index_c_in_end - num_elems_read_per_iteration;
+ index_c_in += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration)
+ {
+ const auto src_vec = wrapper::vloadq(in_ptr_mover);
+ //Load Cin weights
+ for(unsigned int k = 0; k < num_elems_read_per_iteration; ++k, weights_ptr_mover += index_c_out_end)
+ {
+ w_vec = wrapper::vsetlane(*weights_ptr_mover, w_vec, k);
+ }
+ out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec);
+ }
+ out_temp += vreduce(out_temp_vec);
+ for(; index_c_in < index_c_in_end; ++index_c_in, ++in_ptr_mover, weights_ptr_mover += index_c_out_end)
+ {
+ const auto src_val = *(in_ptr_mover);
+ const auto w_val = *(weights_ptr_mover);
+ out_temp += src_val * w_val;
+ }
+ }
+ }
+ }
+ *(reinterpret_cast<T *>(out_ptr + id_w[0])) = (biases_ptr != nullptr) ? out_temp + biases_ptr[id_w[0]] : out_temp;
+ },
+ wei);
+ },
+ out);
+}
+} // namespace cpu
+} // namespace arm_compute
+#endif // SRC_CORE_NEON_KERNELS_CONV3D_LIST_H \ No newline at end of file
diff --git a/src/cpu/operators/CpuDirectConv3d.cpp b/src/cpu/operators/CpuDirectConv3d.cpp
index 3827910d37..aa74e420a6 100644
--- a/src/cpu/operators/CpuDirectConv3d.cpp
+++ b/src/cpu/operators/CpuDirectConv3d.cpp
@@ -40,10 +40,10 @@ CpuDirectConv3d::CpuDirectConv3d(std::shared_ptr<IMemoryManager> memory_manager)
{
}
-void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo conv_info)
+void CpuDirectConv3d::configure(ITensorInfo *src0, ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo conv_info)
{
- ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, conv_info);
- ARM_COMPUTE_ERROR_ON(src->data_layout() != DataLayout::NDHWC);
+ ARM_COMPUTE_LOG_PARAMS(src0, src1, src2, dst, conv_info);
+ ARM_COMPUTE_ERROR_ON(src0->data_layout() != DataLayout::NDHWC);
_conv_kernel = std::make_unique<kernels::CpuDirectConv3dKernel>();
@@ -55,7 +55,7 @@ void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const IT
_dim_split = Window::DimY;
- _conv_kernel->configure(src, weights, biases, dst, conv_info);
+ _conv_kernel->configure(src0, src1, src2, dst, conv_info);
//Configure Activation Layer
_is_activationlayer_enabled = conv_info.act_info.enabled();
@@ -66,16 +66,12 @@ void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const IT
}
}
-Status CpuDirectConv3d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo conv_info)
+Status CpuDirectConv3d::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo conv_info)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
-
- // output might not be initialized since it can be an intermediate tensor of another layer
- DataType data_type = src->data_type();
- TensorInfo accumulator(dst->clone()->set_is_resizable(true).reset_padding().set_data_type(data_type));
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
// Validate Convolution kernel
- ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConv3dKernel::validate(src, weights, biases, &accumulator, conv_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConv3dKernel::validate(src0, src1, src2, dst, conv_info));
if(conv_info.act_info.enabled())
{
diff --git a/src/cpu/operators/CpuDirectConv3d.h b/src/cpu/operators/CpuDirectConv3d.h
index ad04dee0fa..f7c3099be0 100644
--- a/src/cpu/operators/CpuDirectConv3d.h
+++ b/src/cpu/operators/CpuDirectConv3d.h
@@ -57,23 +57,31 @@ public:
~CpuDirectConv3d();
/** Set the input, weights, biases and output tensor info.
*
- * @param[in, out] src Input tensor info.
- * @param[in] weights Set of kernels to convolve the input volume.
- * The 2nd dimension must be the same as the input's volume 1st dimension.
- * Data type supported: Same as @p src.
- * @param[in] biases Set of biases. Can be nullptr. Data type supported: Same as @p src.
+ * Valid data layouts:
+ * - NDHWC
+ *
+ * Valid data type configurations:
+ * |src0 |src1 |src2 |dst |
+ * |:--------------|:------------------|:------|:--------------|
+ * |F16 |F16 |F16 |F16 |
+ * |F32 |F32 |F32 |F32 |
+ *
+ * @param[in, out] src0 Input tensor info.
+ * @param[in] src1 Set of kernels to convolve the input volume.
+ * The 2nd dimension must be the same as the src0's volume 1st dimension.
+ * @param[in] src2 Set of biases. Can be nullptr.
* @param[out] dst Output tensor info.
* The 1st dimensions must be equal to the 1st dimension of the @p kernels tensor.
* @param[in] conv_info Contains padding, stride, acitvation information.
*/
- void configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo conv_info);
+ void configure(ITensorInfo *src0, ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo conv_info);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to CpuDirectConv3d::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo conv_info);
+ static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo conv_info);
// Inherited methods overridden:
void run(ITensorPack &tensors) override;
diff --git a/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp b/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp
index 1c4326b494..88e73dc72a 100644
--- a/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp
+++ b/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp
@@ -37,36 +37,36 @@ namespace kernels
{
namespace
{
-Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info)
+Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info)
{
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights, dst);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_layout() != DataLayout::NDHWC, "Only NDHWC layout supported");
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(src0, src1, dst);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->data_layout() != DataLayout::NDHWC, "Only NDHWC layout supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv3d_info.act_info.enabled(), "Fused activation not supported");
- ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights);
+ ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src0);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(1) != src->dimension(0), "Weights feature map dimension should match the respective src's one");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->num_dimensions() > 5, "Weights can be at most 5 dimensional");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->dimension(1) != src0->dimension(0), "Weights feature map dimension should match the respective src's one");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->num_dimensions() > 5, "Weights can be at most 5 dimensional");
- ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(2) > (src->dimension(1) + conv3d_info.padding.left + conv3d_info.padding.right));
- ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(3) > (src->dimension(2) + conv3d_info.padding.top + conv3d_info.padding.bottom));
- ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(4) > (src->dimension(3) + conv3d_info.padding.front + conv3d_info.padding.back));
+ ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(2) > (src0->dimension(1) + conv3d_info.padding.left + conv3d_info.padding.right));
+ ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(3) > (src0->dimension(2) + conv3d_info.padding.top + conv3d_info.padding.bottom));
+ ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(4) > (src0->dimension(3) + conv3d_info.padding.front + conv3d_info.padding.back));
- if(biases != nullptr)
+ if(src2 != nullptr)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->dimension(0) != weights->dimension(0), "Biases size and number of dst feature maps should match");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->num_dimensions() > 1, "Biases should be one dimensional");
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->dimension(0) != src1->dimension(0), "Biases size and number of dst feature maps should match");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->num_dimensions() > 1, "Biases should be one dimensional");
}
// Checks performed when dst is configured
if(dst->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->dimension(0) != weights->dimension(0), "Weights and dst OFMs should match");
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), misc::shape_calculator::compute_conv3d_shape(src->tensor_shape(), weights->tensor_shape(), conv3d_info));
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->dimension(0) != src1->dimension(0), "Weights and dst OFMs should match");
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv3d_info));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, dst);
}
return Status{};
@@ -78,27 +78,27 @@ ClDirectConv3dKernel::ClDirectConv3dKernel()
_type = CLKernelType::DIRECT;
}
-void ClDirectConv3dKernel::configure(const CLCompileContext &compile_context, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst,
+void ClDirectConv3dKernel::configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst,
const Conv3dInfo &conv3d_info)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
// Perform validation
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, biases, dst, conv3d_info));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, conv3d_info));
// Create window and update padding
- const DataType data_type = src->data_type();
- const size_t src_width = src->dimension(1);
- const size_t src_height = src->dimension(2);
- const size_t src_depth = src->dimension(3);
- const size_t src_channels = src->dimension(0);
+ const DataType data_type = src0->data_type();
+ const size_t src_width = src0->dimension(1);
+ const size_t src_height = src0->dimension(2);
+ const size_t src_depth = src0->dimension(3);
+ const size_t src_channels = src0->dimension(0);
const size_t dst_width = dst->dimension(1);
const size_t dst_height = dst->dimension(2);
const size_t dst_depth = dst->dimension(3);
const size_t dst_channels = dst->dimension(0);
- const size_t weights_width = weights->dimension(2);
- const size_t weights_height = weights->dimension(3);
- const size_t weights_depth = weights->dimension(4);
+ const size_t weights_width = src1->dimension(2);
+ const size_t weights_height = src1->dimension(3);
+ const size_t weights_depth = src1->dimension(4);
const size_t pad_left = conv3d_info.padding.left;
const size_t pad_top = conv3d_info.padding.top;
const size_t pad_front = conv3d_info.padding.front;
@@ -108,7 +108,7 @@ void ClDirectConv3dKernel::configure(const CLCompileContext &compile_context, co
const size_t n0 = std::min(dst->dimension(0), static_cast<size_t>(4u));
const size_t m0 = (dst->tensor_shape()[0] > 16) ? ((data_type == DataType::F32) ? 2U : 4U) : 1U;
- const size_t k0 = adjust_vec_size(8u, src->dimension(0));
+ const size_t k0 = adjust_vec_size(8u, src0->dimension(0));
const size_t partial_store_n0 = dst->dimension(0) % n0;
CLBuildOptions build_options;
@@ -136,7 +136,7 @@ void ClDirectConv3dKernel::configure(const CLCompileContext &compile_context, co
build_options.add_option("-DM0=" + support::cpp11::to_string(m0));
build_options.add_option("-DK0=" + support::cpp11::to_string(k0));
build_options.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0));
- build_options.add_option_if(biases != nullptr, std::string("-DHAS_BIAS"));
+ build_options.add_option_if(src2 != nullptr, std::string("-DHAS_BIAS"));
std::string kernel_name = "direct_convolution3d_ndhwc";
_kernel = create_kernel(compile_context, kernel_name, build_options.options());
@@ -169,9 +169,9 @@ void ClDirectConv3dKernel::configure(const CLCompileContext &compile_context, co
_config_id += support::cpp11::to_string(dst_channels);
}
-Status ClDirectConv3dKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info)
+Status ClDirectConv3dKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv3d_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, src2, dst, conv3d_info));
return Status{};
}
diff --git a/src/gpu/cl/kernels/ClDirectConv3dKernel.h b/src/gpu/cl/kernels/ClDirectConv3dKernel.h
index 9ac8f0d7b3..485c900826 100644
--- a/src/gpu/cl/kernels/ClDirectConv3dKernel.h
+++ b/src/gpu/cl/kernels/ClDirectConv3dKernel.h
@@ -61,21 +61,21 @@ public:
* |F32 |F32 |F32 |F32 |
*
* @param[in] compile_context The compile context to be used.
- * @param[in] src Source tensor. 4 lower dimensions represent a single src [IFM, width, height, depth],
+ * @param[in] src0 Source tensor. 4 lower dimensions represent a single src [IFM, width, height, depth],
* while every optional dimension from 5 and above represent a batch of srcs.
- * @param[in] weights Weights tensor. Weights are 5D tensor with dimensions [OFM, IFM, kernel_w, kernel_h, kernel_d].
- * @param[in] biases Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
+ * @param[in] src1 Weights tensor. Weights are 5D tensor with dimensions [OFM, IFM, kernel_w, kernel_h, kernel_d].
+ * @param[in] src2 Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
* @param[out] dst Destination tensor. 4 lower dimensions represent a single dst [OFM, width, height, depth], while the rest represent batch of dsts.
* @param[in] conv3d_info Contains strides, padding, rounding, activation, dilation and fast math information. Activation and fast math are currently unused.
*/
- void configure(const CLCompileContext &compile_context, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv3d_info);
+ void configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv3d_info);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to ClDirectConv3dKernel::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info);
+ static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info);
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
diff --git a/src/gpu/cl/operators/ClDirectConv3d.cpp b/src/gpu/cl/operators/ClDirectConv3d.cpp
index d10165814b..5d37f07f31 100644
--- a/src/gpu/cl/operators/ClDirectConv3d.cpp
+++ b/src/gpu/cl/operators/ClDirectConv3d.cpp
@@ -30,19 +30,19 @@ namespace arm_compute
{
namespace opencl
{
-void ClDirectConv3d::configure(const CLCompileContext &compile_context, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv3d_info)
+void ClDirectConv3d::configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv3d_info)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(src);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(src0);
// Configure direct convolution 3d kernel
auto k = std::make_unique<kernels::ClDirectConv3dKernel>();
- k->configure(compile_context, src, weights, biases, dst, conv3d_info);
+ k->configure(compile_context, src0, src1, src2, dst, conv3d_info);
_direct_conv3d_kernel = std::move(k);
}
-Status ClDirectConv3d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info)
+Status ClDirectConv3d::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info)
{
- ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClDirectConv3dKernel::validate(src, weights, biases, dst, conv3d_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClDirectConv3dKernel::validate(src0, src1, src2, dst, conv3d_info));
return Status{};
}
diff --git a/src/gpu/cl/operators/ClDirectConv3d.h b/src/gpu/cl/operators/ClDirectConv3d.h
index ce9135b812..d8ffefc450 100644
--- a/src/gpu/cl/operators/ClDirectConv3d.h
+++ b/src/gpu/cl/operators/ClDirectConv3d.h
@@ -57,15 +57,15 @@ public:
* |F32 |F32 |F32 |F32 |
*
* @param[in] compile_context The compile context to be used.
- * @param[in] src Source tensor. 4 lower dimensions represent a single src [IFM, width, height, depth],
+ * @param[in] src0 Source tensor. 4 lower dimensions represent a single src [IFM, width, height, depth],
* while every optional dimension from 5 and above represent a batch of srcs.
- * @param[in] weights Weights tensor. Weights are 5D tensor with dimensions [OFM, IFM, kernel_w, kernel_h, kernel_d].
- * @param[in] biases Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
+ * @param[in] src1 Weights tensor. Weights are 5D tensor with dimensions [OFM, IFM, kernel_w, kernel_h, kernel_d].
+ * @param[in] src2 Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
* @param[out] dst Destination tensor. 4 lower dimensions represent a single dst [OFM, width, height, depth], while the rest represent batch of dsts.
* @param[in] conv3d_info Contains strides, padding, rounding, activation, dilation and fast math information. Activation and fast math are currently unused.
*
*/
- void configure(const CLCompileContext &compile_context, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv3d_info);
+ void configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv3d_info);
/** Static function to check if given info will lead to a valid configuration
*
@@ -73,7 +73,7 @@ public:
*
* @return a status
*/
- static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info);
+ static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info);
// Inherited method overridden
void run(ITensorPack &tensors) override;
diff --git a/src/runtime/NEON/functions/NEConv3D.cpp b/src/runtime/NEON/functions/NEConv3D.cpp
index b5e2e2a843..3bb66c44b0 100644
--- a/src/runtime/NEON/functions/NEConv3D.cpp
+++ b/src/runtime/NEON/functions/NEConv3D.cpp
@@ -27,7 +27,6 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
#include "src/common/utils/Log.h"
-#include "src/core/helpers/MemoryHelpers.h"
#include "src/cpu/operators/CpuDirectConv3d.h"
namespace arm_compute
@@ -58,7 +57,7 @@ void NEConv3D::configure(ITensor *input, const ITensor *weights, const ITensor *
f->configure(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info);
_impl->op = std::move(f);
- if(_impl->op)
+ if(_impl->op != nullptr)
{
_impl->run_pack = { { ACL_SRC_0, input }, { ACL_SRC_1, weights }, { ACL_SRC_2, biases }, { ACL_DST, output } };
}
@@ -73,7 +72,7 @@ Status NEConv3D::validate(const ITensorInfo *input, const ITensorInfo *weights,
void NEConv3D::run()
{
- if(_impl->op)
+ if(_impl->op != nullptr)
{
_impl->op->run(_impl->run_pack);
}