From 5dda2177800009b24e31550ed849b1ef3fca6167 Mon Sep 17 00:00:00 2001 From: Sheri Zhang Date: Fri, 15 Oct 2021 19:54:17 +0100 Subject: DirectConv3d support refine - Decouple data support of CpuDirectConv3dKernel - Update documentation for Conv3d Signed-off-by: Sheri Zhang Change-Id: I1d94aa28f821f45a1a3d39cc3335c8faeee89f0d Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6453 Reviewed-by: Giorgio Arena Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- arm_compute/runtime/NEON/functions/NEConv3D.h | 1 - docs/user_guide/data_layout.dox | 3 +- docs/user_guide/operator_list.dox | 3 +- docs/user_guide/release_version_and_change_log.dox | 8 + src/core/NEON/NEMath.h | 15 + src/core/NEON/NEMath.inl | 27 +- src/cpu/kernels/CpuDirectConv2dKernel.cpp | 11 - src/cpu/kernels/CpuDirectConv3dKernel.cpp | 303 ++++++--------------- src/cpu/kernels/CpuDirectConv3dKernel.h | 24 +- src/cpu/kernels/conv3d/neon/list.h | 176 ++++++++++++ src/cpu/operators/CpuDirectConv3d.cpp | 18 +- src/cpu/operators/CpuDirectConv3d.h | 22 +- src/gpu/cl/kernels/ClDirectConv3dKernel.cpp | 66 ++--- src/gpu/cl/kernels/ClDirectConv3dKernel.h | 10 +- src/gpu/cl/operators/ClDirectConv3d.cpp | 10 +- src/gpu/cl/operators/ClDirectConv3d.h | 10 +- src/runtime/NEON/functions/NEConv3D.cpp | 5 +- 17 files changed, 400 insertions(+), 312 deletions(-) create mode 100644 src/cpu/kernels/conv3d/neon/list.h diff --git a/arm_compute/runtime/NEON/functions/NEConv3D.h b/arm_compute/runtime/NEON/functions/NEConv3D.h index 487d357fa1..2b3a45f0af 100644 --- a/arm_compute/runtime/NEON/functions/NEConv3D.h +++ b/arm_compute/runtime/NEON/functions/NEConv3D.h @@ -29,7 +29,6 @@ #include "arm_compute/core/ITensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/runtime/FunctionDescriptors.h" -#include "arm_compute/runtime/MemoryGroup.h" #include diff --git a/docs/user_guide/data_layout.dox b/docs/user_guide/data_layout.dox index 97d3ea6262..ae69bbf457 100644 --- a/docs/user_guide/data_layout.dox +++ b/docs/user_guide/data_layout.dox @@ -34,8 +34,9 @@ the right-most letter represents the fastest changing dimension: - NHWC: The native layout of Compute Library that delivers the best performance where channels are in the fastest changing dimension - NCHW: Legacy layout where width is in the fastest changing dimension +- NDHWC: New data layout for supporting 3D operators -, where N = batch, C = channel, H = height, W = width. +, where N = batch, C = channel, H = height, W = width, D = depth. */ } // namespace diff --git a/docs/user_guide/operator_list.dox b/docs/user_guide/operator_list.dox index ebc970d8c1..1d06a394a9 100644 --- a/docs/user_guide/operator_list.dox +++ b/docs/user_guide/operator_list.dox @@ -52,9 +52,10 @@ Compute Library supports the following data layouts (fast changing dimension fro
  • NHWC: The native layout of Compute Library that delivers the best performance where channels are in the fastest changing dimension
  • NCHW: Legacy layout where width is in the fastest changing dimension +
  • NDHWC: New data layout for supporting 3D operators
  • All: Agnostic to any specific data layout
-where N = batches, C = channels, H = height, W = width +where N = batches, C = channels, H = height, W = width, D = depth diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox index 583cf4fb82..2470b45203 100644 --- a/docs/user_guide/release_version_and_change_log.dox +++ b/docs/user_guide/release_version_and_change_log.dox @@ -40,6 +40,14 @@ If there is more than one release in a month then an extra sequential number is @section S2_2_changelog Changelog +v21.11 Public major release + - Various bug fixes. + - Various optimizations. + - New OpenCL kernels / functions: + - @ref CLConv3D + - New Arm® Neon™ kernels / functions: + - @ref NEConv3D + v21.08 Public major release - Various bug fixes. - Various optimizations: diff --git a/src/core/NEON/NEMath.h b/src/core/NEON/NEMath.h index 13484c9c15..8118c4701f 100644 --- a/src/core/NEON/NEMath.h +++ b/src/core/NEON/NEMath.h @@ -239,6 +239,14 @@ float32x4_t vsinq_f32(float32x4_t val); */ float32x2_t vsin_f32(float32x2_t val); +/** Reduce a vector to be a scalar by accumulating all lanes in the vector + * + * @param[in] v Vector to be reduced. + * + * @return the wrapped-around number. + */ +float vreduce(const float32x4_t &v); + #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC /** Calculate hyperbolic tangent. * @@ -319,6 +327,13 @@ float16x8_t vpowq_f16(float16x8_t val, float16x8_t n); */ float16x8_t vsinq_f16(float16x8_t val); +/** Reduce a vector to be a scalar by accumulating all lanes in the vector + * + * @param[in] v Vector to be reduced. + * + * @return the wrapped-around number. + */ +float16_t vreduce(const float16x8_t &v); #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ } // namespace arm_compute #include "src/core/NEON/NEMath.inl" diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl index 5ac62badcc..05cf3013bc 100644 --- a/src/core/NEON/NEMath.inl +++ b/src/core/NEON/NEMath.inl @@ -193,7 +193,7 @@ inline float32x4_t vtanhq_f32(float32x4_t val) static const float32x4_t CONST_THR = vdupq_n_f32(5.e-3); static const float32x4_t CONST_1_3 = vdupq_n_f32(0.3333333f); - float32x4_t x = vminq_f32(vmaxq_f32(val, CONST_MIN_TANH), CONST_MAX_TANH); + float32x4_t x = vminq_f32(vmaxq_f32(val, CONST_MIN_TANH), CONST_MAX_TANH); // x * (1 - x^2/3) if |x| < 5.e-3 or (exp2x - 1) / (exp2x + 1) otherwise float32x4_t exp2x = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vexpq_f32(vmulq_f32(CONST_2, x)), vmulq_f32(x, x)); float32x4_t num = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vsubq_f32(exp2x, CONST_1), vmulq_f32(CONST_1_3, exp2x)); @@ -418,6 +418,18 @@ inline float32x4x4_t convert_int_to_float(const int8x1 return convert_int8x16_to_float32x4x4(in); } +inline float vreduce(const float32x4_t &v) +{ + const float32x2_t v0 = vget_high_f32(v); + const float32x2_t v1 = vget_low_f32(v); + const float32x2_t v_out = vadd_f32(v0, v1); + + const float a = vget_lane_f32(v_out, 0); + const float b = vget_lane_f32(v_out, 1); + + return a + b; +} + #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC /** Exponent polynomial coefficients */ /** Logarithm polynomial coefficients */ @@ -550,6 +562,19 @@ inline float16x4_t vsin_f16(float16x4_t val) return vcvt_f16_f32(vcombine_f32(res_low, res_high)); } +inline float16_t vreduce(const float16x8_t &v) +{ + const float16x4_t v0 = vget_high_f16(v); + const float16x4_t v1 = vget_low_f16(v); + const float16x4_t v_out = vadd_f16(v0, v1); + + const float16_t a = vget_lane_f16(v_out, 0); + const float16_t b = vget_lane_f16(v_out, 1); + const float16_t c = vget_lane_f16(v_out, 2); + const float16_t d = vget_lane_f16(v_out, 3); + + return a + b + c + d; +} #endif /* DOXYGEN_SKIP_THIS */ #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ } // namespace arm_compute diff --git a/src/cpu/kernels/CpuDirectConv2dKernel.cpp b/src/cpu/kernels/CpuDirectConv2dKernel.cpp index db1b5f3c54..68de9803eb 100644 --- a/src/cpu/kernels/CpuDirectConv2dKernel.cpp +++ b/src/cpu/kernels/CpuDirectConv2dKernel.cpp @@ -711,17 +711,6 @@ public: } }; -float vreduce(const float32x4_t &v) -{ - auto v0 = wrapper::vgethigh(v); - auto v1 = wrapper::vgetlow(v); - auto v_out = wrapper::vadd(v0, v1); - - float a = wrapper::vgetlane(v_out, 0); - float b = wrapper::vgetlane(v_out, 1); - return a + b; -} - template inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration, const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info) diff --git a/src/cpu/kernels/CpuDirectConv3dKernel.cpp b/src/cpu/kernels/CpuDirectConv3dKernel.cpp index fecdb2bcae..595b5f1330 100644 --- a/src/cpu/kernels/CpuDirectConv3dKernel.cpp +++ b/src/cpu/kernels/CpuDirectConv3dKernel.cpp @@ -23,9 +23,6 @@ */ #include "src/cpu/kernels/CpuDirectConv3dKernel.h" -#include "src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h" -#include "src/core/NEON/wrapper/wrapper.h" - #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/IAccessWindow.h" @@ -35,8 +32,10 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "src/core/CPP/Validate.h" +#include "src/core/NEON/wrapper/wrapper.h" +#include "src/core/common/Registrars.h" #include "src/core/helpers/AutoConfiguration.h" -#include "src/core/helpers/WindowHelpers.h" +#include "src/cpu/kernels/conv3d/neon/list.h" #include @@ -50,236 +49,126 @@ namespace kernels { namespace { -Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv_info) +struct DirectConv3dSelectorData { - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst); - ARM_COMPUTE_RETURN_ERROR_ON(src->data_layout() != DataLayout::NDHWC); - ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights); - - const DataLayout data_layout = src->data_layout(); - const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); - - // Weight layout is D, H, W, Cin, Cout - ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 5); - ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(1) != src->dimension(channel_idx)); + DataType dt; + const CPUInfo &ci; +}; +using DirectConv3dSelectorPtr = std::add_pointer::type; +using DirectConv3dKernelPtr = std::add_pointer::type; +struct DirectConv3dKernel +{ + const char *name; + const DirectConv3dSelectorPtr is_selected; + DirectConv3dKernelPtr ukernel; +}; - if(biases != nullptr) +static const DirectConv3dKernel available_kernels[] = +{ +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->dimension(0) != weights->dimension(0), - "biases size and number of output feature maps should match"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->num_dimensions() > 1, "biases should be one dimensional"); - } - - // Checks performed when output is configured - if(dst->total_size() != 0) + "neon_fp16_directconv3d", + [](const DirectConv3dSelectorData & data) { return data.dt == DataType::F16 && data.ci.has_fp16(); }, + REGISTER_FP16_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc) + }, +#endif /* !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ { - TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src->tensor_shape(), weights->tensor_shape(), conv_info); - - DataType data_type = src->data_type(); - - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), output_shape); - ARM_COMPUTE_RETURN_ERROR_ON(dst->data_type() != data_type); + "neon_fp32_directconv3d", + [](const DirectConv3dSelectorData & data) { return data.dt == DataType::F32; }, + REGISTER_FP32_NEON(arm_compute::cpu::directconv3d_float_neon_ndhwc) } +}; - return Status{}; -} - -/** Reduce a vector to be a scalar by accumulating all lanes in the vector +/** Micro-kernel selector * - * @param[in] v Vector to be reduced. + * @param[in] data Selection data passed to help pick the appropriate micro-kernel * - * @return the wrapped-around number. + * @return A matching micro-kernel else nullptr */ -auto vreduce(const float32x4_t &v) +const DirectConv3dKernel *get_implementation(const DirectConv3dSelectorData &data) { - auto v0 = wrapper::vgethigh(v); - auto v1 = wrapper::vgetlow(v); - auto v_out = wrapper::vadd(v0, v1); - - float a = wrapper::vgetlane(v_out, 0); - float b = wrapper::vgetlane(v_out, 1); - return a + b; -} - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -auto vreduce(const float16x8_t &v) -{ - auto v0 = wrapper::vgethigh(v); - auto v1 = wrapper::vgetlow(v); - auto v_out = wrapper::vadd(v0, v1); - - float16_t a = wrapper::vgetlane(v_out, 0); - float16_t b = wrapper::vgetlane(v_out, 1); - float16_t c = wrapper::vgetlane(v_out, 2); - float16_t d = wrapper::vgetlane(v_out, 3); - return a + b + c + d; -} -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + for(const auto &uk : available_kernels) + { + if(uk.is_selected(data)) + { + return &uk; + } + } + return nullptr; } -template -void CpuDirectConv3dKernel::convolve_ndhwc(const Window &window, const ITensor *src, const ITensor *weights, const ITensor *biases, ITensor *dst) +Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv_info) { - using vtype = wrapper::traits::neon_bitvector; - using vector_type = typename vtype::type; - using tag_type = typename vtype::tag_type; - constexpr int num_elems_read_per_iteration = 16 / sizeof(T); - - // Scalar quantities (N D H W Cin) - const int element_size = src->info()->element_size(); - const int input_stride_w = src->info()->strides_in_bytes().y() / element_size; - const int input_stride_h = src->info()->strides_in_bytes().z() / element_size; - const int input_stride_d = src->info()->strides_in_bytes()[3] / element_size; - const int input_stride_n = src->info()->strides_in_bytes()[4] / element_size; - const int input_dim_w = src->info()->dimension(1); - const int input_dim_h = src->info()->dimension(2); - const int input_dim_d = src->info()->dimension(3); - - // Kernel info (D H W Cin Cout) - const unsigned int kernel_stride_w = weights->info()->strides_in_bytes()[2] / element_size; - const unsigned int kernel_stride_h = weights->info()->strides_in_bytes()[3] / element_size; - const unsigned int kernel_stride_d = weights->info()->strides_in_bytes()[4] / element_size; - const int kernel_dim_w = weights->info()->dimension(2); - const int kernel_dim_h = weights->info()->dimension(3); - const int kernel_dim_d = weights->info()->dimension(4); - - // Convolution padding and stride - const int conv_pad_top = _conv_info.padding.top; - const int conv_pad_left = _conv_info.padding.left; - const int conv_pad_front = _conv_info.padding.front; - const int conv_stride_w = _conv_info.stride.width; - const int conv_stride_h = _conv_info.stride.height; - const int conv_stride_d = _conv_info.stride.depth; + const auto *uk = get_implementation(DirectConv3dSelectorData{ src0->data_type(), CPUInfo::get() }); + ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); - // Setup input window for the output iterator - Window window_out = window; - window_out.set(Window::DimX, Window::Dimension(0, 1, 1)); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst); + ARM_COMPUTE_RETURN_ERROR_ON(src0->data_layout() != DataLayout::NDHWC); + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src0); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1); - // Setup input window for the weights iterator - Window window_w = calculate_max_window(*weights->info(), Steps()); - window_w.set(Window::DimY, Window::Dimension(0, 1, 1)); - window_w.set(Window::DimZ, Window::Dimension(0, 1, 1)); - window_w.set(Window::DimW, Window::Dimension(0, 1, 1)); - window_w.set(4, Window::Dimension(0, 1, 1)); + const DataLayout data_layout = src0->data_layout(); + const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); - Iterator out(dst, window_out); - Iterator wei(weights, window_w); + // Weight layout is D, H, W, Cin, Cout + ARM_COMPUTE_RETURN_ERROR_ON(src1->num_dimensions() > 5); + ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(1) != src0->dimension(channel_idx)); - const T *biases_ptr = nullptr; - if(biases) + if(src2 != nullptr) { - biases_ptr = reinterpret_cast(biases->buffer() + biases->info()->offset_first_element_in_bytes()); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->dimension(0) != src1->dimension(0), + "biases size and number of output feature maps should match"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->num_dimensions() > 1, "biases should be one dimensional"); } - execute_window_loop(window_out, [&](const Coordinates & id) - { - // We are computing the theoretical input starting points - const int in_w_start_t = static_cast(id.y()) * conv_stride_w - conv_pad_left; - const int in_h_start_t = static_cast(id.z()) * conv_stride_h - conv_pad_top; - const int in_d_start_t = static_cast(id[3]) * conv_stride_d - conv_pad_front; - const int in_w_end_t = in_w_start_t + kernel_dim_w; - const int in_h_end_t = in_h_start_t + kernel_dim_h; - const int in_d_end_t = in_d_start_t + kernel_dim_d; - // We are computing the valid initial and ending input points by checking the borders - const int in_w_start = std::max(in_w_start_t, 0); - const int in_h_start = std::max(in_h_start_t, 0); - const int in_d_start = std::max(in_d_start_t, 0); - const int in_w_end = std::min(in_w_end_t, input_dim_w); - const int in_h_end = std::min(in_h_end_t, input_dim_h); - const int in_d_end = std::min(in_d_end_t, input_dim_d); + // Checks performed when output is configured + if(dst->total_size() != 0) + { + TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv_info); - // We use the input points to select the valid weight points to use - const int wei_w_start = in_w_start - in_w_start_t; - const int wei_h_start = in_h_start - in_h_start_t; - const int wei_d_start = in_d_start - in_d_start_t; - const int wei_w_end = kernel_dim_w - (in_w_end_t - in_w_end); - const int wei_h_end = kernel_dim_h - (in_h_end_t - in_h_end); - const int wei_d_end = kernel_dim_d - (in_d_end_t - in_d_end); + DataType data_type = src0->data_type(); - const int index_c_out_end = weights->info()->dimension(0); - const int index_c_in_end = weights->info()->dimension(1); - const T *const in_ptr_start = reinterpret_cast(src->buffer() + src->info()->offset_first_element_in_bytes()) + id[4] * input_stride_n; + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), output_shape); + ARM_COMPUTE_RETURN_ERROR_ON(dst->data_type() != data_type); + } - execute_window_loop(window_w, [&](const Coordinates & id_w) - { - /* - * This is the loop in the weights, and it goes along OFM (output feature map) - */ - const auto weights_ptr_start = reinterpret_cast(wei.ptr()); - T out_temp = static_cast(0); - T *out_ptr = reinterpret_cast(out.ptr()); - for(int index_wei_d = wei_d_start, index_in_d = in_d_start; index_wei_d < wei_d_end; ++index_wei_d, ++index_in_d) - { - const auto in_ptr_d = in_ptr_start + index_in_d * input_stride_d; - const auto weights_ptr_d = weights_ptr_start + index_wei_d * kernel_stride_d; - for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h) - { - const T *const in_ptr_row = in_ptr_d + index_in_h * input_stride_h; - const T *const weights_ptr_row = weights_ptr_d + index_wei_h * kernel_stride_h; - for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w) - { - const T *in_ptr_mover = in_ptr_row + index_in_w * input_stride_w; - const T *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w; - int index_c_in = 0; - vector_type out_temp_vec = wrapper::vdup_n(static_cast(0), tag_type()); - vector_type w_vec = wrapper::vdup_n(static_cast(0), tag_type()); - for(; index_c_in <= index_c_in_end - num_elems_read_per_iteration; - index_c_in += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration) - { - const auto src_vec = wrapper::vloadq(in_ptr_mover); - //Load Cin weights - for(unsigned int k = 0; k < num_elems_read_per_iteration; ++k, weights_ptr_mover += index_c_out_end) - { - w_vec = wrapper::vsetlane(*weights_ptr_mover, w_vec, k); - } - out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec); - } - out_temp += vreduce(out_temp_vec); - for(; index_c_in < index_c_in_end; ++index_c_in, ++in_ptr_mover, weights_ptr_mover += index_c_out_end) - { - const auto src_val = *(in_ptr_mover); - const auto w_val = *(weights_ptr_mover); - out_temp += src_val * w_val; - } - } - } - } - *(reinterpret_cast(out_ptr + id_w[0])) = (biases) ? out_temp + biases_ptr[id_w[0]] : out_temp; - }, - wei); - }, - out); + return Status{}; +} } -void CpuDirectConv3dKernel::configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv_info) +void CpuDirectConv3dKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv_info) { - ARM_COMPUTE_UNUSED(biases); - ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst); + ARM_COMPUTE_UNUSED(src2); + ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst); - _conv_info = conv_info; + const auto *uk = get_implementation(DirectConv3dSelectorData{ src0->data_type(), CPUInfo::get() }); + ARM_COMPUTE_ERROR_ON_NULLPTR(uk); + + _conv_info = conv_info; + _run_method = uk->ukernel; + _name = std::string("CpuDirectConv3dKernel").append("/").append(uk->name); // Get convolved dimensions - TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src->tensor_shape(), weights->tensor_shape(), conv_info); + TensorShape output_shape = misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv_info); - DataType data_type = src->data_type(); + DataType data_type = src0->data_type(); // Output auto inizialitation if not yet initialized auto_init_if_empty(*dst, output_shape, 1, data_type); // Perform validation step - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, biases, dst, conv_info)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, conv_info)); // Configure kernel window Window win = calculate_max_window(*dst, Steps()); ICpuKernel::configure(win); } -Status CpuDirectConv3dKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv_info) +Status CpuDirectConv3dKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv_info) { - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, src2, dst, conv_info)); return Status{}; } @@ -289,35 +178,19 @@ void CpuDirectConv3dKernel::run_op(ITensorPack &tensors, const Window &window, c ARM_COMPUTE_UNUSED(info); ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window); + ARM_COMPUTE_ERROR_ON(_run_method == nullptr); - auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0); - auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); - auto biases = tensors.get_const_tensor(TensorType::ACL_SRC_2); - auto dst = tensors.get_tensor(TensorType::ACL_DST); + auto src0 = tensors.get_const_tensor(TensorType::ACL_SRC_0); + auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto src2 = tensors.get_const_tensor(TensorType::ACL_SRC_2); + auto dst = tensors.get_tensor(TensorType::ACL_DST); - switch(src->info()->data_type()) - { - case DataType::F32: - { - convolve_ndhwc(window, src, weights, biases, dst); - break; - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - { - convolve_ndhwc(window, src, weights, biases, dst); - break; - } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - default: - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } + _run_method(src0, src1, src2, dst, _conv_info, window); } const char *CpuDirectConv3dKernel::name() const { - return "CpuDirectConv3dKernel"; + return _name.c_str(); } } // namespace kernels } // namespace cpu diff --git a/src/cpu/kernels/CpuDirectConv3dKernel.h b/src/cpu/kernels/CpuDirectConv3dKernel.h index c7dcb0fb5e..fc64e8518b 100644 --- a/src/cpu/kernels/CpuDirectConv3dKernel.h +++ b/src/cpu/kernels/CpuDirectConv3dKernel.h @@ -39,10 +39,7 @@ class CpuDirectConv3dKernel : public ICpuKernel public: CpuDirectConv3dKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuDirectConv3dKernel); - /** Set the src, weights, and dst tensor info. - * - * Valid data layouts: - * - NDHWC + /** Set the src, weights, biases and dst tensor info. * * Valid data type configurations: * |src0 |src1 |src2 |dst | @@ -50,34 +47,35 @@ public: * |F16 |F16 |F16 |F16 | * |F32 |F32 |F32 |F32 | * - * @param[in, out] src Input tensor info. - * @param[in] weights Set of kernels to convolve the input volume. + * @param[in, out] src0 Input tensor info. + * @param[in] src1 Set of kernels to convolve the input volume. * The 2nd dimension must be the same as the input's volume 1st dimension. - * @param[in] biases Set of biases. Can be nullptr. + * @param[in] src2 Set of biases. Can be nullptr. * @param[out] dst Output tensor info. * The 1st dimensions must be equal to the 1st dimension of the @p kernels tensor. * @param[in] conv_info Contains padding, stride, acitvation information. * */ - void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv_info); + void configure(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv_info); /** Static function to check if given info will lead to a valid configuration * * Similar to CpuDirectConv3dKernel::configure() * * @return a status */ - static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv_info); + static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv_info); // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; private: - /* Template function for convolution NDHWC */ - template - void convolve_ndhwc(const Window &window, const ITensor *src, const ITensor *weights, const ITensor *biases, ITensor *dst); + /* Template function for convolution 3d NDHWC */ + using DirectConv3dKernelPtr = std::add_pointer::type; - Conv3dInfo _conv_info{}; + Conv3dInfo _conv_info{}; + DirectConv3dKernelPtr _run_method{ nullptr }; + std::string _name{}; }; } // namespace kernels } // namespace cpu diff --git a/src/cpu/kernels/conv3d/neon/list.h b/src/cpu/kernels/conv3d/neon/list.h new file mode 100644 index 0000000000..b24785a48f --- /dev/null +++ b/src/cpu/kernels/conv3d/neon/list.h @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_CORE_NEON_KERNELS_CONV3D_LIST_H +#define SRC_CORE_NEON_KERNELS_CONV3D_LIST_H + +#include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/misc/Traits.h" +#include "arm_compute/runtime/FunctionDescriptors.h" +#include "src/core/NEON/wrapper/wrapper.h" +#include "src/core/helpers/WindowHelpers.h" + +namespace arm_compute +{ +namespace cpu +{ +template +void directconv3d_float_neon_ndhwc(const ITensor *src0, const ITensor *src1, const ITensor *src2, ITensor *dst, const Conv3dInfo &conv_info, const Window &window) +{ + const ITensor *src = src0; + const ITensor *weights = src1; + const ITensor *biases = src2; + + using vtype = wrapper::traits::neon_bitvector; + using vector_type = typename vtype::type; + using tag_type = typename vtype::tag_type; + constexpr int num_elems_read_per_iteration = 16 / sizeof(T); + + // Scalar quantities (N D H W Cin) + const int element_size = src->info()->element_size(); + const int input_stride_w = src->info()->strides_in_bytes().y() / element_size; + const int input_stride_h = src->info()->strides_in_bytes().z() / element_size; + const int input_stride_d = src->info()->strides_in_bytes()[3] / element_size; + const int input_stride_n = src->info()->strides_in_bytes()[4] / element_size; + const int input_dim_w = src->info()->dimension(1); + const int input_dim_h = src->info()->dimension(2); + const int input_dim_d = src->info()->dimension(3); + + // Kernel info (D H W Cin Cout) + const unsigned int kernel_stride_w = weights->info()->strides_in_bytes()[2] / element_size; + const unsigned int kernel_stride_h = weights->info()->strides_in_bytes()[3] / element_size; + const unsigned int kernel_stride_d = weights->info()->strides_in_bytes()[4] / element_size; + const int kernel_dim_w = weights->info()->dimension(2); + const int kernel_dim_h = weights->info()->dimension(3); + const int kernel_dim_d = weights->info()->dimension(4); + + // Convolution padding and stride + const int conv_pad_top = conv_info.padding.top; + const int conv_pad_left = conv_info.padding.left; + const int conv_pad_front = conv_info.padding.front; + const int conv_stride_w = conv_info.stride.width; + const int conv_stride_h = conv_info.stride.height; + const int conv_stride_d = conv_info.stride.depth; + + // Setup input window for the output iterator + Window window_out = window; + window_out.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Setup input window for the weights iterator + Window window_w = calculate_max_window(*weights->info(), Steps()); + window_w.set(Window::DimY, Window::Dimension(0, 1, 1)); + window_w.set(Window::DimZ, Window::Dimension(0, 1, 1)); + window_w.set(Window::DimW, Window::Dimension(0, 1, 1)); + window_w.set(4, Window::Dimension(0, 1, 1)); + + Iterator out(dst, window_out); + Iterator wei(weights, window_w); + + const T *biases_ptr = nullptr; + if(biases != nullptr) + { + biases_ptr = reinterpret_cast(biases->buffer() + biases->info()->offset_first_element_in_bytes()); + } + execute_window_loop(window_out, [&](const Coordinates & id) + { + // We are computing the theoretical input starting points + const int in_w_start_t = static_cast(id.y()) * conv_stride_w - conv_pad_left; + const int in_h_start_t = static_cast(id.z()) * conv_stride_h - conv_pad_top; + const int in_d_start_t = static_cast(id[3]) * conv_stride_d - conv_pad_front; + const int in_w_end_t = in_w_start_t + kernel_dim_w; + const int in_h_end_t = in_h_start_t + kernel_dim_h; + const int in_d_end_t = in_d_start_t + kernel_dim_d; + + // We are computing the valid initial and ending input points by checking the borders + const int in_w_start = std::max(in_w_start_t, 0); + const int in_h_start = std::max(in_h_start_t, 0); + const int in_d_start = std::max(in_d_start_t, 0); + const int in_w_end = std::min(in_w_end_t, input_dim_w); + const int in_h_end = std::min(in_h_end_t, input_dim_h); + const int in_d_end = std::min(in_d_end_t, input_dim_d); + + // We use the input points to select the valid weight points to use + const int wei_w_start = in_w_start - in_w_start_t; + const int wei_h_start = in_h_start - in_h_start_t; + const int wei_d_start = in_d_start - in_d_start_t; + const int wei_w_end = kernel_dim_w - (in_w_end_t - in_w_end); + const int wei_h_end = kernel_dim_h - (in_h_end_t - in_h_end); + const int wei_d_end = kernel_dim_d - (in_d_end_t - in_d_end); + + const int index_c_out_end = weights->info()->dimension(0); + const int index_c_in_end = weights->info()->dimension(1); + const T *const in_ptr_start = reinterpret_cast(src->buffer() + src->info()->offset_first_element_in_bytes()) + id[4] * input_stride_n; + + execute_window_loop(window_w, [&](const Coordinates & id_w) + { + /* + * This is the loop in the weights, and it goes along OFM (output feature map) + */ + const auto weights_ptr_start = reinterpret_cast(wei.ptr()); + T out_temp = static_cast(0); + T *out_ptr = reinterpret_cast(out.ptr()); + for(int index_wei_d = wei_d_start, index_in_d = in_d_start; index_wei_d < wei_d_end; ++index_wei_d, ++index_in_d) + { + const auto in_ptr_d = in_ptr_start + index_in_d * input_stride_d; + const auto weights_ptr_d = weights_ptr_start + index_wei_d * kernel_stride_d; + for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h) + { + const T *const in_ptr_row = in_ptr_d + index_in_h * input_stride_h; + const T *const weights_ptr_row = weights_ptr_d + index_wei_h * kernel_stride_h; + for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w) + { + const T *in_ptr_mover = in_ptr_row + index_in_w * input_stride_w; + const T *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w; + int index_c_in = 0; + vector_type out_temp_vec = wrapper::vdup_n(static_cast(0), tag_type()); + vector_type w_vec = wrapper::vdup_n(static_cast(0), tag_type()); + for(; index_c_in <= index_c_in_end - num_elems_read_per_iteration; + index_c_in += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration) + { + const auto src_vec = wrapper::vloadq(in_ptr_mover); + //Load Cin weights + for(unsigned int k = 0; k < num_elems_read_per_iteration; ++k, weights_ptr_mover += index_c_out_end) + { + w_vec = wrapper::vsetlane(*weights_ptr_mover, w_vec, k); + } + out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec); + } + out_temp += vreduce(out_temp_vec); + for(; index_c_in < index_c_in_end; ++index_c_in, ++in_ptr_mover, weights_ptr_mover += index_c_out_end) + { + const auto src_val = *(in_ptr_mover); + const auto w_val = *(weights_ptr_mover); + out_temp += src_val * w_val; + } + } + } + } + *(reinterpret_cast(out_ptr + id_w[0])) = (biases_ptr != nullptr) ? out_temp + biases_ptr[id_w[0]] : out_temp; + }, + wei); + }, + out); +} +} // namespace cpu +} // namespace arm_compute +#endif // SRC_CORE_NEON_KERNELS_CONV3D_LIST_H \ No newline at end of file diff --git a/src/cpu/operators/CpuDirectConv3d.cpp b/src/cpu/operators/CpuDirectConv3d.cpp index 3827910d37..aa74e420a6 100644 --- a/src/cpu/operators/CpuDirectConv3d.cpp +++ b/src/cpu/operators/CpuDirectConv3d.cpp @@ -40,10 +40,10 @@ CpuDirectConv3d::CpuDirectConv3d(std::shared_ptr memory_manager) { } -void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo conv_info) +void CpuDirectConv3d::configure(ITensorInfo *src0, ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo conv_info) { - ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, conv_info); - ARM_COMPUTE_ERROR_ON(src->data_layout() != DataLayout::NDHWC); + ARM_COMPUTE_LOG_PARAMS(src0, src1, src2, dst, conv_info); + ARM_COMPUTE_ERROR_ON(src0->data_layout() != DataLayout::NDHWC); _conv_kernel = std::make_unique(); @@ -55,7 +55,7 @@ void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const IT _dim_split = Window::DimY; - _conv_kernel->configure(src, weights, biases, dst, conv_info); + _conv_kernel->configure(src0, src1, src2, dst, conv_info); //Configure Activation Layer _is_activationlayer_enabled = conv_info.act_info.enabled(); @@ -66,16 +66,12 @@ void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const IT } } -Status CpuDirectConv3d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo conv_info) +Status CpuDirectConv3d::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo conv_info) { - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst); - - // output might not be initialized since it can be an intermediate tensor of another layer - DataType data_type = src->data_type(); - TensorInfo accumulator(dst->clone()->set_is_resizable(true).reset_padding().set_data_type(data_type)); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst); // Validate Convolution kernel - ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConv3dKernel::validate(src, weights, biases, &accumulator, conv_info)); + ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConv3dKernel::validate(src0, src1, src2, dst, conv_info)); if(conv_info.act_info.enabled()) { diff --git a/src/cpu/operators/CpuDirectConv3d.h b/src/cpu/operators/CpuDirectConv3d.h index ad04dee0fa..f7c3099be0 100644 --- a/src/cpu/operators/CpuDirectConv3d.h +++ b/src/cpu/operators/CpuDirectConv3d.h @@ -57,23 +57,31 @@ public: ~CpuDirectConv3d(); /** Set the input, weights, biases and output tensor info. * - * @param[in, out] src Input tensor info. - * @param[in] weights Set of kernels to convolve the input volume. - * The 2nd dimension must be the same as the input's volume 1st dimension. - * Data type supported: Same as @p src. - * @param[in] biases Set of biases. Can be nullptr. Data type supported: Same as @p src. + * Valid data layouts: + * - NDHWC + * + * Valid data type configurations: + * |src0 |src1 |src2 |dst | + * |:--------------|:------------------|:------|:--------------| + * |F16 |F16 |F16 |F16 | + * |F32 |F32 |F32 |F32 | + * + * @param[in, out] src0 Input tensor info. + * @param[in] src1 Set of kernels to convolve the input volume. + * The 2nd dimension must be the same as the src0's volume 1st dimension. + * @param[in] src2 Set of biases. Can be nullptr. * @param[out] dst Output tensor info. * The 1st dimensions must be equal to the 1st dimension of the @p kernels tensor. * @param[in] conv_info Contains padding, stride, acitvation information. */ - void configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo conv_info); + void configure(ITensorInfo *src0, ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo conv_info); /** Static function to check if given info will lead to a valid configuration * * Similar to CpuDirectConv3d::configure() * * @return a status */ - static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo conv_info); + static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo conv_info); // Inherited methods overridden: void run(ITensorPack &tensors) override; diff --git a/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp b/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp index 1c4326b494..88e73dc72a 100644 --- a/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp +++ b/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp @@ -37,36 +37,36 @@ namespace kernels { namespace { -Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info) +Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info) { - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights, dst); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_layout() != DataLayout::NDHWC, "Only NDHWC layout supported"); + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(src0, src1, dst); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->data_layout() != DataLayout::NDHWC, "Only NDHWC layout supported"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv3d_info.act_info.enabled(), "Fused activation not supported"); - ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights); + ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src0); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(1) != src->dimension(0), "Weights feature map dimension should match the respective src's one"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->num_dimensions() > 5, "Weights can be at most 5 dimensional"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->dimension(1) != src0->dimension(0), "Weights feature map dimension should match the respective src's one"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->num_dimensions() > 5, "Weights can be at most 5 dimensional"); - ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(2) > (src->dimension(1) + conv3d_info.padding.left + conv3d_info.padding.right)); - ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(3) > (src->dimension(2) + conv3d_info.padding.top + conv3d_info.padding.bottom)); - ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(4) > (src->dimension(3) + conv3d_info.padding.front + conv3d_info.padding.back)); + ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(2) > (src0->dimension(1) + conv3d_info.padding.left + conv3d_info.padding.right)); + ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(3) > (src0->dimension(2) + conv3d_info.padding.top + conv3d_info.padding.bottom)); + ARM_COMPUTE_RETURN_ERROR_ON(src1->dimension(4) > (src0->dimension(3) + conv3d_info.padding.front + conv3d_info.padding.back)); - if(biases != nullptr) + if(src2 != nullptr) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->dimension(0) != weights->dimension(0), "Biases size and number of dst feature maps should match"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->num_dimensions() > 1, "Biases should be one dimensional"); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->dimension(0) != src1->dimension(0), "Biases size and number of dst feature maps should match"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src2->num_dimensions() > 1, "Biases should be one dimensional"); } // Checks performed when dst is configured if(dst->total_size() != 0) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->dimension(0) != weights->dimension(0), "Weights and dst OFMs should match"); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), misc::shape_calculator::compute_conv3d_shape(src->tensor_shape(), weights->tensor_shape(), conv3d_info)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->dimension(0) != src1->dimension(0), "Weights and dst OFMs should match"); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), misc::shape_calculator::compute_conv3d_shape(src0->tensor_shape(), src1->tensor_shape(), conv3d_info)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, dst); } return Status{}; @@ -78,27 +78,27 @@ ClDirectConv3dKernel::ClDirectConv3dKernel() _type = CLKernelType::DIRECT; } -void ClDirectConv3dKernel::configure(const CLCompileContext &compile_context, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, +void ClDirectConv3dKernel::configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv3d_info) { - ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst); + ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst); // Perform validation - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, biases, dst, conv3d_info)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, conv3d_info)); // Create window and update padding - const DataType data_type = src->data_type(); - const size_t src_width = src->dimension(1); - const size_t src_height = src->dimension(2); - const size_t src_depth = src->dimension(3); - const size_t src_channels = src->dimension(0); + const DataType data_type = src0->data_type(); + const size_t src_width = src0->dimension(1); + const size_t src_height = src0->dimension(2); + const size_t src_depth = src0->dimension(3); + const size_t src_channels = src0->dimension(0); const size_t dst_width = dst->dimension(1); const size_t dst_height = dst->dimension(2); const size_t dst_depth = dst->dimension(3); const size_t dst_channels = dst->dimension(0); - const size_t weights_width = weights->dimension(2); - const size_t weights_height = weights->dimension(3); - const size_t weights_depth = weights->dimension(4); + const size_t weights_width = src1->dimension(2); + const size_t weights_height = src1->dimension(3); + const size_t weights_depth = src1->dimension(4); const size_t pad_left = conv3d_info.padding.left; const size_t pad_top = conv3d_info.padding.top; const size_t pad_front = conv3d_info.padding.front; @@ -108,7 +108,7 @@ void ClDirectConv3dKernel::configure(const CLCompileContext &compile_context, co const size_t n0 = std::min(dst->dimension(0), static_cast(4u)); const size_t m0 = (dst->tensor_shape()[0] > 16) ? ((data_type == DataType::F32) ? 2U : 4U) : 1U; - const size_t k0 = adjust_vec_size(8u, src->dimension(0)); + const size_t k0 = adjust_vec_size(8u, src0->dimension(0)); const size_t partial_store_n0 = dst->dimension(0) % n0; CLBuildOptions build_options; @@ -136,7 +136,7 @@ void ClDirectConv3dKernel::configure(const CLCompileContext &compile_context, co build_options.add_option("-DM0=" + support::cpp11::to_string(m0)); build_options.add_option("-DK0=" + support::cpp11::to_string(k0)); build_options.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0)); - build_options.add_option_if(biases != nullptr, std::string("-DHAS_BIAS")); + build_options.add_option_if(src2 != nullptr, std::string("-DHAS_BIAS")); std::string kernel_name = "direct_convolution3d_ndhwc"; _kernel = create_kernel(compile_context, kernel_name, build_options.options()); @@ -169,9 +169,9 @@ void ClDirectConv3dKernel::configure(const CLCompileContext &compile_context, co _config_id += support::cpp11::to_string(dst_channels); } -Status ClDirectConv3dKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info) +Status ClDirectConv3dKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info) { - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv3d_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, src2, dst, conv3d_info)); return Status{}; } diff --git a/src/gpu/cl/kernels/ClDirectConv3dKernel.h b/src/gpu/cl/kernels/ClDirectConv3dKernel.h index 9ac8f0d7b3..485c900826 100644 --- a/src/gpu/cl/kernels/ClDirectConv3dKernel.h +++ b/src/gpu/cl/kernels/ClDirectConv3dKernel.h @@ -61,21 +61,21 @@ public: * |F32 |F32 |F32 |F32 | * * @param[in] compile_context The compile context to be used. - * @param[in] src Source tensor. 4 lower dimensions represent a single src [IFM, width, height, depth], + * @param[in] src0 Source tensor. 4 lower dimensions represent a single src [IFM, width, height, depth], * while every optional dimension from 5 and above represent a batch of srcs. - * @param[in] weights Weights tensor. Weights are 5D tensor with dimensions [OFM, IFM, kernel_w, kernel_h, kernel_d]. - * @param[in] biases Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. + * @param[in] src1 Weights tensor. Weights are 5D tensor with dimensions [OFM, IFM, kernel_w, kernel_h, kernel_d]. + * @param[in] src2 Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. * @param[out] dst Destination tensor. 4 lower dimensions represent a single dst [OFM, width, height, depth], while the rest represent batch of dsts. * @param[in] conv3d_info Contains strides, padding, rounding, activation, dilation and fast math information. Activation and fast math are currently unused. */ - void configure(const CLCompileContext &compile_context, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv3d_info); + void configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv3d_info); /** Static function to check if given info will lead to a valid configuration * * Similar to ClDirectConv3dKernel::configure() * * @return a status */ - static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info); + static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info); // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override; diff --git a/src/gpu/cl/operators/ClDirectConv3d.cpp b/src/gpu/cl/operators/ClDirectConv3d.cpp index d10165814b..5d37f07f31 100644 --- a/src/gpu/cl/operators/ClDirectConv3d.cpp +++ b/src/gpu/cl/operators/ClDirectConv3d.cpp @@ -30,19 +30,19 @@ namespace arm_compute { namespace opencl { -void ClDirectConv3d::configure(const CLCompileContext &compile_context, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv3d_info) +void ClDirectConv3d::configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv3d_info) { - ARM_COMPUTE_ERROR_ON_NULLPTR(src); + ARM_COMPUTE_ERROR_ON_NULLPTR(src0); // Configure direct convolution 3d kernel auto k = std::make_unique(); - k->configure(compile_context, src, weights, biases, dst, conv3d_info); + k->configure(compile_context, src0, src1, src2, dst, conv3d_info); _direct_conv3d_kernel = std::move(k); } -Status ClDirectConv3d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info) +Status ClDirectConv3d::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info) { - ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClDirectConv3dKernel::validate(src, weights, biases, dst, conv3d_info)); + ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClDirectConv3dKernel::validate(src0, src1, src2, dst, conv3d_info)); return Status{}; } diff --git a/src/gpu/cl/operators/ClDirectConv3d.h b/src/gpu/cl/operators/ClDirectConv3d.h index ce9135b812..d8ffefc450 100644 --- a/src/gpu/cl/operators/ClDirectConv3d.h +++ b/src/gpu/cl/operators/ClDirectConv3d.h @@ -57,15 +57,15 @@ public: * |F32 |F32 |F32 |F32 | * * @param[in] compile_context The compile context to be used. - * @param[in] src Source tensor. 4 lower dimensions represent a single src [IFM, width, height, depth], + * @param[in] src0 Source tensor. 4 lower dimensions represent a single src [IFM, width, height, depth], * while every optional dimension from 5 and above represent a batch of srcs. - * @param[in] weights Weights tensor. Weights are 5D tensor with dimensions [OFM, IFM, kernel_w, kernel_h, kernel_d]. - * @param[in] biases Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. + * @param[in] src1 Weights tensor. Weights are 5D tensor with dimensions [OFM, IFM, kernel_w, kernel_h, kernel_d]. + * @param[in] src2 Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. * @param[out] dst Destination tensor. 4 lower dimensions represent a single dst [OFM, width, height, depth], while the rest represent batch of dsts. * @param[in] conv3d_info Contains strides, padding, rounding, activation, dilation and fast math information. Activation and fast math are currently unused. * */ - void configure(const CLCompileContext &compile_context, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo &conv3d_info); + void configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo &conv3d_info); /** Static function to check if given info will lead to a valid configuration * @@ -73,7 +73,7 @@ public: * * @return a status */ - static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo &conv3d_info); + static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo &conv3d_info); // Inherited method overridden void run(ITensorPack &tensors) override; diff --git a/src/runtime/NEON/functions/NEConv3D.cpp b/src/runtime/NEON/functions/NEConv3D.cpp index b5e2e2a843..3bb66c44b0 100644 --- a/src/runtime/NEON/functions/NEConv3D.cpp +++ b/src/runtime/NEON/functions/NEConv3D.cpp @@ -27,7 +27,6 @@ #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" #include "src/common/utils/Log.h" -#include "src/core/helpers/MemoryHelpers.h" #include "src/cpu/operators/CpuDirectConv3d.h" namespace arm_compute @@ -58,7 +57,7 @@ void NEConv3D::configure(ITensor *input, const ITensor *weights, const ITensor * f->configure(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info); _impl->op = std::move(f); - if(_impl->op) + if(_impl->op != nullptr) { _impl->run_pack = { { ACL_SRC_0, input }, { ACL_SRC_1, weights }, { ACL_SRC_2, biases }, { ACL_DST, output } }; } @@ -73,7 +72,7 @@ Status NEConv3D::validate(const ITensorInfo *input, const ITensorInfo *weights, void NEConv3D::run() { - if(_impl->op) + if(_impl->op != nullptr) { _impl->op->run(_impl->run_pack); } -- cgit v1.2.1