aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2017-08-10 15:10:40 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commit06da39df202f0ee8eae83c4ff5588c426a0d5fd3 (patch)
treeb0f9454e2558224a2e687729e699ed3d4cc9d5bb /src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
parentd60a6b9d7977c6bd63ff7c523bed84d42363898b (diff)
downloadComputeLibrary-06da39df202f0ee8eae83c4ff5588c426a0d5fd3.tar.gz
COMPMID-345: Added support for 5x5 kernels in NEDirectConvolution
Change-Id: I25cd8f057566b59ce40e2acf14714e83a286ae4e Reviewed-on: http://mpd-gerrit.cambridge.arm.com/83791 Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp358
1 files changed, 339 insertions, 19 deletions
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
index 3a102edd10..3dd07fcdbe 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
@@ -508,6 +508,159 @@ inline qint8x8x3_t load_matrix_row(const qint8_t *ptr)
}
template <unsigned int stridex>
+float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
+ const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position);
+
+inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
+{
+ const float32x4x3_t m00 =
+ {
+ {
+ vld1q_dup_f32(m0),
+ vld1q_dup_f32(m1),
+ vld1q_dup_f32(m2)
+ }
+ };
+ return m00;
+}
+
+inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
+{
+ const float32x4x2_t m00 =
+ {
+ {
+ vld1q_dup_f32(m3),
+ vld1q_dup_f32(m4)
+ }
+ };
+ return m00;
+}
+
+inline float32x4x3_t load_input(const float *const in)
+{
+ const float32x4x3_t vin =
+ {
+ {
+ vld1q_f32(in),
+ vld1q_f32(in + 4),
+ vld1q_f32(in + 8)
+ }
+ };
+ return vin;
+}
+
+template <>
+inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
+ const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
+{
+ ARM_COMPUTE_UNUSED(fixed_point_position);
+ const float32x4x3_t vin0 = load_input(in_0);
+ const float32x4x3_t vin1 = load_input(in_1);
+ const float32x4x3_t vin2 = load_input(in_2);
+ const float32x4x3_t vin3 = load_input(in_3);
+ const float32x4x3_t vin4 = load_input(in_4);
+ const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
+ const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
+ const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
+ const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
+ const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
+ const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
+ const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
+ const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
+ const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
+ const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
+
+ float32x4x2_t out =
+ {
+ {
+ vmulq_f32(vin0.val[0], m00.val[0]),
+ vmulq_f32(vin0.val[1], m00.val[0])
+ }
+ };
+
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
+
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
+
+ out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
+
+ out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
+
+ out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
+
+ out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
+
+ return out;
+}
+
+template <>
+inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
+ const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
+{
+ ARM_COMPUTE_UNUSED(fixed_point_position);
+ float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
+ return out;
+}
+
+template <>
+inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
+ const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
+{
+ float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
+ return out;
+}
+
+template <unsigned int stridex>
float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position);
template <>
@@ -548,17 +701,22 @@ inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, c
};
out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
+
out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
+
out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
+
out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
+
out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
+
out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
@@ -841,7 +999,6 @@ public:
1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
*/
-
for(int oz = 0; oz < num_planes_z; ++oz)
{
const int zoffset = id.z() + oz;
@@ -871,17 +1028,19 @@ public:
// Step 2
for(int p = 1; p < kernel_depth; ++p)
{
- const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
- const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
- const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
- const auto vk_r0 = load_matrix_row(ptr_k_r0);
- const auto vk_r1 = load_matrix_row(ptr_k_r1);
- const auto vk_r2 = load_matrix_row(ptr_k_r2);
+ const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
+ const uint8_t *input_base = input_ptr + p * input_stride_z;
+ const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
+ const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
+ const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
+ const auto vk_r0 = load_matrix_row(ptr_k_r0);
+ const auto vk_r1 = load_matrix_row(ptr_k_r1);
+ const auto vk_r2 = load_matrix_row(ptr_k_r2);
for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
{
- auto in_top = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
- auto in_mid = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
- auto in_low = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
+ auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
+ auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
+ auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
@@ -897,6 +1056,118 @@ public:
}
};
+template <typename T1, typename T2, unsigned int stridex>
+class convolver_5x5
+{
+public:
+ static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
+ const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
+ {
+ ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
+ const int input_stride_x = input->info()->strides_in_bytes().x();
+ const int input_stride_y = input->info()->strides_in_bytes().y();
+ const int input_stride_z = input->info()->strides_in_bytes().z();
+ const int output_stride_y = output->info()->strides_in_bytes().y();
+ const int output_stride_z = output->info()->strides_in_bytes().z();
+ const int kernel_stride_x = weights->info()->strides_in_bytes().x();
+ const int kernel_stride_y = weights->info()->strides_in_bytes().y();
+ const int kernel_stride_z = weights->info()->strides_in_bytes().z();
+ const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
+ const int output_w = output->info()->dimension(0);
+ const int output_h = output->info()->dimension(1);
+ const int num_planes_z = window.z().end() - window.z().start();
+ const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
+ const int kernel_depth = weights->info()->dimension(Window::DimZ);
+ const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
+ const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
+ const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
+ const int fixed_point_position = input->info()->fixed_point_position();
+
+ // setup output window for the iterator
+ Window window_out = window;
+ window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
+ window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
+ window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
+
+ // setup input window for the iterator
+ Window window_in = window;
+ // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
+ window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
+ window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
+ window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Window window_k = calculate_max_window(*weights->info(), Steps(1u));
+
+ Iterator out(output, window_out);
+ Iterator in(input, window_in);
+ Iterator k(weights, window_k);
+
+ const uint8_t *k_ptr = k.ptr();
+
+ execute_window_loop(window_out, [&](const Coordinates & id)
+ {
+ const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y;
+ uint8_t *out_ptr = out.ptr();
+ int ih = 0;
+ int oh = 0;
+ for(int oz = 0; oz < num_planes_z; ++oz)
+ {
+ const int zoffset = id.z() + oz;
+ uint8_t *p_out_base = out_ptr + oz * output_stride_z;
+ // Step 1
+ {
+ const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
+ for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
+ {
+ auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
+ auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
+ auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
+ auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
+ auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
+ auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
+ for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
+ in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
+ {
+ auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
+ store_results<stridex>(p_out, vres);
+ }
+ }
+ }
+ // Step 2
+ for(int p = 1; p < kernel_depth; ++p)
+ {
+ const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
+
+ for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
+ {
+ auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
+ auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
+ auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
+ auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
+ auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
+ auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
+ for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
+ in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
+ {
+ auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
+ accumulate_results<stridex>(p_out, vres);
+ }
+ }
+ }
+ }
+ },
+ in, out);
+ }
+};
+
template <typename T1, typename T2>
inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
@@ -938,6 +1209,28 @@ inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_i
ARM_COMPUTE_ERROR("Not implemented");
}
}
+
+template <typename T1, typename T2>
+inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
+ const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
+{
+ const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
+ switch(conv_stride_x)
+ {
+ case 1:
+ convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
+ break;
+ case 2:
+ convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
+ break;
+ case 3:
+ convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+}
+
} // namespace
NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
@@ -958,6 +1251,9 @@ void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITens
"Pad > 0 not supported for 1x1 weights");
ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1),
"Pad > 1 not supported for 3x3 weights");
+ ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 5 && (std::get<0>(conv_info.pad()) > 2 || std::get<1>(conv_info.pad()) > 2),
+ "Pad > 2 not supported for 5x5 weights");
+
ARM_COMPUTE_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
ARM_COMPUTE_ERROR_ON(weights->info()->dimension(2) != input->info()->dimension(2));
ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != weights->info()->dimension(1));
@@ -1032,16 +1328,25 @@ void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITens
break;
}
case 3:
+ case 5:
{
- if(input->info()->data_type() == DataType::F32)
- {
- _num_elems_read_per_iteration = 12;
- _num_elems_written_per_iteration = 16 >> conv_stride_x;
- }
- else
+ switch(input->info()->data_type())
{
- _num_elems_read_per_iteration = 24;
- _num_elems_written_per_iteration = 32 >> conv_stride_x;
+ case DataType::F32:
+ _num_elems_read_per_iteration = 12;
+ _num_elems_written_per_iteration = 16 >> conv_stride_x;
+ break;
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ case DataType::F16:
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+ case DataType::QS8:
+ case DataType::QS16:
+ _num_elems_read_per_iteration = 24;
+ _num_elems_written_per_iteration = 32 >> conv_stride_x;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported.");
+ break;
}
// Calculate right and bottom border
@@ -1060,6 +1365,7 @@ void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITens
AccessWindowHorizontal output_access(output->info(), 0, _num_elems_written_per_iteration);
update_window_and_padding(win, input_access, weights_access, output_access);
output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
+
break;
}
default:
@@ -1127,9 +1433,23 @@ void NEDirectConvolutionLayerKernel::run(const Window &window)
}
break;
}
+ case 5:
+ {
+ switch(_input->info()->data_type())
+ {
+ case DataType::F32:
+ convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported");
+ break;
+ }
+ break;
+ }
+
default:
{
- ARM_COMPUTE_ERROR("Only kernel sizes 1x1 and 3x3 are supported.");
+ ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
break;
}
}