From afde732eb016f18c781923cf1e6c9edf68f586f7 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Tue, 25 Jul 2017 09:19:46 +0100 Subject: COMPMID-421: Added FP16 support in the Neon Locally Connected Layer. Change-Id: I4b52a209a5ce1a7e69494008538ed242b14b5593 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/81520 Tested-by: Kaizen Reviewed-by: Anthony Barbier --- .../NELocallyConnectedMatrixMultiplyKernel.cpp | 149 ++++++++++++++++++++- 1 file changed, 143 insertions(+), 6 deletions(-) (limited to 'src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp') diff --git a/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp index 895799c6ca..2b7b391c43 100644 --- a/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp +++ b/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp @@ -49,6 +49,126 @@ class Coordinates; namespace { +void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window) +{ +#ifdef ARM_COMPUTE_ENABLE_FP16 + const auto width_matrix_b = static_cast(output->info()->dimension(0)); + const auto in_b_stride = static_cast(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type())); + const auto num_elems_vec_a = static_cast(input0->info()->dimension(0)); + + // The implementation computes 16 elements per iteration + const int window_start_x = 16 * window.thread_id(); + const int window_step_x = 16 * window.num_threads(); + // Make sure (window_end_x - window_start_x) is a multiple of window_step_x + const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x; + + Window win_out(window); + win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x)); + + Window win_a(window); + win_a.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator ina(input0, win_a); + Iterator out(output, win_out); + + execute_window_loop(win_out, [&](const Coordinates & id) + { + if(id.x() > width_matrix_b) + { + return; + } + + float16x8_t acc0 = vdupq_n_f16(0.f); + float16x8_t acc1 = vdupq_n_f16(0.f); + float16x8_t acc2 = vdupq_n_f16(0.f); + float16x8_t acc3 = vdupq_n_f16(0.f); + + auto vec_a = reinterpret_cast(ina.ptr()); + auto matrix_b = reinterpret_cast(input1->ptr_to_element(Coordinates(id[0], 0, id[1]))); + + const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; + + for(; vec_a <= (vec_a_end_addr - 4);) + { + const float16x4_t a0l = vld1_f16(vec_a); + + float16x8_t b00 = vld1q_f16(matrix_b); + float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + + float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); + float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); + float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); + float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0)); + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1)); + + matrix_b += 2 * in_b_stride; + + b00 = vld1q_f16(matrix_b); + b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); + b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); + b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); + b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2)); + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3)); + + vec_a += 4; + matrix_b += 2 * in_b_stride; + } + + for(; vec_a < vec_a_end_addr;) + { + const float16_t a0 = *vec_a; + const float16x8_t b00 = vld1q_f16(matrix_b); + const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0)); + acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0)); + acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0)); + acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0)); + + vec_a += 1; + matrix_b += in_b_stride; + } + + const auto vec_out = reinterpret_cast(out.ptr()); + + vst1q_f16(vec_out + 0, acc0); + vst1q_f16(vec_out + 8, acc1); + vst1q_f16(vec_out + 16, acc2); + vst1q_f16(vec_out + 24, acc3); + }, + ina, out); +#else /* ARM_COMPUTE_ENABLE_FP16 */ + ARM_COMPUTE_UNUSED(input0); + ARM_COMPUTE_UNUSED(input1); + ARM_COMPUTE_UNUSED(output); + ARM_COMPUTE_UNUSED(window); + ARM_COMPUTE_ERROR("Not supported, recompile with -march=armv8.2-a+fp16+simd."); +#endif /* ARM_COMPUTE_ENABLE_FP16 */ +} + void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window) { const auto width_matrix_b = static_cast(output->info()->dimension(0)); @@ -190,17 +310,17 @@ NELocallyConnectedMatrixMultiplyKernel::NELocallyConnectedMatrixMultiplyKernel() void NELocallyConnectedMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON(input0->info()->dimension(0) != input1->info()->dimension(1)); _input0 = input0; _input1 = input1; _output = output; - unsigned int num_elems_processed_per_iteration_x = 16; + const unsigned int num_elems_processed_per_iteration_x = 16; // Configure kernel window Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x)); @@ -222,5 +342,22 @@ void NELocallyConnectedMatrixMultiplyKernel::run(const Window &window) ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - vector_matrix_multiply_f32(_input0, _input1, _output, window); + switch(_input0->info()->data_type()) + { + case DataType::F16: + { + vector_matrix_multiply_f16(_input0, _input1, _output, window); + break; + } + case DataType::F32: + { + vector_matrix_multiply_f32(_input0, _input1, _output, window); + break; + } + default: + { + ARM_COMPUTE_ERROR("Data type not supported"); + break; + } + } } -- cgit v1.2.1