aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2017-07-25 09:19:46 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:16:42 +0100
commitafde732eb016f18c781923cf1e6c9edf68f586f7 (patch)
tree72303ece7d6c0caf2eb6c8d0e5e869c0b4e975de /src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
parentf6ad98a95cc4a638e133538ae682185032c16201 (diff)
downloadComputeLibrary-afde732eb016f18c781923cf1e6c9edf68f586f7.tar.gz
COMPMID-421: Added FP16 support in the Neon Locally Connected Layer.
Change-Id: I4b52a209a5ce1a7e69494008538ed242b14b5593 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/81520 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp149
1 files changed, 143 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
index 895799c6ca..2b7b391c43 100644
--- a/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
@@ -49,6 +49,126 @@ class Coordinates;
namespace
{
+void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window)
+{
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
+ const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
+ const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
+
+ // The implementation computes 16 elements per iteration
+ const int window_start_x = 16 * window.thread_id();
+ const int window_step_x = 16 * window.num_threads();
+ // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
+ const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
+
+ Window win_out(window);
+ win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+
+ Window win_a(window);
+ win_a.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator ina(input0, win_a);
+ Iterator out(output, win_out);
+
+ execute_window_loop(win_out, [&](const Coordinates & id)
+ {
+ if(id.x() > width_matrix_b)
+ {
+ return;
+ }
+
+ float16x8_t acc0 = vdupq_n_f16(0.f);
+ float16x8_t acc1 = vdupq_n_f16(0.f);
+ float16x8_t acc2 = vdupq_n_f16(0.f);
+ float16x8_t acc3 = vdupq_n_f16(0.f);
+
+ auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
+ auto matrix_b = reinterpret_cast<const float16_t *>(input1->ptr_to_element(Coordinates(id[0], 0, id[1])));
+
+ const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
+
+ for(; vec_a <= (vec_a_end_addr - 4);)
+ {
+ const float16x4_t a0l = vld1_f16(vec_a);
+
+ float16x8_t b00 = vld1q_f16(matrix_b);
+ float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+ float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+ float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+
+ float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
+ float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
+ float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
+ float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
+
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
+
+ matrix_b += 2 * in_b_stride;
+
+ b00 = vld1q_f16(matrix_b);
+ b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+ b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+ b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+ b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
+ b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
+ b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
+ b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
+
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
+
+ vec_a += 4;
+ matrix_b += 2 * in_b_stride;
+ }
+
+ for(; vec_a < vec_a_end_addr;)
+ {
+ const float16_t a0 = *vec_a;
+ const float16x8_t b00 = vld1q_f16(matrix_b);
+ const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+ const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+ const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+
+ acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
+ acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
+ acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
+ acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
+
+ vec_a += 1;
+ matrix_b += in_b_stride;
+ }
+
+ const auto vec_out = reinterpret_cast<float16_t *>(out.ptr());
+
+ vst1q_f16(vec_out + 0, acc0);
+ vst1q_f16(vec_out + 8, acc1);
+ vst1q_f16(vec_out + 16, acc2);
+ vst1q_f16(vec_out + 24, acc3);
+ },
+ ina, out);
+#else /* ARM_COMPUTE_ENABLE_FP16 */
+ ARM_COMPUTE_UNUSED(input0);
+ ARM_COMPUTE_UNUSED(input1);
+ ARM_COMPUTE_UNUSED(output);
+ ARM_COMPUTE_UNUSED(window);
+ ARM_COMPUTE_ERROR("Not supported, recompile with -march=armv8.2-a+fp16+simd.");
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+}
+
void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window)
{
const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
@@ -190,17 +310,17 @@ NELocallyConnectedMatrixMultiplyKernel::NELocallyConnectedMatrixMultiplyKernel()
void NELocallyConnectedMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON(input0->info()->dimension(0) != input1->info()->dimension(1));
_input0 = input0;
_input1 = input1;
_output = output;
- unsigned int num_elems_processed_per_iteration_x = 16;
+ const unsigned int num_elems_processed_per_iteration_x = 16;
// Configure kernel window
Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
@@ -222,5 +342,22 @@ void NELocallyConnectedMatrixMultiplyKernel::run(const Window &window)
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- vector_matrix_multiply_f32(_input0, _input1, _output, window);
+ switch(_input0->info()->data_type())
+ {
+ case DataType::F16:
+ {
+ vector_matrix_multiply_f16(_input0, _input1, _output, window);
+ break;
+ }
+ case DataType::F32:
+ {
+ vector_matrix_multiply_f32(_input0, _input1, _output, window);
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Data type not supported");
+ break;
+ }
+ }
}