aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.h2
-rw-r--r--arm_compute/runtime/NEON/functions/NELocallyConnectedLayer.h2
-rw-r--r--src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp149
-rw-r--r--tests/validation/NEON/DirectConvolutionLayer.cpp4
4 files changed, 147 insertions, 10 deletions
diff --git a/arm_compute/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.h b/arm_compute/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.h
index d4bff661f9..a2c6b51101 100644
--- a/arm_compute/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.h
+++ b/arm_compute/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.h
@@ -46,7 +46,7 @@ public:
NELocallyConnectedMatrixMultiplyKernel &operator=(NELocallyConnectedMatrixMultiplyKernel &&) = default;
/** Initialise the kernel's input and output
*
- * @param[in] input0 First input tensor. Data types supported: F32
+ * @param[in] input0 First input tensor. Data types supported: F16, F32
* @param[in] input1 Second input tensor containing the Matrix B. Data type supported: same as @p input0
* @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0
*/
diff --git a/arm_compute/runtime/NEON/functions/NELocallyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NELocallyConnectedLayer.h
index 1b2b2ee3cf..efb2ff6c8b 100644
--- a/arm_compute/runtime/NEON/functions/NELocallyConnectedLayer.h
+++ b/arm_compute/runtime/NEON/functions/NELocallyConnectedLayer.h
@@ -53,7 +53,7 @@ public:
*
* @param[in] input Source tensor. 3 lower dimensions represent a single input [width, height, IFM],
* while every optional dimension from 4 and above represent a batch of inputs.
- * Data types supported: F32.
+ * Data types supported: F16, F32.
* @param[in] weights Weights tensor. Weights are 5D tensor with dimensions [kernel_x, kernel_y, IFM, OFM, num_patches]. Data type supported:Same as @p input.
* @param[in] biases Biases tensor. Shared biases supported. Biases are 2D tensor with dimensions [OFM, num_patches]. Data type supported:Same as @p input.
* @param[out] output Destination tensor. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs.
diff --git a/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
index 895799c6ca..2b7b391c43 100644
--- a/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
@@ -49,6 +49,126 @@ class Coordinates;
namespace
{
+void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window)
+{
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
+ const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
+ const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
+
+ // The implementation computes 16 elements per iteration
+ const int window_start_x = 16 * window.thread_id();
+ const int window_step_x = 16 * window.num_threads();
+ // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
+ const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
+
+ Window win_out(window);
+ win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+
+ Window win_a(window);
+ win_a.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator ina(input0, win_a);
+ Iterator out(output, win_out);
+
+ execute_window_loop(win_out, [&](const Coordinates & id)
+ {
+ if(id.x() > width_matrix_b)
+ {
+ return;
+ }
+
+ float16x8_t acc0 = vdupq_n_f16(0.f);
+ float16x8_t acc1 = vdupq_n_f16(0.f);
+ float16x8_t acc2 = vdupq_n_f16(0.f);
+ float16x8_t acc3 = vdupq_n_f16(0.f);
+
+ auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
+ auto matrix_b = reinterpret_cast<const float16_t *>(input1->ptr_to_element(Coordinates(id[0], 0, id[1])));
+
+ const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
+
+ for(; vec_a <= (vec_a_end_addr - 4);)
+ {
+ const float16x4_t a0l = vld1_f16(vec_a);
+
+ float16x8_t b00 = vld1q_f16(matrix_b);
+ float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+ float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+ float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+
+ float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
+ float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
+ float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
+ float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
+
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
+
+ matrix_b += 2 * in_b_stride;
+
+ b00 = vld1q_f16(matrix_b);
+ b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+ b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+ b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+ b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
+ b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
+ b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
+ b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
+
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
+
+ vec_a += 4;
+ matrix_b += 2 * in_b_stride;
+ }
+
+ for(; vec_a < vec_a_end_addr;)
+ {
+ const float16_t a0 = *vec_a;
+ const float16x8_t b00 = vld1q_f16(matrix_b);
+ const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+ const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+ const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+
+ acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
+ acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
+ acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
+ acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
+
+ vec_a += 1;
+ matrix_b += in_b_stride;
+ }
+
+ const auto vec_out = reinterpret_cast<float16_t *>(out.ptr());
+
+ vst1q_f16(vec_out + 0, acc0);
+ vst1q_f16(vec_out + 8, acc1);
+ vst1q_f16(vec_out + 16, acc2);
+ vst1q_f16(vec_out + 24, acc3);
+ },
+ ina, out);
+#else /* ARM_COMPUTE_ENABLE_FP16 */
+ ARM_COMPUTE_UNUSED(input0);
+ ARM_COMPUTE_UNUSED(input1);
+ ARM_COMPUTE_UNUSED(output);
+ ARM_COMPUTE_UNUSED(window);
+ ARM_COMPUTE_ERROR("Not supported, recompile with -march=armv8.2-a+fp16+simd.");
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+}
+
void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window)
{
const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
@@ -190,17 +310,17 @@ NELocallyConnectedMatrixMultiplyKernel::NELocallyConnectedMatrixMultiplyKernel()
void NELocallyConnectedMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON(input0->info()->dimension(0) != input1->info()->dimension(1));
_input0 = input0;
_input1 = input1;
_output = output;
- unsigned int num_elems_processed_per_iteration_x = 16;
+ const unsigned int num_elems_processed_per_iteration_x = 16;
// Configure kernel window
Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
@@ -222,5 +342,22 @@ void NELocallyConnectedMatrixMultiplyKernel::run(const Window &window)
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- vector_matrix_multiply_f32(_input0, _input1, _output, window);
+ switch(_input0->info()->data_type())
+ {
+ case DataType::F16:
+ {
+ vector_matrix_multiply_f16(_input0, _input1, _output, window);
+ break;
+ }
+ case DataType::F32:
+ {
+ vector_matrix_multiply_f32(_input0, _input1, _output, window);
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Data type not supported");
+ break;
+ }
+ }
}
diff --git a/tests/validation/NEON/DirectConvolutionLayer.cpp b/tests/validation/NEON/DirectConvolutionLayer.cpp
index 034a8b2045..effb898428 100644
--- a/tests/validation/NEON/DirectConvolutionLayer.cpp
+++ b/tests/validation/NEON/DirectConvolutionLayer.cpp
@@ -150,7 +150,7 @@ BOOST_DATA_TEST_CASE(W1x1,
RawTensor ref = Reference::compute_reference_convolution_layer(input_shape, w_shape, b_shape, d_shape, dt, conv_info, 0);
// Validate output
- validate(NEAccessor(dst), ref);
+ validate(Accessor(dst), ref);
}
BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
@@ -172,7 +172,7 @@ BOOST_DATA_TEST_CASE(W3x3, DirectConvolutionShapes() * boost::unit_test::data::m
RawTensor ref = Reference::compute_reference_convolution_layer(input_shape, w_shape, b_shape, d_shape, dt, conv_info, 0);
// Validate output
- validate(NEAccessor(dst), ref, tolerance_fp16);
+ validate(Accessor(dst), ref, tolerance_fp16);
}
BOOST_AUTO_TEST_SUITE_END()
#endif /* ARM_COMPUTE_ENABLE_FP16 */