aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2017-10-09 15:05:40 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commitab18212dd287cc0ec9b7c1a2c72455fe75ebd13d (patch)
treef802205d85785da671ddd1949ba61b9dc36a3035 /src
parented194b1fbec6627896c5c12f74460b9142b98f7d (diff)
downloadComputeLibrary-ab18212dd287cc0ec9b7c1a2c72455fe75ebd13d.tar.gz
COMPMID-616 - Optimizing GEMMLowp on NEON intrinsics
Change-Id: Ibbeff5d37249b6e8fc34ad496035a1511c9da5a3 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/94072 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/NEON/kernels/NEGEMMLowpFinalizeKernel.cpp509
-rw-r--r--src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp395
-rw-r--r--src/core/NEON/kernels/NEGEMMLowpReductionKernel.cpp431
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowp.cpp160
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp163
5 files changed, 1281 insertions, 377 deletions
diff --git a/src/core/NEON/kernels/NEGEMMLowpFinalizeKernel.cpp b/src/core/NEON/kernels/NEGEMMLowpFinalizeKernel.cpp
new file mode 100644
index 0000000000..400c6d9d8c
--- /dev/null
+++ b/src/core/NEON/kernels/NEGEMMLowpFinalizeKernel.cpp
@@ -0,0 +1,509 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/NEON/kernels/NEGEMMLowpFinalizeKernel.h"
+
+#include "arm_compute/core/AccessWindowStatic.h"
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/Window.h"
+
+#include <arm_neon.h>
+#include <cstddef>
+#include <cstdint>
+
+using namespace arm_compute;
+
+namespace arm_compute
+{
+class Coordinates;
+} // namespace arm_compute
+
+template <bool add_a_offset, bool add_b_offset>
+void NEGEMMLowpFinalizeKernel::finalize(const Window &window)
+{
+ const int32x4_t c_offset_s32 = vdupq_n_s32(_c_offset);
+ const int32x4_t shift_s32 = vdupq_n_s32(-_shift);
+
+ Window collapsed_window = window.collapse_if_possible(IKernel::window(), Window::DimZ);
+
+ if(add_a_offset && add_b_offset) // true, true
+ {
+ // Set window for vector_sum_col
+ Window win_vector_sum_col(collapsed_window);
+ win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
+ if(!_slide_vector_sum_col)
+ {
+ win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
+ }
+
+ // Set window for vector_sum_row
+ Window win_vector_sum_row(collapsed_window);
+ win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
+
+ Iterator vector_sum_col(_vector_sum_col, win_vector_sum_col);
+ Iterator vector_sum_row(_vector_sum_row, win_vector_sum_row);
+ Iterator mm_result(_mm_result, window);
+ Iterator out(_output, window);
+
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ // Compute the leftover term due to a_offset.
+ int32x4x4_t a_offset_term_s32 =
+ {
+ {
+ vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 0),
+ vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 4),
+ vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 8),
+ vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 12)
+ }
+ };
+
+ a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], _a_offset);
+ a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], _a_offset);
+ a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], _a_offset);
+ a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], _a_offset);
+
+ // Compute the leftover term due to b_offset.
+ int32x4_t b_offset_term_s32 = vld1q_dup_s32(reinterpret_cast<const int32_t *>(vector_sum_row.ptr()) + id.y());
+ b_offset_term_s32 = vmulq_n_s32(b_offset_term_s32, _b_offset);
+
+ // Add a_offset_term_s32 and b_offset_term_s32
+ int32x4x4_t offset_term_s32 =
+ {
+ {
+ vdupq_n_s32(_k_offset),
+ vdupq_n_s32(_k_offset),
+ vdupq_n_s32(_k_offset),
+ vdupq_n_s32(_k_offset)
+ }
+ };
+
+ offset_term_s32.val[0] = vaddq_s32(offset_term_s32.val[0], vaddq_s32(a_offset_term_s32.val[0], b_offset_term_s32));
+ offset_term_s32.val[1] = vaddq_s32(offset_term_s32.val[1], vaddq_s32(a_offset_term_s32.val[1], b_offset_term_s32));
+ offset_term_s32.val[2] = vaddq_s32(offset_term_s32.val[2], vaddq_s32(a_offset_term_s32.val[2], b_offset_term_s32));
+ offset_term_s32.val[3] = vaddq_s32(offset_term_s32.val[3], vaddq_s32(a_offset_term_s32.val[3], b_offset_term_s32));
+
+ // Add c_offset
+ offset_term_s32.val[0] = vaddq_s32(offset_term_s32.val[0], c_offset_s32);
+ offset_term_s32.val[1] = vaddq_s32(offset_term_s32.val[1], c_offset_s32);
+ offset_term_s32.val[2] = vaddq_s32(offset_term_s32.val[2], c_offset_s32);
+ offset_term_s32.val[3] = vaddq_s32(offset_term_s32.val[3], c_offset_s32);
+
+ int32x4x4_t in_s32 =
+ {
+ {
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 0),
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 4),
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 8),
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 12)
+ }
+ };
+
+ // Add the offset terms to GEMM's result
+ in_s32.val[0] = vaddq_s32(in_s32.val[0], offset_term_s32.val[0]);
+ in_s32.val[1] = vaddq_s32(in_s32.val[1], offset_term_s32.val[1]);
+ in_s32.val[2] = vaddq_s32(in_s32.val[2], offset_term_s32.val[2]);
+ in_s32.val[3] = vaddq_s32(in_s32.val[3], offset_term_s32.val[3]);
+
+ // Multiply by c_mult_int
+ in_s32.val[0] = vmulq_n_s32(in_s32.val[0], _c_mult_int);
+ in_s32.val[1] = vmulq_n_s32(in_s32.val[1], _c_mult_int);
+ in_s32.val[2] = vmulq_n_s32(in_s32.val[2], _c_mult_int);
+ in_s32.val[3] = vmulq_n_s32(in_s32.val[3], _c_mult_int);
+
+ // Shift final result (negative value shift right)
+ in_s32.val[0] = vshlq_s32(in_s32.val[0], shift_s32);
+ in_s32.val[1] = vshlq_s32(in_s32.val[1], shift_s32);
+ in_s32.val[2] = vshlq_s32(in_s32.val[2], shift_s32);
+ in_s32.val[3] = vshlq_s32(in_s32.val[3], shift_s32);
+
+ // Convert S32 to U16
+ const int16x8x2_t in_u16 =
+ {
+ {
+ vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
+ vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3])),
+ }
+ };
+
+ // Convert U16 to U8
+ const uint8x16_t out_u8 = vcombine_u8(vqmovun_s16(in_u16.val[0]), vqmovun_s16(in_u16.val[1]));
+
+ vst1q_u8(out.ptr(), out_u8);
+ },
+ vector_sum_col, vector_sum_row, mm_result, out);
+ }
+ else if(!add_a_offset && add_b_offset) // false, true
+ {
+ // Set window for vector_sum_row
+ Window win_vector_sum_row(collapsed_window);
+ win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
+
+ Iterator vector_sum_row(_vector_sum_row, win_vector_sum_row);
+ Iterator mm_result(_mm_result, window);
+ Iterator out(_output, window);
+
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ // Compute the leftover term due to b_offset.
+ int32x4_t b_offset_term_s32 = vld1q_dup_s32(reinterpret_cast<const int32_t *>(vector_sum_row.ptr()) + id.y());
+ b_offset_term_s32 = vmulq_n_s32(b_offset_term_s32, _b_offset);
+
+ // Add b_offset_term_s32 and c_offset_term_s32
+ int32x4_t offset_term_s32 = vaddq_s32(b_offset_term_s32, c_offset_s32);
+
+ int32x4x4_t in_s32 =
+ {
+ {
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 0),
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 4),
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 8),
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 12)
+ }
+ };
+
+ // Add the offset terms to GEMM's result
+ in_s32.val[0] = vaddq_s32(in_s32.val[0], offset_term_s32);
+ in_s32.val[1] = vaddq_s32(in_s32.val[1], offset_term_s32);
+ in_s32.val[2] = vaddq_s32(in_s32.val[2], offset_term_s32);
+ in_s32.val[3] = vaddq_s32(in_s32.val[3], offset_term_s32);
+
+ // Multiply by c_mult_int
+ in_s32.val[0] = vmulq_n_s32(in_s32.val[0], _c_mult_int);
+ in_s32.val[1] = vmulq_n_s32(in_s32.val[1], _c_mult_int);
+ in_s32.val[2] = vmulq_n_s32(in_s32.val[2], _c_mult_int);
+ in_s32.val[3] = vmulq_n_s32(in_s32.val[3], _c_mult_int);
+
+ // Shift final result (negative value shift right)
+ in_s32.val[0] = vshlq_s32(in_s32.val[0], shift_s32);
+ in_s32.val[1] = vshlq_s32(in_s32.val[1], shift_s32);
+ in_s32.val[2] = vshlq_s32(in_s32.val[2], shift_s32);
+ in_s32.val[3] = vshlq_s32(in_s32.val[3], shift_s32);
+
+ // Convert S32 to U16
+ const int16x8x2_t in_u16 =
+ {
+ {
+ vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
+ vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3])),
+ }
+ };
+
+ // Convert U16 to U8
+ const uint8x16_t out_u8 = vcombine_u8(vqmovun_s16(in_u16.val[0]), vqmovun_s16(in_u16.val[1]));
+
+ vst1q_u8(out.ptr(), out_u8);
+ },
+ vector_sum_row, mm_result, out);
+ }
+ else if(add_a_offset && !add_b_offset) // true, false
+ {
+ // Set window for vector_sum_col
+ Window win_vector_sum_col(collapsed_window);
+ win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
+ if(!_slide_vector_sum_col)
+ {
+ win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
+ }
+
+ Iterator vector_sum_col(_vector_sum_col, win_vector_sum_col);
+ Iterator mm_result(_mm_result, window);
+ Iterator out(_output, window);
+
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ // Compute the leftover term due to a_offset.
+ int32x4x4_t a_offset_term_s32 =
+ {
+ {
+ vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 0),
+ vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 4),
+ vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 8),
+ vld1q_s32(reinterpret_cast<const int32_t *>(vector_sum_col.ptr()) + 12)
+ }
+ };
+
+ a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], _a_offset);
+ a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], _a_offset);
+ a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], _a_offset);
+ a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], _a_offset);
+
+ // Add a_offset_term_s32 and b_offset_term_s32
+ int32x4x4_t offset_term_s32 =
+ {
+ {
+ vaddq_s32(c_offset_s32, a_offset_term_s32.val[0]),
+ vaddq_s32(c_offset_s32, a_offset_term_s32.val[1]),
+ vaddq_s32(c_offset_s32, a_offset_term_s32.val[2]),
+ vaddq_s32(c_offset_s32, a_offset_term_s32.val[3])
+ }
+ };
+
+ int32x4x4_t in_s32 =
+ {
+ {
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 0),
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 4),
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 8),
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 12)
+ }
+ };
+
+ // Add the offset terms to GEMM's result
+ in_s32.val[0] = vaddq_s32(in_s32.val[0], offset_term_s32.val[0]);
+ in_s32.val[1] = vaddq_s32(in_s32.val[1], offset_term_s32.val[1]);
+ in_s32.val[2] = vaddq_s32(in_s32.val[2], offset_term_s32.val[2]);
+ in_s32.val[3] = vaddq_s32(in_s32.val[3], offset_term_s32.val[3]);
+
+ // Multiply by c_mult_int
+ in_s32.val[0] = vmulq_n_s32(in_s32.val[0], _c_mult_int);
+ in_s32.val[1] = vmulq_n_s32(in_s32.val[1], _c_mult_int);
+ in_s32.val[2] = vmulq_n_s32(in_s32.val[2], _c_mult_int);
+ in_s32.val[3] = vmulq_n_s32(in_s32.val[3], _c_mult_int);
+
+ // Shift final result (negative value shift right)
+ in_s32.val[0] = vshlq_s32(in_s32.val[0], shift_s32);
+ in_s32.val[1] = vshlq_s32(in_s32.val[1], shift_s32);
+ in_s32.val[2] = vshlq_s32(in_s32.val[2], shift_s32);
+ in_s32.val[3] = vshlq_s32(in_s32.val[3], shift_s32);
+
+ // Convert S32 to U16
+ const int16x8x2_t in_u16 =
+ {
+ {
+ vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
+ vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3]))
+ }
+ };
+
+ // Convert U16 to U8
+ const uint8x16_t out_u8 = vcombine_u8(vqmovun_s16(in_u16.val[0]), vqmovun_s16(in_u16.val[1]));
+
+ vst1q_u8(out.ptr(), out_u8);
+ },
+ vector_sum_col, mm_result, out);
+ }
+ else // false, false
+ {
+ Iterator mm_result(_mm_result, window);
+ Iterator out(_output, window);
+
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ int32x4x4_t in_s32 =
+ {
+ {
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 0),
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 4),
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 8),
+ vld1q_s32(reinterpret_cast<const int32_t *>(mm_result.ptr()) + 12)
+ }
+ };
+
+ // Add the offset terms to GEMM's result
+ in_s32.val[0] = vaddq_s32(in_s32.val[0], c_offset_s32);
+ in_s32.val[1] = vaddq_s32(in_s32.val[1], c_offset_s32);
+ in_s32.val[2] = vaddq_s32(in_s32.val[2], c_offset_s32);
+ in_s32.val[3] = vaddq_s32(in_s32.val[3], c_offset_s32);
+
+ // Multiply by c_mult_int
+ in_s32.val[0] = vmulq_n_s32(in_s32.val[0], _c_mult_int);
+ in_s32.val[1] = vmulq_n_s32(in_s32.val[1], _c_mult_int);
+ in_s32.val[2] = vmulq_n_s32(in_s32.val[2], _c_mult_int);
+ in_s32.val[3] = vmulq_n_s32(in_s32.val[3], _c_mult_int);
+
+ // Shift final result (negative value shift right)
+ in_s32.val[0] = vshlq_s32(in_s32.val[0], shift_s32);
+ in_s32.val[1] = vshlq_s32(in_s32.val[1], shift_s32);
+ in_s32.val[2] = vshlq_s32(in_s32.val[2], shift_s32);
+ in_s32.val[3] = vshlq_s32(in_s32.val[3], shift_s32);
+
+ // Convert S32 to U16
+ const int16x8x2_t in_u16 =
+ {
+ {
+ vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1])),
+ vcombine_s16(vqmovn_s32(in_s32.val[2]), vqmovn_s32(in_s32.val[3]))
+ }
+ };
+
+ // Convert U16 to U8
+ const uint8x16_t out_u8 = vcombine_u8(vqmovun_s16(in_u16.val[0]), vqmovun_s16(in_u16.val[1]));
+
+ vst1q_u8(out.ptr(), out_u8);
+ },
+ mm_result, out);
+ }
+}
+
+NEGEMMLowpFinalizeKernel::NEGEMMLowpFinalizeKernel()
+ : _func(nullptr), _vector_sum_col(nullptr), _vector_sum_row(nullptr), _mm_result(nullptr), _output(nullptr), _a_offset(0), _b_offset(0), _c_offset(0), _k_offset(0), _c_mult_int(0), _shift(0),
+ _slide_vector_sum_col(true)
+{
+}
+
+void NEGEMMLowpFinalizeKernel::configure(const ITensor *vector_sum_col, const ITensor *vector_sum_row, const ITensor *mm_result, ITensor *output, int32_t num_mtx_a_cols, int32_t a_offset,
+ int32_t b_offset,
+ int32_t c_offset, int32_t c_mult_int, int32_t shift)
+{
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
+
+ TensorShape mm_result_shape = mm_result->info()->tensor_shape();
+ TensorShape output_shape = output->info()->tensor_shape();
+
+ mm_result_shape.collapse(2);
+ output_shape.collapse(2);
+
+ ARM_COMPUTE_ERROR_ON_MSG(mm_result_shape[2] != output_shape[2], "mm_result tensor must have the same number of batches of output tensor");
+
+ // If a_offset == 0, vector_sum_col can be a nullptr
+ if(a_offset != 0)
+ {
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_col, 1, DataType::S32);
+ ARM_COMPUTE_ERROR_ON(vector_sum_col->info()->dimension(0) != mm_result->info()->dimension(0));
+
+ TensorShape vector_sum_col_shape = vector_sum_col->info()->tensor_shape();
+ vector_sum_col_shape.collapse(1);
+
+ // Check if vector_sum_col_shape should be slidden or not
+ // Don't slide vector_sum_col_shape along the y dimension if vector_sum_col_shape has just 1 dimension and vector_sum_row_shape more than 1
+ // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
+ _slide_vector_sum_col = vector_sum_col_shape[1] != 1;
+ }
+
+ // If b_offset == 0, vector_sum_row can be a nullptr
+ if(b_offset != 0)
+ {
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_row, 1, DataType::S32);
+ ARM_COMPUTE_ERROR_ON(vector_sum_row->info()->dimension(0) != mm_result->info()->dimension(1));
+
+ TensorShape vector_sum_row_shape = vector_sum_row->info()->tensor_shape();
+ vector_sum_row_shape.collapse(1);
+
+ ARM_COMPUTE_ERROR_ON_MSG(vector_sum_row_shape[1] != output_shape[2], "mm_result tensor must have the same number of batches of output tensor");
+
+ if(a_offset != 0)
+ {
+ TensorShape vector_sum_col_shape = vector_sum_col->info()->tensor_shape();
+ vector_sum_col_shape.collapse(1);
+
+ ARM_COMPUTE_ERROR_ON_MSG(vector_sum_col_shape[1] != 1
+ && vector_sum_col_shape[1] != vector_sum_row_shape[1],
+ "vector_sum_col tensor must have the same number of batches of vector_sum_row_shape or the number of batches must be set to 1");
+ }
+ }
+
+ _vector_sum_col = vector_sum_col;
+ _vector_sum_row = vector_sum_row;
+ _mm_result = mm_result;
+ _output = output;
+ _a_offset = a_offset;
+ _b_offset = b_offset;
+ _k_offset = a_offset * b_offset * num_mtx_a_cols;
+ _c_offset = c_offset;
+ _c_mult_int = c_mult_int;
+ _shift = shift;
+
+ constexpr unsigned int num_elems_processed_per_iteration = 16;
+
+ // Configure kernel window
+ Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration));
+
+ AccessWindowHorizontal mm_result_access(mm_result->info(), 0, num_elems_processed_per_iteration);
+ AccessWindowHorizontal output_result_access(output->info(), 0, num_elems_processed_per_iteration);
+
+ // Accordingly with a_offset and b_offset, we can have 4 cases:
+ // a_offset != 0 && b_offset != 0
+ // a_offset = 0 && b_offset != 0
+ // a_offset != 0 && b_offset = 0
+ // a_offset = 0 && b_offset = 0
+ if(a_offset != 0 && b_offset != 0)
+ {
+ // Set the function to use
+ _func = &NEGEMMLowpFinalizeKernel::finalize<true, true>;
+
+ AccessWindowStatic vector_sum_row_access(vector_sum_row->info(), 0, 0, vector_sum_row->info()->dimension(0), 0);
+ AccessWindowHorizontal vector_sum_col_access(vector_sum_col->info(), 0, num_elems_processed_per_iteration);
+
+ update_window_and_padding(win,
+ vector_sum_col_access,
+ vector_sum_row_access,
+ mm_result_access,
+ output_result_access);
+ }
+ else if(a_offset == 0 && b_offset != 0)
+ {
+ // Set the function to use
+ _func = &NEGEMMLowpFinalizeKernel::finalize<false, true>;
+
+ AccessWindowStatic vector_sum_row_access(vector_sum_row->info(), 0, 0, vector_sum_row->info()->dimension(0), 0);
+
+ update_window_and_padding(win,
+ vector_sum_row_access,
+ mm_result_access,
+ output_result_access);
+ }
+ else if(a_offset != 0 && b_offset == 0)
+ {
+ // Set the function to use
+ _func = &NEGEMMLowpFinalizeKernel::finalize<true, false>;
+
+ AccessWindowHorizontal vector_sum_col_access(vector_sum_col->info(), 0, num_elems_processed_per_iteration);
+
+ update_window_and_padding(win,
+ vector_sum_col_access,
+ mm_result_access,
+ output_result_access);
+ }
+ else
+ {
+ // Set the function to use
+ _func = &NEGEMMLowpFinalizeKernel::finalize<false, false>;
+
+ update_window_and_padding(win,
+ mm_result_access,
+ output_result_access);
+ }
+
+ output_result_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->info()->tensor_shape()));
+
+ INEKernel::configure(win);
+}
+
+void NEGEMMLowpFinalizeKernel::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+
+ (this->*_func)(window);
+}
diff --git a/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp
index cbba4461a2..3e614a8bfc 100644
--- a/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2017 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,7 @@
*/
#include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
+#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
@@ -45,35 +46,43 @@ class Coordinates;
} // namespace arm_compute
NEGEMMLowpMatrixMultiplyKernel::NEGEMMLowpMatrixMultiplyKernel()
- : _input0(nullptr), _input1(nullptr), _output(nullptr), _a_offset(0), _b_offset(0), _output_offset(0), _output_mult_int(0), _shift(0)
+ : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true)
{
}
-void NEGEMMLowpMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output,
- int32_t a_offset, int32_t b_offset, int32_t output_offset, int32_t output_mult_int, int32_t shift)
+void NEGEMMLowpMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::U8);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
-
- _input0 = input0;
- _input1 = input1;
- _output = output;
- _a_offset = a_offset;
- _b_offset = b_offset;
- _output_offset = output_offset;
- _output_mult_int = output_mult_int;
- _shift = shift;
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+
+ // Check if matrix B should be slidden or not
+ // Don't slide matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
+ // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
+ TensorShape in0_shape = input0->info()->tensor_shape();
+ TensorShape in1_shape = input1->info()->tensor_shape();
+ TensorShape out_shape = output->info()->tensor_shape();
+
+ in0_shape.collapse(2);
+ in1_shape.collapse(2);
+ out_shape.collapse(2);
+
+ ARM_COMPUTE_ERROR_ON_MSG(in0_shape[2] != out_shape[2], "Output tensor must have the same number of batches of input0 tensor");
+ ARM_COMPUTE_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2], "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
+
+ _input0 = input0;
+ _input1 = input1;
+ _output = output;
+ _slide_matrix_b = in1_shape[2] != 1;
constexpr unsigned int num_elems_processed_per_iteration_x = 16;
constexpr unsigned int num_elems_processed_per_iteration_y = 4;
Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
- AccessWindowHorizontal in0_access(input0->info(), 0, num_elems_processed_per_iteration_x);
+ AccessWindowStatic in0_access(input0->info(), 0, 0, ceil_to_multiple(input0->info()->dimension(0), 8), input0->info()->dimension(1));
AccessWindowHorizontal in1_access(input1->info(), 0, num_elems_processed_per_iteration_x);
+ AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
update_window_and_padding(win, in0_access, in1_access, output_access);
@@ -88,337 +97,145 @@ void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
const size_t in_b_stride = _input1->info()->strides_in_bytes()[1];
- const size_t out_stride = _output->info()->strides_in_bytes()[1];
+ const size_t out_stride = _output->info()->strides_in_bytes()[1] / _output->info()->element_size();
- /* Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix */
+ // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
Window win_a(window);
win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
- win_a.set(Window::DimY, Window::Dimension(window.y().start() >> 2, window.y().end() >> 2, 1));
+ win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, window.y().end() / 4, 1));
- /* Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix */
- Window win_b(window);
- win_b.set(Window::DimX, Window::Dimension(window.x().start() >> 4, window.x().end() >> 4, in_b_stride));
+ // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the columns of the output matrix
+ Window win_b;
+ // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
+ // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
+ if(_slide_matrix_b)
+ {
+ win_b = window;
+ }
+ win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, in_b_stride));
win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
- /* The step x and step y for the output matrix has been already set using in configure() */
+ // The step x and step y for the output matrix has been already set using in configure()
Iterator ina(_input0, win_a);
Iterator inb(_input1, win_b);
Iterator out(_output, window);
- const int32x4_t voffset_a = vdupq_n_s32(_a_offset);
- const int32x4_t voffset_b = vdupq_n_s32(_b_offset);
- const int32x4_t vshiftr = vdupq_n_s32(-_shift);
-
const int width_b = _input1->info()->dimension(0);
// The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
// The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
// All the values needed for computing a single 4x4 block will be read from consecutive memory positions
- execute_window_loop(window, [&](const Coordinates &)
+ execute_window_loop(window, [&](const Coordinates & id)
{
const uint8_t *mtx_a0 = ina.ptr();
const uint8_t *mtx_b0 = inb.ptr();
+ // Note: Since the input are all positives, we can use uint32_t
// Accumulators for the block 0
- int32x4x4_t c0 =
+ uint32x4x4_t c0 =
{
{
- vdupq_n_s32(_output_offset),
- vdupq_n_s32(_output_offset),
- vdupq_n_s32(_output_offset),
- vdupq_n_s32(_output_offset)
+ vdupq_n_u32(0),
+ vdupq_n_u32(0),
+ vdupq_n_u32(0),
+ vdupq_n_u32(0)
}
};
// Accumulators for the block 1
- int32x4x4_t c1 =
+ uint32x4x4_t c1 =
{
{
- vdupq_n_s32(_output_offset),
- vdupq_n_s32(_output_offset),
- vdupq_n_s32(_output_offset),
- vdupq_n_s32(_output_offset)
+ vdupq_n_u32(0),
+ vdupq_n_u32(0),
+ vdupq_n_u32(0),
+ vdupq_n_u32(0)
}
};
// Accumulators for the block 2
- int32x4x4_t c2 =
+ uint32x4x4_t c2 =
{
{
- vdupq_n_s32(_output_offset),
- vdupq_n_s32(_output_offset),
- vdupq_n_s32(_output_offset),
- vdupq_n_s32(_output_offset)
+ vdupq_n_u32(0),
+ vdupq_n_u32(0),
+ vdupq_n_u32(0),
+ vdupq_n_u32(0)
}
};
// Accumulators for the block 3
- int32x4x4_t c3 =
+ uint32x4x4_t c3 =
{
{
- vdupq_n_s32(_output_offset),
- vdupq_n_s32(_output_offset),
- vdupq_n_s32(_output_offset),
- vdupq_n_s32(_output_offset)
+ vdupq_n_u32(0),
+ vdupq_n_u32(0),
+ vdupq_n_u32(0),
+ vdupq_n_u32(0)
}
};
- int k = 0;
- // This for loop performs 4 accumulations per iteration
- for(; k <= (width_b - 64); k += 64, mtx_a0 += 16, mtx_b0 += 64)
+ for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
{
- const uint8x8_t p00 = vld1_u8(mtx_a0 + 0);
- const uint8x8_t p01 = vld1_u8(mtx_a0 + 8);
- const uint8x8_t q00l = vld1_u8(mtx_b0 + 0);
- const uint8x8_t q00h = vld1_u8(mtx_b0 + 8);
- const uint8x8_t q01l = vld1_u8(mtx_b0 + 16);
- const uint8x8_t q01h = vld1_u8(mtx_b0 + 24);
- const uint8x8_t q02l = vld1_u8(mtx_b0 + 32);
- const uint8x8_t q02h = vld1_u8(mtx_b0 + 40);
- const uint8x8_t q03l = vld1_u8(mtx_b0 + 48);
- const uint8x8_t q03h = vld1_u8(mtx_b0 + 56);
-
- const int32x4_t ia0l = vaddw_s16(voffset_a, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(p00))));
- const int32x4_t ia0h = vaddw_s16(voffset_a, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(p00))));
- const int32x4_t ia1l = vaddw_s16(voffset_a, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(p01))));
- const int32x4_t ia1h = vaddw_s16(voffset_a, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(p01))));
-
- const int32x2x4_t ia0 =
- {
- {
- vget_low_s32(ia0l),
- vget_high_s32(ia0l),
- vget_low_s32(ia0h),
- vget_high_s32(ia0h)
- }
- };
-
- const int32x2x4_t ia1 =
- {
- {
- vget_low_s32(ia1l),
- vget_high_s32(ia1l),
- vget_low_s32(ia1h),
- vget_high_s32(ia1h)
- }
- };
-
- const int32x4x4_t ib0 =
- {
- {
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q00l)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q00l)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q00h)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q00h))))
- }
- };
+ const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
+ const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
- const int32x4x4_t ib1 =
- {
- {
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q01l)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q01l)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q01h)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q01h))))
- }
- };
+ // Convert a00_u8 to uint16_t and get the lower part
+ const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
- const int32x4x4_t ib2 =
+ // Convert b00_u8 to int16_t
+ const uint16x4x4_t b00_u16 =
{
{
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q02l)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q02l)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q02h)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q02h))))
- }
- };
-
- const int32x4x4_t ib3 =
- {
- {
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q03l)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q03l)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q03h)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q03h))))
- }
- };
-
- // 4x4 block 0 - Accumulation 0
- c0.val[0] = vmlaq_lane_s32(c0.val[0], ib0.val[0], ia0.val[0], 0);
- c0.val[1] = vmlaq_lane_s32(c0.val[1], ib0.val[0], ia0.val[0], 1);
- c0.val[2] = vmlaq_lane_s32(c0.val[2], ib0.val[0], ia0.val[1], 0);
- c0.val[3] = vmlaq_lane_s32(c0.val[3], ib0.val[0], ia0.val[1], 1);
- // 4x4 block 0 - Accumulation 1
- c0.val[0] = vmlaq_lane_s32(c0.val[0], ib1.val[0], ia0.val[2], 0);
- c0.val[1] = vmlaq_lane_s32(c0.val[1], ib1.val[0], ia0.val[2], 1);
- c0.val[2] = vmlaq_lane_s32(c0.val[2], ib1.val[0], ia0.val[3], 0);
- c0.val[3] = vmlaq_lane_s32(c0.val[3], ib1.val[0], ia0.val[3], 1);
- // 4x4 block 0 - Accumulation 2
- c0.val[0] = vmlaq_lane_s32(c0.val[0], ib2.val[0], ia1.val[0], 0);
- c0.val[1] = vmlaq_lane_s32(c0.val[1], ib2.val[0], ia1.val[0], 1);
- c0.val[2] = vmlaq_lane_s32(c0.val[2], ib2.val[0], ia1.val[1], 0);
- c0.val[3] = vmlaq_lane_s32(c0.val[3], ib2.val[0], ia1.val[1], 1);
- // 4x4 block 0 - Accumulation 3
- c0.val[0] = vmlaq_lane_s32(c0.val[0], ib3.val[0], ia1.val[2], 0);
- c0.val[1] = vmlaq_lane_s32(c0.val[1], ib3.val[0], ia1.val[2], 1);
- c0.val[2] = vmlaq_lane_s32(c0.val[2], ib3.val[0], ia1.val[3], 0);
- c0.val[3] = vmlaq_lane_s32(c0.val[3], ib3.val[0], ia1.val[3], 1);
-
- // 4x4 block 1 - Accumulation 0
- c1.val[0] = vmlaq_lane_s32(c1.val[0], ib0.val[1], ia0.val[0], 0);
- c1.val[1] = vmlaq_lane_s32(c1.val[1], ib0.val[1], ia0.val[0], 1);
- c1.val[2] = vmlaq_lane_s32(c1.val[2], ib0.val[1], ia0.val[1], 0);
- c1.val[3] = vmlaq_lane_s32(c1.val[3], ib0.val[1], ia0.val[1], 1);
- // 4x4 block 1 - Accumulation 1
- c1.val[0] = vmlaq_lane_s32(c1.val[0], ib1.val[1], ia0.val[2], 0);
- c1.val[1] = vmlaq_lane_s32(c1.val[1], ib1.val[1], ia0.val[2], 1);
- c1.val[2] = vmlaq_lane_s32(c1.val[2], ib1.val[1], ia0.val[3], 0);
- c1.val[3] = vmlaq_lane_s32(c1.val[3], ib1.val[1], ia0.val[3], 1);
- // 4x4 block 1 - Accumulation 2
- c1.val[0] = vmlaq_lane_s32(c1.val[0], ib2.val[1], ia1.val[0], 0);
- c1.val[1] = vmlaq_lane_s32(c1.val[1], ib2.val[1], ia1.val[0], 1);
- c1.val[2] = vmlaq_lane_s32(c1.val[2], ib2.val[1], ia1.val[1], 0);
- c1.val[3] = vmlaq_lane_s32(c1.val[3], ib2.val[1], ia1.val[1], 1);
- // 4x4 block 1 - Accumulation 3
- c1.val[0] = vmlaq_lane_s32(c1.val[0], ib3.val[1], ia1.val[2], 0);
- c1.val[1] = vmlaq_lane_s32(c1.val[1], ib3.val[1], ia1.val[2], 1);
- c1.val[2] = vmlaq_lane_s32(c1.val[2], ib3.val[1], ia1.val[3], 0);
- c1.val[3] = vmlaq_lane_s32(c1.val[3], ib3.val[1], ia1.val[3], 1);
-
- // 4x4 block 2 - Accumulation 0
- c2.val[0] = vmlaq_lane_s32(c2.val[0], ib0.val[2], ia0.val[0], 0);
- c2.val[1] = vmlaq_lane_s32(c2.val[1], ib0.val[2], ia0.val[0], 1);
- c2.val[2] = vmlaq_lane_s32(c2.val[2], ib0.val[2], ia0.val[1], 0);
- c2.val[3] = vmlaq_lane_s32(c2.val[3], ib0.val[2], ia0.val[1], 1);
- // 4x4 block 2 - Accumulation 1
- c2.val[0] = vmlaq_lane_s32(c2.val[0], ib1.val[2], ia0.val[2], 0);
- c2.val[1] = vmlaq_lane_s32(c2.val[1], ib1.val[2], ia0.val[2], 1);
- c2.val[2] = vmlaq_lane_s32(c2.val[2], ib1.val[2], ia0.val[3], 0);
- c2.val[3] = vmlaq_lane_s32(c2.val[3], ib1.val[2], ia0.val[3], 1);
- // 4x4 block 2 - Accumulation 2
- c2.val[0] = vmlaq_lane_s32(c2.val[0], ib2.val[2], ia1.val[0], 0);
- c2.val[1] = vmlaq_lane_s32(c2.val[1], ib2.val[2], ia1.val[0], 1);
- c2.val[2] = vmlaq_lane_s32(c2.val[2], ib2.val[2], ia1.val[1], 0);
- c2.val[3] = vmlaq_lane_s32(c2.val[3], ib2.val[2], ia1.val[1], 1);
- // 4x4 block 2 - Accumulation 3
- c2.val[0] = vmlaq_lane_s32(c2.val[0], ib3.val[2], ia1.val[2], 0);
- c2.val[1] = vmlaq_lane_s32(c2.val[1], ib3.val[2], ia1.val[2], 1);
- c2.val[2] = vmlaq_lane_s32(c2.val[2], ib3.val[2], ia1.val[3], 0);
- c2.val[3] = vmlaq_lane_s32(c2.val[3], ib3.val[2], ia1.val[3], 1);
-
- // 4x4 block 3 - Accumulation 0
- c3.val[0] = vmlaq_lane_s32(c3.val[0], ib0.val[3], ia0.val[0], 0);
- c3.val[1] = vmlaq_lane_s32(c3.val[1], ib0.val[3], ia0.val[0], 1);
- c3.val[2] = vmlaq_lane_s32(c3.val[2], ib0.val[3], ia0.val[1], 0);
- c3.val[3] = vmlaq_lane_s32(c3.val[3], ib0.val[3], ia0.val[1], 1);
- // 4x4 block 3 - Accumulation 1
- c3.val[0] = vmlaq_lane_s32(c3.val[0], ib1.val[3], ia0.val[2], 0);
- c3.val[1] = vmlaq_lane_s32(c3.val[1], ib1.val[3], ia0.val[2], 1);
- c3.val[2] = vmlaq_lane_s32(c3.val[2], ib1.val[3], ia0.val[3], 0);
- c3.val[3] = vmlaq_lane_s32(c3.val[3], ib1.val[3], ia0.val[3], 1);
- // 4x4 block 3 - Accumulation 2
- c3.val[0] = vmlaq_lane_s32(c3.val[0], ib2.val[3], ia1.val[0], 0);
- c3.val[1] = vmlaq_lane_s32(c3.val[1], ib2.val[3], ia1.val[0], 1);
- c3.val[2] = vmlaq_lane_s32(c3.val[2], ib2.val[3], ia1.val[1], 0);
- c3.val[3] = vmlaq_lane_s32(c3.val[3], ib2.val[3], ia1.val[1], 1);
- // 4x4 block 3 - Accumulation 3
- c3.val[0] = vmlaq_lane_s32(c3.val[0], ib3.val[3], ia1.val[2], 0);
- c3.val[1] = vmlaq_lane_s32(c3.val[1], ib3.val[3], ia1.val[2], 1);
- c3.val[2] = vmlaq_lane_s32(c3.val[2], ib3.val[3], ia1.val[3], 0);
- c3.val[3] = vmlaq_lane_s32(c3.val[3], ib3.val[3], ia1.val[3], 1);
- }
-
- // This for loop handles the left-over accumulations
- for(; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
- {
- const uint8x8_t p00 = vld1_u8(mtx_a0);
- const uint8x8_t q00l = vld1_u8(mtx_b0);
- const uint8x8_t q00h = vld1_u8(mtx_b0 + 8);
-
- const int32x4_t ia0 = vaddw_s16(voffset_a, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(p00))));
-
- const int32x2x2_t ia =
- {
- {
- vget_low_s32(ia0),
- vget_high_s32(ia0)
- }
- };
-
- const int32x4x4_t ib0 =
- {
- {
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q00l)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q00l)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(q00h)))),
- vaddw_s16(voffset_b, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(q00h))))
+ vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
+ vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
+ vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
+ vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
}
};
// 4x4 block 0
- c0.val[0] = vmlaq_lane_s32(c0.val[0], ib0.val[0], ia.val[0], 0);
- c0.val[1] = vmlaq_lane_s32(c0.val[1], ib0.val[0], ia.val[0], 1);
- c0.val[2] = vmlaq_lane_s32(c0.val[2], ib0.val[0], ia.val[1], 0);
- c0.val[3] = vmlaq_lane_s32(c0.val[3], ib0.val[0], ia.val[1], 1);
+ c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
+ c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
+ c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
+ c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
// 4x4 block 1
- c1.val[0] = vmlaq_lane_s32(c1.val[0], ib0.val[1], ia.val[0], 0);
- c1.val[1] = vmlaq_lane_s32(c1.val[1], ib0.val[1], ia.val[0], 1);
- c1.val[2] = vmlaq_lane_s32(c1.val[2], ib0.val[1], ia.val[1], 0);
- c1.val[3] = vmlaq_lane_s32(c1.val[3], ib0.val[1], ia.val[1], 1);
+ c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
+ c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
+ c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
+ c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
// 4x4 block 2
- c2.val[0] = vmlaq_lane_s32(c2.val[0], ib0.val[2], ia.val[0], 0);
- c2.val[1] = vmlaq_lane_s32(c2.val[1], ib0.val[2], ia.val[0], 1);
- c2.val[2] = vmlaq_lane_s32(c2.val[2], ib0.val[2], ia.val[1], 0);
- c2.val[3] = vmlaq_lane_s32(c2.val[3], ib0.val[2], ia.val[1], 1);
+ c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
+ c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
+ c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
+ c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
// 4x4 block 3
- c3.val[0] = vmlaq_lane_s32(c3.val[0], ib0.val[3], ia.val[0], 0);
- c3.val[1] = vmlaq_lane_s32(c3.val[1], ib0.val[3], ia.val[0], 1);
- c3.val[2] = vmlaq_lane_s32(c3.val[2], ib0.val[3], ia.val[1], 0);
- c3.val[3] = vmlaq_lane_s32(c3.val[3], ib0.val[3], ia.val[1], 1);
+ c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
+ c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
+ c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
+ c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
}
- c0.val[0] = vshlq_s32(vmulq_n_s32(c0.val[0], _output_mult_int), vshiftr);
- c0.val[1] = vshlq_s32(vmulq_n_s32(c0.val[1], _output_mult_int), vshiftr);
- c0.val[2] = vshlq_s32(vmulq_n_s32(c0.val[2], _output_mult_int), vshiftr);
- c0.val[3] = vshlq_s32(vmulq_n_s32(c0.val[3], _output_mult_int), vshiftr);
-
- c1.val[0] = vshlq_s32(vmulq_n_s32(c1.val[0], _output_mult_int), vshiftr);
- c1.val[1] = vshlq_s32(vmulq_n_s32(c1.val[1], _output_mult_int), vshiftr);
- c1.val[2] = vshlq_s32(vmulq_n_s32(c1.val[2], _output_mult_int), vshiftr);
- c1.val[3] = vshlq_s32(vmulq_n_s32(c1.val[3], _output_mult_int), vshiftr);
-
- c2.val[0] = vshlq_s32(vmulq_n_s32(c2.val[0], _output_mult_int), vshiftr);
- c2.val[1] = vshlq_s32(vmulq_n_s32(c2.val[1], _output_mult_int), vshiftr);
- c2.val[2] = vshlq_s32(vmulq_n_s32(c2.val[2], _output_mult_int), vshiftr);
- c2.val[3] = vshlq_s32(vmulq_n_s32(c2.val[3], _output_mult_int), vshiftr);
-
- c3.val[0] = vshlq_s32(vmulq_n_s32(c3.val[0], _output_mult_int), vshiftr);
- c3.val[1] = vshlq_s32(vmulq_n_s32(c3.val[1], _output_mult_int), vshiftr);
- c3.val[2] = vshlq_s32(vmulq_n_s32(c3.val[2], _output_mult_int), vshiftr);
- c3.val[3] = vshlq_s32(vmulq_n_s32(c3.val[3], _output_mult_int), vshiftr);
-
- const uint8x16x4_t r =
- {
- {
- vcombine_u8(vqmovun_s16(vcombine_s16(vqmovn_s32(c0.val[0]), vqmovn_s32(c1.val[0]))),
- vqmovun_s16(vcombine_s16(vqmovn_s32(c2.val[0]), vqmovn_s32(c3.val[0])))),
- vcombine_u8(vqmovun_s16(vcombine_s16(vqmovn_s32(c0.val[1]), vqmovn_s32(c1.val[1]))),
- vqmovun_s16(vcombine_s16(vqmovn_s32(c2.val[1]), vqmovn_s32(c3.val[1])))),
- vcombine_u8(vqmovun_s16(vcombine_s16(vqmovn_s32(c0.val[2]), vqmovn_s32(c1.val[2]))),
- vqmovun_s16(vcombine_s16(vqmovn_s32(c2.val[2]), vqmovn_s32(c3.val[2])))),
- vcombine_u8(vqmovun_s16(vcombine_s16(vqmovn_s32(c0.val[3]), vqmovn_s32(c1.val[3]))),
- vqmovun_s16(vcombine_s16(vqmovn_s32(c2.val[3]), vqmovn_s32(c3.val[3]))))
- }
- };
-
- uint8_t *const mtx_out = out.ptr();
- vst1q_u8(mtx_out + 0 * out_stride, r.val[0]);
- vst1q_u8(mtx_out + 1 * out_stride, r.val[1]);
- vst1q_u8(mtx_out + 2 * out_stride, r.val[2]);
- vst1q_u8(mtx_out + 3 * out_stride, r.val[3]);
+ auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
+ vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
+ vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
+ vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
+ vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
+ vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
+ vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
+ vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
+ vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
+ vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
+ vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
+ vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
+ vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
+ vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
+ vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
+ vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
+ vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
},
ina, inb, out);
}
diff --git a/src/core/NEON/kernels/NEGEMMLowpReductionKernel.cpp b/src/core/NEON/kernels/NEGEMMLowpReductionKernel.cpp
new file mode 100644
index 0000000000..3f841bbf59
--- /dev/null
+++ b/src/core/NEON/kernels/NEGEMMLowpReductionKernel.cpp
@@ -0,0 +1,431 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/NEON/kernels/NEGEMMLowpReductionKernel.h"
+
+#include "arm_compute/core/AccessWindowStatic.h"
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/Window.h"
+
+#include <arm_neon.h>
+#include <cstddef>
+#include <cstdint>
+
+using namespace arm_compute;
+
+namespace arm_compute
+{
+class Coordinates;
+} // namespace arm_compute
+
+INEGEMMLowpReductionKernel::INEGEMMLowpReductionKernel()
+ : _input(), _output(), _k(0), _is_reshaped(false)
+{
+}
+
+void NEGEMMLowpMatrixAReductionKernel::configure(const ITensor *mtx_a_interleaved4x4, ITensor *vector_sum_row, int32_t num_mtx_a_cols, bool is_interleaved4x4)
+{
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mtx_a_interleaved4x4, 1, DataType::U8);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_row, 1, DataType::S32);
+
+ _input = mtx_a_interleaved4x4;
+ _output = vector_sum_row;
+ _k = num_mtx_a_cols;
+ _is_reshaped = is_interleaved4x4;
+
+ const unsigned int num_elems_processed_per_iteration = _is_reshaped ? 4 : 1;
+
+ // Configure kernel window
+ Window win = calculate_max_window(*_output->info(), Steps(num_elems_processed_per_iteration));
+
+ AccessWindowStatic input_access(_input->info(), 0, 0, ceil_to_multiple(_input->info()->dimension(0), 16), _input->info()->dimension(1));
+ AccessWindowHorizontal output_access(_output->info(), 0, num_elems_processed_per_iteration);
+
+ update_window_and_padding(win,
+ input_access,
+ output_access);
+
+ output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), _output->info()->tensor_shape()));
+
+ INEKernel::configure(win);
+}
+
+void NEGEMMLowpMatrixAReductionKernel::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+
+ Window collapsed_window = window.collapse_if_possible(IKernel::window(), Window::DimY);
+
+ Window win_input(collapsed_window);
+ win_input.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_input.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_input.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator in(_input, win_input);
+ Iterator out(_output, collapsed_window);
+
+ if(_is_reshaped)
+ {
+ execute_window_loop(collapsed_window, [&](const Coordinates & id)
+ {
+ // Note: Since the input is unsigned char, we can safely use unsigned int for the accumulation
+ uint32x4_t sum_row = vdupq_n_u32(0);
+
+ const uint8_t *matrix_a = in.ptr() + (id.x() / 4) * _input->info()->strides_in_bytes()[1] + id.y() * _input->info()->strides_in_bytes()[2];
+
+#if __arm__
+ asm volatile("PLD [%0, #128*4]" ::"r"(matrix_a));
+#endif /* __arm__ */
+
+ int i = 0;
+ // This for loop performs 4 accumulations
+ for(; i <= (_k - 4); i += 4)
+ {
+ const uint8x16_t a0_u8 = vld1q_u8(matrix_a + i * 4);
+
+ // Convert U8 to U16
+ uint16x4x4_t a0_u16 =
+ {
+ {
+ vget_low_u16(vmovl_u8(vget_low_u8(a0_u8))),
+ vget_high_u16(vmovl_u8(vget_low_u8(a0_u8))),
+ vget_low_u16(vmovl_u8(vget_high_u8(a0_u8))),
+ vget_high_u16(vmovl_u8(vget_high_u8(a0_u8)))
+ }
+ };
+
+ // Accumulate to U16
+ a0_u16.val[0] = vadd_u16(a0_u16.val[0], a0_u16.val[1]);
+ a0_u16.val[0] = vadd_u16(a0_u16.val[0], a0_u16.val[2]);
+ a0_u16.val[0] = vadd_u16(a0_u16.val[0], a0_u16.val[3]);
+
+ // Accumulate to U32
+ sum_row = vaddw_u16(sum_row, a0_u16.val[0]);
+ }
+
+ // This for loop performs the leftover accumulations
+ for(; i < _k; ++i)
+ {
+ const uint8x8_t a0_u8 = vld1_u8(matrix_a + i * 4);
+
+ // Convert U8 to U16
+ const uint16x4_t a0_u16 = vget_low_u16(vmovl_u8(a0_u8));
+
+ // Accumulate to U32
+ sum_row = vaddw_u16(sum_row, a0_u16);
+ }
+
+ auto vector_sum_row = reinterpret_cast<int32_t *>(out.ptr());
+
+ vst1q_s32(vector_sum_row, vreinterpretq_s32_u32(sum_row));
+ },
+ in, out);
+ }
+ else // it is not reshaped
+ {
+ execute_window_loop(collapsed_window, [&](const Coordinates & id)
+ {
+ // Note: Since the input is unsigned char, we can safely use unsigned int for the accumulation
+ uint32x4_t sum_row_s32 = vdupq_n_u32(0);
+ unsigned int sum_row = 0;
+
+ const uint8_t *matrix_a = in.ptr() + id.x() * _input->info()->strides_in_bytes()[1] + +id.y() * _input->info()->strides_in_bytes()[2];
+
+#if __arm__
+ asm volatile("PLD [%0, #128*4]" ::"r"(matrix_a));
+#endif /* __arm__ */
+
+ int i = 0;
+ // This for loop performs 16 accumulations
+ for(; i <= (_k - 16); i += 16)
+ {
+ const uint8x16_t a0_u8 = vld1q_u8(matrix_a + i);
+
+ // Partial accumulations in U16
+ const uint16x8_t tmp_sum0 = vaddl_u8(vget_low_u8(a0_u8), vget_high_u8(a0_u8));
+
+ // Accumulate to U32
+ sum_row_s32 = vaddq_u32(sum_row_s32, vpaddlq_u16(tmp_sum0));
+ }
+
+ // This for loop performs the leftover accumulations
+ for(; i < _k; ++i)
+ {
+ sum_row += static_cast<unsigned int>(matrix_a[i]);
+ }
+
+#if defined(__aarch64__)
+ // Reduction operation available on 64 bit architectures only
+ sum_row += vaddvq_u32(sum_row_s32);
+#else // __aarch64__
+ uint32x2_t tmp = vpadd_u32(vget_high_u32(sum_row_s32), vget_low_u32(sum_row_s32));
+ tmp = vpadd_u32(tmp, tmp);
+
+ sum_row += vget_lane_u32(tmp, 0);
+#endif // __aarch64__
+
+ *(reinterpret_cast<int *>(out.ptr())) = static_cast<int>(sum_row);
+ },
+ in, out);
+ }
+}
+
+void NEGEMMLowpMatrixBReductionKernel::configure(const ITensor *mtx_b_transposed1xW, ITensor *vector_sum_col, int32_t num_mtx_b_rows, bool is_transposed1xW)
+{
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mtx_b_transposed1xW, 1, DataType::U8);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_col, 1, DataType::S32);
+
+ _input = mtx_b_transposed1xW;
+ _output = vector_sum_col;
+ _k = num_mtx_b_rows;
+ _is_reshaped = is_transposed1xW;
+
+ constexpr unsigned int num_elems_processed_per_iteration = 16;
+
+ // Configure kernel window
+ Window win = calculate_max_window(*vector_sum_col->info(), Steps(num_elems_processed_per_iteration));
+
+ AccessWindowStatic input_access(_input->info(), 0, 0, ceil_to_multiple(_input->info()->dimension(0), 16), _input->info()->dimension(1));
+ AccessWindowHorizontal output_access(_output->info(), 0, num_elems_processed_per_iteration);
+
+ update_window_and_padding(win,
+ input_access,
+ output_access);
+
+ output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), _output->info()->tensor_shape()));
+
+ INEKernel::configure(win);
+}
+
+void NEGEMMLowpMatrixBReductionKernel::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+
+ Window collapsed_window = window.collapse_if_possible(IKernel::window(), Window::DimY);
+
+ if(_is_reshaped)
+ {
+ Window win_input(collapsed_window);
+ win_input.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_input.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_input.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator in(_input, win_input);
+ Iterator out(_output, collapsed_window);
+
+ execute_window_loop(collapsed_window, [&](const Coordinates & id)
+ {
+ // Note: Since the input is unsigned char, we can safely use unsigned int for the accumulation
+ uint32x4x4_t sum_col =
+ {
+ {
+ vdupq_n_u32(0),
+ vdupq_n_u32(0),
+ vdupq_n_u32(0),
+ vdupq_n_u32(0)
+ }
+ };
+
+ const uint8_t *matrix_b = in.ptr() + (id.x() / 16) * _input->info()->strides_in_bytes()[1] + id.y() * _input->info()->strides_in_bytes()[2];
+
+#if __arm__
+ asm volatile("PLD [%0, #128*4]" ::"r"(matrix_b));
+#endif /* __arm__ */
+
+ int i = 0;
+ for(; i < _k; ++i)
+ {
+ const uint8x16_t b0_u8 = vld1q_u8(matrix_b + i * 16);
+
+ // Convert U8 to U16
+ const uint16x8x2_t b0_u16 =
+ {
+ {
+ vmovl_u8(vget_low_u8(b0_u8)),
+ vmovl_u8(vget_high_u8(b0_u8))
+ }
+ };
+
+ // Accumulate to U32
+ sum_col =
+ {
+ {
+ vaddw_u16(sum_col.val[0], vget_low_u16(b0_u16.val[0])),
+ vaddw_u16(sum_col.val[1], vget_high_u16(b0_u16.val[0])),
+ vaddw_u16(sum_col.val[2], vget_low_u16(b0_u16.val[1])),
+ vaddw_u16(sum_col.val[3], vget_high_u16(b0_u16.val[1]))
+ }
+ };
+ }
+
+ auto vector_sum_col = reinterpret_cast<int32_t *>(out.ptr());
+
+ vst1q_s32(vector_sum_col + 0, vreinterpretq_s32_u32(sum_col.val[0]));
+ vst1q_s32(vector_sum_col + 4, vreinterpretq_s32_u32(sum_col.val[1]));
+ vst1q_s32(vector_sum_col + 8, vreinterpretq_s32_u32(sum_col.val[2]));
+ vst1q_s32(vector_sum_col + 12, vreinterpretq_s32_u32(sum_col.val[3]));
+ },
+ in, out);
+ }
+ else // it is not reshaped
+ {
+ const auto width_matrix_b = static_cast<int>(_input->info()->dimension(0));
+ const auto in_b_stride = static_cast<int>(_input->info()->strides_in_bytes()[1]);
+
+ // The implementation computes 16 elements per iteration
+ const int window_start_x = 16 * info.thread_id;
+ const int window_step_x = 16 * info.num_threads;
+ // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
+ const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
+
+ Window win_out(collapsed_window);
+ win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+
+ Window win_in(win_out);
+ win_in.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator inb(_input, win_in);
+ Iterator out(_output, win_out);
+
+ execute_window_loop(win_out, [&](const Coordinates & id)
+ {
+ if(id.x() > width_matrix_b)
+ {
+ return;
+ }
+
+ // Note: Since the input is unsigned char, we can safely use unsigned int for the accumulation
+ uint32x4x4_t sum_col =
+ {
+ {
+ vdupq_n_u32(0),
+ vdupq_n_u32(0),
+ vdupq_n_u32(0),
+ vdupq_n_u32(0)
+ }
+ };
+
+ const uint8_t *matrix_b = inb.ptr() + id.y() * _input->info()->strides_in_bytes()[2];
+
+#if __arm__
+ asm volatile("PLD [%0, #128*4]" ::"r"(matrix_b));
+ asm volatile("PLD [%0, #128*4]" ::"r"(matrix_b + in_b_stride));
+#endif /* __arm__ */
+
+ int i = 0;
+ // This for loop performs 4 accumulations
+ for(; i <= (_k - 4); i += 4)
+ {
+ const uint8x16_t b0_u8 = vld1q_u8(matrix_b + 0 * in_b_stride);
+ const uint8x16_t b1_u8 = vld1q_u8(matrix_b + 1 * in_b_stride);
+ const uint8x16_t b2_u8 = vld1q_u8(matrix_b + 2 * in_b_stride);
+ const uint8x16_t b3_u8 = vld1q_u8(matrix_b + 3 * in_b_stride);
+
+#if __arm__
+ asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 1 * in_b_stride));
+ asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 2 * in_b_stride));
+ asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 3 * in_b_stride));
+ asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 4 * in_b_stride));
+#endif /* __arm__ */
+
+ // Partial accumulation in u16
+ uint16x8x2_t tmp_sum =
+ {
+ {
+ vdupq_n_u16(0),
+ vdupq_n_u16(0)
+ }
+ };
+
+ tmp_sum.val[0] = vaddw_u8(tmp_sum.val[0], vget_low_u8(b0_u8));
+ tmp_sum.val[0] = vaddw_u8(tmp_sum.val[0], vget_low_u8(b1_u8));
+ tmp_sum.val[0] = vaddw_u8(tmp_sum.val[0], vget_low_u8(b2_u8));
+ tmp_sum.val[0] = vaddw_u8(tmp_sum.val[0], vget_low_u8(b3_u8));
+ tmp_sum.val[1] = vaddw_u8(tmp_sum.val[1], vget_high_u8(b0_u8));
+ tmp_sum.val[1] = vaddw_u8(tmp_sum.val[1], vget_high_u8(b1_u8));
+ tmp_sum.val[1] = vaddw_u8(tmp_sum.val[1], vget_high_u8(b2_u8));
+ tmp_sum.val[1] = vaddw_u8(tmp_sum.val[1], vget_high_u8(b3_u8));
+
+ // Accumulate to U32
+ sum_col =
+ {
+ {
+ vaddw_u16(sum_col.val[0], vget_low_u16(tmp_sum.val[0])),
+ vaddw_u16(sum_col.val[1], vget_high_u16(tmp_sum.val[0])),
+ vaddw_u16(sum_col.val[2], vget_low_u16(tmp_sum.val[1])),
+ vaddw_u16(sum_col.val[3], vget_high_u16(tmp_sum.val[1]))
+ }
+ };
+
+ matrix_b += 4 * in_b_stride;
+ }
+
+ // This for loop perfoms the leftover accumulations
+ for(; i < _k; ++i)
+ {
+ const uint8x16_t b0_u8 = vld1q_u8(matrix_b + 0 * in_b_stride);
+
+ // Convert U8 to U16
+ const uint16x8x2_t b0_u16 =
+ {
+ {
+ vmovl_u8(vget_low_u8(b0_u8)),
+ vmovl_u8(vget_high_u8(b0_u8))
+ }
+ };
+
+ // Accumulate to U32
+ sum_col =
+ {
+ {
+ vaddw_u16(sum_col.val[0], vget_low_u16(b0_u16.val[0])),
+ vaddw_u16(sum_col.val[1], vget_high_u16(b0_u16.val[0])),
+ vaddw_u16(sum_col.val[2], vget_low_u16(b0_u16.val[1])),
+ vaddw_u16(sum_col.val[3], vget_high_u16(b0_u16.val[1]))
+ }
+ };
+
+ matrix_b += in_b_stride;
+ }
+
+ auto vector_sum_col = reinterpret_cast<int32_t *>(out.ptr());
+
+ vst1q_s32(vector_sum_col + 0, vreinterpretq_s32_u32(sum_col.val[0]));
+ vst1q_s32(vector_sum_col + 4, vreinterpretq_s32_u32(sum_col.val[1]));
+ vst1q_s32(vector_sum_col + 8, vreinterpretq_s32_u32(sum_col.val[2]));
+ vst1q_s32(vector_sum_col + 12, vreinterpretq_s32_u32(sum_col.val[3]));
+ },
+ inb, out);
+ }
+} \ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEGEMMLowp.cpp b/src/runtime/NEON/functions/NEGEMMLowp.cpp
index 12136cbcb5..ab7fa079b1 100644
--- a/src/runtime/NEON/functions/NEGEMMLowp.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowp.cpp
@@ -31,120 +31,104 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h"
#include "arm_compute/runtime/TensorAllocator.h"
#include "support/ToolchainSupport.h"
using namespace arm_compute;
-#define NEGEMMLOWP_VALIDATE_DIMENSIONS(a, b, output) \
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN((a), 1, DataType::U8); \
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN((b), 1, DataType::U8); \
- ARM_COMPUTE_ERROR_ON_MSG((a)->info()->dimension(0) != (b)->info()->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B"); \
- ARM_COMPUTE_ERROR_ON_MSG((a)->info()->dimension(1) != (output)->info()->dimension(1), "The C matrix must have the same number of rows as the matrix A"); \
- ARM_COMPUTE_ERROR_ON_MSG((b)->info()->dimension(0) != (output)->info()->dimension(0), "The C matrix must have the same number of columns as the matrix C");
-
NEGEMMLowp::NEGEMMLowp(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _mm_optimised_kernel(nullptr), _interleave_blocked(), _interleave_blocked_transposed(), _tmp_a(),
- _tmp_b()
+ : _memory_group(std::move(memory_manager)), _mm_func(), _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(), _finalize_kernel(), _vector_sum_col(), _vector_sum_row(), _mm_output(), _a_offset(0),
+ _b_offset(0)
{
}
-void NEGEMMLowp::configure(const ITensor *a, const ITensor *b, ITensor *output)
+void NEGEMMLowp::configure(const ITensor *a, const ITensor *b, ITensor *output, int32_t a_offset, int32_t b_offset, int32_t c_offset, int32_t output_mult_int, int32_t shift)
{
- NEGEMMLOWP_VALIDATE_DIMENSIONS(a, b, output);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN((a), 1, DataType::U8);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, b, output);
+ ARM_COMPUTE_ERROR_ON_MSG((a)->info()->dimension(0) != (b)->info()->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
+ ARM_COMPUTE_ERROR_ON_MSG((a)->info()->dimension(1) != (output)->info()->dimension(1), "The output matrix must have the same number of rows as the matrix A");
+ ARM_COMPUTE_ERROR_ON_MSG((b)->info()->dimension(0) != (output)->info()->dimension(0), "The output matrix must have the same number of columns as the matrix B");
+
+ _a_offset = a_offset;
+ _b_offset = b_offset;
+
+ // Initialize matrix multiply output tensor
+ const TensorShape &shape_mm_output = output->info()->tensor_shape();
+ TensorInfo info_mm_output(shape_mm_output, 1, DataType::S32);
+ _mm_output.allocator()->init(info_mm_output);
+ _memory_group.manage(&_mm_output);
- const struct CPUInfo ci = NEScheduler::get().cpu_info();
- const int cpu_has_dotprod = static_cast<int>(ci.CPU) & static_cast<int>(CPUTarget::DOT);
- if(cpu_has_dotprod != 0)
+ // Initialize Matrix B reduction kernel only if _a_offset is not equal to 0
+ if(_a_offset != 0)
{
-#ifdef ARM_COMPUTE_AARCH64_V8_2
- // NEGEMMLowpAArch64V8P4Kernel only compiled in AArch64 targets
- _mm_optimised_kernel = support::cpp14::make_unique<NEGEMMLowpAArch64V8P4Kernel>();
- TensorShape shape_a_int = a->info()->tensor_shape();
- shape_a_int.set(0, a->info()->dimension(0) * 8.f);
- shape_a_int.set(1, std::ceil(a->info()->dimension(1) / 8.f));
-
- TensorShape shape_b_int = b->info()->tensor_shape();
- shape_b_int.set(0, b->info()->dimension(0) * 12.f);
- shape_b_int.set(1, std::ceil(b->info()->dimension(1) / 12.f));
-
- TensorInfo info_a_int(shape_a_int, 1, a->info()->data_type());
- TensorInfo info_b_int(shape_b_int, 1, b->info()->data_type());
- _tmp_a.allocator()->init(info_a_int);
- _tmp_b.allocator()->init(info_b_int);
-
- _memory_group.manage(&_tmp_a);
- _memory_group.manage(&_tmp_b);
-
- _interleave_blocked.configure(a, &_tmp_a, 8, 4, false);
- _interleave_blocked_transposed.configure(b, &_tmp_b, 12, 4, true);
- _mm_optimised_kernel->configure(&_tmp_a, &_tmp_b, output);
-
- _tmp_a.allocator()->allocate();
- _tmp_b.allocator()->allocate();
-#endif /* ARM_COMPUTE_AARCH64_V8_2 */
+ TensorShape shape_vector_sum_col = b->info()->tensor_shape();
+ shape_vector_sum_col.remove_dimension(1);
+ TensorInfo info_vector_sum_col(shape_vector_sum_col, 1, DataType::S32);
+ _vector_sum_col.allocator()->init(info_vector_sum_col);
+ _memory_group.manage(&_vector_sum_col);
+
+ // Configure Matrix B reduction kernel
+ _mtx_b_reduction_kernel.configure(b, &_vector_sum_col, a->info()->dimension(0), false);
}
- else
+
+ // Initialize Matrix A reduction kernel only if _b_offset is not equal to 0
+ if(_b_offset != 0)
{
- ARM_COMPUTE_ERROR("Not implemented");
- //FIXME: This is in the process of being updated, for more info please refer to COMPMID-624.
+ TensorShape shape_vector_sum_row = a->info()->tensor_shape();
+ shape_vector_sum_row.set(Window::DimX, a->info()->dimension(1));
+ shape_vector_sum_row.remove_dimension(1);
+ TensorInfo info_vector_sum_row(shape_vector_sum_row, 1, DataType::S32);
+ _vector_sum_row.allocator()->init(info_vector_sum_row);
+ _memory_group.manage(&_vector_sum_row);
+
+ // Configure Matrix A reduction kernel
+ _mtx_a_reduction_kernel.configure(a, &_vector_sum_row, a->info()->dimension(0), false);
}
-}
-void NEGEMMLowp::run()
-{
- _memory_group.acquire();
+ // Configure matrix multiply function
+ _mm_func.configure(a, b, &_mm_output);
+
+ // Configure finalize kernel
+ _finalize_kernel.configure(_a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, &_mm_output, output, a->info()->dimension(0), a_offset, b_offset, c_offset,
+ output_mult_int, shift);
- if(_mm_optimised_kernel != nullptr)
+ // Allocate tensors
+ _mm_output.allocator()->allocate();
+
+ if(_a_offset != 0)
{
- NEScheduler::get().schedule(&_interleave_blocked, Window::DimY);
- NEScheduler::get().schedule(&_interleave_blocked_transposed, Window::DimY);
- NEScheduler::get().schedule(_mm_optimised_kernel.get(), Window::DimY);
+ _vector_sum_col.allocator()->allocate();
}
- else
+
+ if(_b_offset != 0)
{
- /* Run interleave kernel */
- NEScheduler::get().schedule(&_interleave_kernel, Window::DimY);
- /* Run transpose kernel */
- NEScheduler::get().schedule(&_transpose_kernel, Window::DimY);
- /* Run matrix multiply kernel */
- NEScheduler::get().schedule(&_mm_kernel, Window::DimY);
+ _vector_sum_row.allocator()->allocate();
}
-
- _memory_group.release();
}
-void NEGEMMLowp::configure(const ITensor *a, const ITensor *b, ITensor *output, int32_t a_offset, int32_t b_offset, int32_t output_offset, int32_t output_mult_int, int32_t shift)
+void NEGEMMLowp::run()
{
- NEGEMMLOWP_VALIDATE_DIMENSIONS(a, b, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, b, output);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
-
- /* The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ] */
- TensorShape shape_tmp_a = a->info()->tensor_shape();
- shape_tmp_a.set(0, a->info()->dimension(0) * 4);
- shape_tmp_a.set(1, std::ceil(a->info()->dimension(1) / 4.f));
-
- TensorShape shape_tmp_b = b->info()->tensor_shape();
- shape_tmp_b.set(0, b->info()->dimension(1) * 16);
- shape_tmp_b.set(1, std::ceil(b->info()->dimension(0) / 16.f));
+ _memory_group.acquire();
- TensorInfo info_a(shape_tmp_a, 1, a->info()->data_type());
- TensorInfo info_b(shape_tmp_b, 1, b->info()->data_type());
- _tmp_a.allocator()->init(info_a);
- _tmp_b.allocator()->init(info_b);
+ // Run matrix A reduction kernel only if _b_offset is not equal to 0
+ if(_b_offset != 0)
+ {
+ NEScheduler::get().schedule(&_mtx_a_reduction_kernel, Window::DimX);
+ }
- // Manage intermediate buffers
- _memory_group.manage(&_tmp_a);
- _memory_group.manage(&_tmp_b);
+ // Run matrix B reduction kernel only if _a_offset is not equal to 0
+ if(_a_offset != 0)
+ {
+ NEScheduler::get().schedule(&_mtx_b_reduction_kernel, Window::DimX);
+ }
- _interleave_kernel.configure(a, &_tmp_a);
- _transpose_kernel.configure(b, &_tmp_b);
- _mm_kernel.configure(&_tmp_a, &_tmp_b, output, a_offset, b_offset, output_offset, output_mult_int, shift);
+ // Run matrix multiply core function
+ _mm_func.run();
- _tmp_a.allocator()->allocate();
- _tmp_b.allocator()->allocate();
-}
+ // Run finalise kernel
+ NEScheduler::get().schedule(&_finalize_kernel, Window::DimY);
-#undef NEGEMMLOWP_VALIDATE_DIMENSIONS
+ _memory_group.release();
+} \ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
new file mode 100644
index 0000000000..11ae054e11
--- /dev/null
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -0,0 +1,163 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h"
+#include "arm_compute/core/NEON/kernels/NEGEMMInterleaveBlockedKernel.h"
+#include "arm_compute/core/NEON/kernels/NEGEMMLowpAssemblyBaseKernel.h"
+#include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
+#include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
+#include "arm_compute/core/NEON/kernels/arm64/NEGEMMLowpAArch64V8P4Kernel.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "arm_compute/runtime/TensorAllocator.h"
+#include "support/ToolchainSupport.h"
+
+using namespace arm_compute;
+
+NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
+ : _memory_group(std::move(memory_manager)), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _tmp_a(), _tmp_b()
+{
+}
+
+void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, ITensor *output)
+{
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::U8);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
+ ARM_COMPUTE_ERROR_ON_MSG((a)->info()->dimension(0) != (b)->info()->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
+ ARM_COMPUTE_ERROR_ON_MSG((a)->info()->dimension(1) != (output)->info()->dimension(1), "The output matrix must have the same number of rows as the matrix A");
+ ARM_COMPUTE_ERROR_ON_MSG((b)->info()->dimension(0) != (output)->info()->dimension(0), "The output matrix must have the same number of columns as the matrix B");
+
+#ifdef ARM_COMPUTE_AARCH64_V8_2
+ // Check for DOT product instruction
+ const struct CPUInfo ci = NEScheduler::get().cpu_info();
+ const int cpu_has_dotprod = static_cast<int>(ci.CPU) & static_cast<int>(CPUTarget::DOT);
+
+ if(cpu_has_dotprod != 0)
+ {
+ TensorShape shape_a_int = a->info()->tensor_shape();
+ shape_a_int.set(0, a->info()->dimension(0) * 8.f);
+ shape_a_int.set(1, std::ceil(a->info()->dimension(1) / 8.f));
+
+ TensorShape shape_b_int = b->info()->tensor_shape();
+ shape_b_int.set(0, b->info()->dimension(0) * 12.f);
+ shape_b_int.set(1, std::ceil(b->info()->dimension(1) / 12.f));
+
+ TensorInfo info_a_int(shape_a_int, 1, a->info()->data_type());
+ TensorInfo info_b_int(shape_b_int, 1, b->info()->data_type());
+ _tmp_a.allocator()->init(info_a_int);
+ _tmp_b.allocator()->init(info_b_int);
+ _memory_group.manage(&_tmp_a);
+ _memory_group.manage(&_tmp_b);
+
+ // Configure interleave blocked kernel for matrix A
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMInterleaveBlockedKernel>();
+ k->configure(a, &_tmp_a, 8, 4, false);
+ _mtx_a_reshape_kernel = std::move(k);
+ }
+
+ // Configure interleave blocked kernel for matrix B
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMInterleaveBlockedKernel>();
+ k->configure(b, &_tmp_b, 12, 4, true);
+ _mtx_b_reshape_kernel = std::move(k);
+ }
+
+ // Configure matrix multiply kernel
+ {
+ // NEGEMMLowpAArch64V8P4Kernel only compiled in AArch64 targets
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpAArch64V8P4Kernel>();
+ k->configure(&_tmp_a, &_tmp_b, output);
+ _mm_kernel = std::move(k);
+ }
+ }
+ else
+#endif /* ARM_COMPUTE_AARCH64_V8_2 */
+ {
+ // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ]
+ TensorShape shape_tmp_a = a->info()->tensor_shape();
+ shape_tmp_a.set(0, a->info()->dimension(0) * 4);
+ shape_tmp_a.set(1, std::ceil(a->info()->dimension(1) / 4.f));
+
+ // The transpose1xW output matrix will have the following shape: [ b_height * 16, ceil(b_width / 16.0f) ]
+ TensorShape shape_tmp_b = b->info()->tensor_shape();
+ shape_tmp_b.set(0, b->info()->dimension(1) * 16);
+ shape_tmp_b.set(1, std::ceil(b->info()->dimension(0) / 16.f));
+
+ TensorInfo info_a(shape_tmp_a, 1, a->info()->data_type());
+ TensorInfo info_b(shape_tmp_b, 1, b->info()->data_type());
+ _tmp_a.allocator()->init(info_a);
+ _tmp_b.allocator()->init(info_b);
+ _memory_group.manage(&_tmp_a);
+ _memory_group.manage(&_tmp_b);
+
+ // Configure interleave kernel
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMInterleave4x4Kernel>();
+ k->configure(a, &_tmp_a);
+ _mtx_a_reshape_kernel = std::move(k);
+ }
+
+ // Configure transpose kernel
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMTranspose1xWKernel>();
+ k->configure(b, &_tmp_b);
+ _mtx_b_reshape_kernel = std::move(k);
+ }
+
+ // Configure matrix multiply kernel
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpMatrixMultiplyKernel>();
+ k->configure(&_tmp_a, &_tmp_b, output);
+ _mm_kernel = std::move(k);
+ }
+ }
+
+ // Allocate tensors
+ _tmp_a.allocator()->allocate();
+ _tmp_b.allocator()->allocate();
+}
+
+void NEGEMMLowpMatrixMultiplyCore::run()
+{
+ _memory_group.acquire();
+
+ // Run reshape matrix A
+ NEScheduler::get().schedule(_mtx_a_reshape_kernel.get(), Window::DimY);
+
+ // Run reshape matrix B
+ NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY);
+
+ // Run matrix multiply kernel
+ NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY);
+
+ _memory_group.release();
+} \ No newline at end of file