aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2017-06-28 17:27:56 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:15:39 +0100
commit221f38176b0d4dbc212441779d9bbac3cc0eecfa (patch)
treee838d673b35c5b40d4b484a3645cc7ae3c9d3ecc
parent6410fb2a14427713251f5d97144ac5d4f17c988c (diff)
downloadComputeLibrary-221f38176b0d4dbc212441779d9bbac3cc0eecfa.tar.gz
COMPMID-421: Fixed FP16 support in Neon GEMM.
Fixed GEMM FP16 problem with matrices that are not multiple of 32. Added a new test suite NEON/GEMM/Float16/SmallGEMM. Implemented FP16 function to multiply vector by a matrix. Change-Id: Ie6c692885a48d0206bd6fe748332fa83bc286d67 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79118 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com>
-rwxr-xr-xscripts/check_clang-tidy.py1
-rw-r--r--scripts/clang-tidy.h25
-rw-r--r--src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp47
-rw-r--r--src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp197
-rw-r--r--tests/NEON/Helper.h1
-rw-r--r--tests/TensorLibrary.h2
-rw-r--r--tests/validation/NEON/GEMM.cpp20
-rw-r--r--tests/validation/Reference.cpp2
-rw-r--r--tests/validation/Reference.h1
9 files changed, 241 insertions, 55 deletions
diff --git a/scripts/check_clang-tidy.py b/scripts/check_clang-tidy.py
index a376d0b898..6ab1747482 100755
--- a/scripts/check_clang-tidy.py
+++ b/scripts/check_clang-tidy.py
@@ -39,6 +39,7 @@ if __name__ == "__main__":
("Validation.cpp" in line and "parameter 'expected_labels' is unused" in line) or
("Reference.cpp" in line and "parameter 'rois' is unused" in line) or
("ReferenceCPP.cpp" in line and "parameter 'rois' is unused" in line) or
+ ("NEGEMMMatrixMultiplyKernel.cpp" in line and "do not use C-style cast to convert between unrelated types" in line) or
"3rdparty" in line):
continue
diff --git a/scripts/clang-tidy.h b/scripts/clang-tidy.h
index 32b0f6955e..cbc0d07cd6 100644
--- a/scripts/clang-tidy.h
+++ b/scripts/clang-tidy.h
@@ -1,5 +1,30 @@
#include <arm_neon.h>
+inline float16x8_t vmulq_lane_f16 (float16x8_t, float16x4_t, const int)
+{
+ return vdupq_n_f16(0);
+}
+
+inline float16x4_t vmul_f16 (float16x4_t, float16x4_t)
+{
+ return vdup_n_u16(0);
+}
+
+inline float16x4_t vadd_f16 (float16x4_t, float16x4_t)
+{
+ return vdup_n_u16(0);
+}
+
+inline float16x4_t vmul_lane_f16 (float16x4_t, float16x4_t, const int)
+{
+ return vdup_n_u16(0);
+}
+
+inline float16x4_t vmul_n_f16 (float16x4_t, float16_t)
+{
+ return vdup_n_u16(0);
+}
+
inline float16x8_t vcvtq_f16_u16(uint16x8_t)
{
return vdupq_n_f16(0);
diff --git a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
index 71dd4c7aa1..7d659ab2e6 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
@@ -52,25 +52,8 @@ void matrix_addition_f32(const ITensor *input, ITensor *output, const Window &wi
const auto in_ptr = reinterpret_cast<const float *>(in.ptr());
const auto out_ptr = reinterpret_cast<float *>(out.ptr());
- float32x4x4_t alpha_ab =
- {
- {
- vld1q_f32(out_ptr + 0),
- vld1q_f32(out_ptr + 4),
- vld1q_f32(out_ptr + 8),
- vld1q_f32(out_ptr + 12)
- }
- };
-
- const float32x4x4_t c =
- {
- {
- vld1q_f32(in_ptr + 0),
- vld1q_f32(in_ptr + 4),
- vld1q_f32(in_ptr + 8),
- vld1q_f32(in_ptr + 12)
- }
- };
+ float32x4x4_t alpha_ab = vld4q_f32(out_ptr);
+ const float32x4x4_t c = vld4q_f32(in_ptr);
// Multiply matrix C by its weight and accumulate
alpha_ab.val[0] = vmlaq_f32(alpha_ab.val[0], c.val[0], beta_f32);
@@ -78,10 +61,7 @@ void matrix_addition_f32(const ITensor *input, ITensor *output, const Window &wi
alpha_ab.val[2] = vmlaq_f32(alpha_ab.val[2], c.val[2], beta_f32);
alpha_ab.val[3] = vmlaq_f32(alpha_ab.val[3], c.val[3], beta_f32);
- vst1q_f32(out_ptr + 0, alpha_ab.val[0]);
- vst1q_f32(out_ptr + 4, alpha_ab.val[1]);
- vst1q_f32(out_ptr + 8, alpha_ab.val[2]);
- vst1q_f32(out_ptr + 12, alpha_ab.val[3]);
+ vst4q_f32(out_ptr, alpha_ab);
},
in, out);
}
@@ -99,28 +79,13 @@ void matrix_addition_f16(const ITensor *input, ITensor *output, const Window &wi
const auto in_ptr = reinterpret_cast<const float16_t *>(in.ptr());
const auto out_ptr = reinterpret_cast<float16_t *>(out.ptr());
- float16x8x2_t alpha_ab =
- {
- {
- vld1q_f16(out_ptr + 0),
- vld1q_f16(out_ptr + 8)
- }
- };
-
- float16x8x2_t c =
- {
- {
- vld1q_f16(in_ptr + 0),
- vld1q_f16(in_ptr + 8)
- }
- };
-
+ float16x8x2_t alpha_ab = vld2q_f16(out_ptr);
+ const float16x8x2_t c = vld2q_f16(in_ptr);
// Multiply matrix C by its weight and accumulate
alpha_ab.val[0] = vaddq_f16(alpha_ab.val[0], vmulq_f16(c.val[0], beta_f16));
alpha_ab.val[1] = vaddq_f16(alpha_ab.val[1], vmulq_f16(c.val[1], beta_f16));
- vst1q_f16(out_ptr + 0, alpha_ab.val[0]);
- vst1q_f16(out_ptr + 8, alpha_ab.val[1]);
+ vst2q_f16(out_ptr + 0, alpha_ab);
},
in, out);
}
diff --git a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
index 1db025723c..101c5c8132 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
@@ -50,6 +50,147 @@ class Coordinates;
namespace
{
template <bool multiply_alpha>
+void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
+{
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
+ const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
+ const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
+
+ // The implementation computes 32 elements per iteration
+ const int window_start_x = 32 * window.thread_id();
+ const int window_step_x = 32 * window.num_threads();
+ const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
+ ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
+
+ Window win_out(window);
+ win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+ win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
+
+ Window win_a(window);
+ win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
+
+ Window win_b;
+ // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
+ // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
+ if(input1->info()->num_dimensions() >= 3)
+ {
+ win_b = window;
+ }
+ win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+ win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
+
+ Iterator ina(input0, win_a);
+ Iterator inb(input1, win_b);
+ Iterator out(output, win_out);
+
+ const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
+ ARM_COMPUTE_UNUSED(alpha_f16);
+
+ execute_window_loop(win_out, [&](const Coordinates & id)
+ {
+ if(id.x() > width_matrix_b)
+ {
+ return;
+ }
+
+ float16x8_t acc0 = vdupq_n_f16(0.f);
+ float16x8_t acc1 = vdupq_n_f16(0.f);
+ float16x8_t acc2 = vdupq_n_f16(0.f);
+ float16x8_t acc3 = vdupq_n_f16(0.f);
+
+ auto vec_a = reinterpret_cast<const float16_t *>(ina.ptr());
+ auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr());
+
+ const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
+ for(; vec_a <= (vec_a_end_addr - 4);)
+ {
+ const float16x4_t a0l = vld1_f16(vec_a);
+
+ float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
+ float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+ float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+ float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+ float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
+ float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
+ float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
+ float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
+
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
+
+ matrix_b += 2 * in_b_stride;
+
+ b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
+ b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+ b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+ b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+ b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
+ b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
+ b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
+ b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
+
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
+ acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
+ acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
+ acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
+ acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
+
+ vec_a += 4;
+ matrix_b += 2 * in_b_stride;
+ }
+
+ for(; vec_a < vec_a_end_addr;)
+ {
+ const float16_t a0 = *vec_a;
+ const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
+ const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+ const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+ const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+
+ acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
+ acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
+ acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
+ acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
+
+ vec_a += 1;
+ matrix_b += in_b_stride;
+ }
+
+ // Multiply by the weight of matrix product (alpha)
+ if(multiply_alpha)
+ {
+ acc0 = vmulq_f16(acc0, alpha_f16);
+ acc1 = vmulq_f16(acc1, alpha_f16);
+ acc2 = vmulq_f16(acc2, alpha_f16);
+ acc3 = vmulq_f16(acc3, alpha_f16);
+ }
+
+ const auto vec_out = reinterpret_cast<float16_t *>(out.ptr());
+
+ vst1q_f16(vec_out + 0, acc0);
+ vst1q_f16(vec_out + 8, acc1);
+ vst1q_f16(vec_out + 16, acc2);
+ vst1q_f16(vec_out + 24, acc3);
+
+ },
+ ina, inb, out);
+#else /* ARM_COMPUTE_ENABLE_FP16 */
+ ARM_COMPUTE_ERROR("Not implemented");
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+}
+
+template <bool multiply_alpha>
void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
{
const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
@@ -639,9 +780,9 @@ template <bool multiply_alpha>
void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
{
#ifdef ARM_COMPUTE_ENABLE_FP16
-
- const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
- const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
+ const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
+ const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
+ const int num_elems_matrix_b_x = input1->info()->dimension(0);
// Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
Window win_a(window);
@@ -663,9 +804,6 @@ void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT
Iterator inb(input1, win_b);
Iterator out(output, window);
- // Number of iterations of inner loop. Since 8 is the number of accumulations per loop, num_it = (width_mtx_b / 4) / 8
- const size_t num_it = ((input1->info()->dimension(0)) >> 2) >> 3;
-
const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
execute_window_loop(window, [&](const Coordinates & id)
@@ -711,10 +849,14 @@ void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT
The size of the output tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
*/
- for(size_t k = num_it; k > 0; mtx_a0 += 16, mtx_b0 += 32, --k)
+ const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
+
+ for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
+
{
const float16x8_t p00 = vld1q_f16(mtx_a0);
const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
+
const float16x8_t q00 = vld1q_f16(mtx_b0);
const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
@@ -739,6 +881,24 @@ void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT
c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
+
+ mtx_a0 += 16;
+ mtx_b0 += 32;
+ }
+
+ for(; mtx_b0 < mtx_b0_end_addr;)
+
+ {
+ const float16x4_t p00 = vld1_f16(mtx_a0);
+ const float16x8_t q00 = vld1q_f16(mtx_b0);
+
+ c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0)));
+ c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1)));
+ c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2)));
+ c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3)));
+
+ mtx_a0 += 4;
+ mtx_b0 += 8;
}
if(multiply_alpha)
@@ -1037,6 +1197,13 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor
num_elems_processed_per_iteration_x = 32;
break;
}
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ case DataType::F16:
+ {
+ num_elems_processed_per_iteration_x = 32;
+ break;
+ }
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
default:
{
ARM_COMPUTE_ERROR("Data type not supported");
@@ -1074,13 +1241,13 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor
num_elems_processed_per_iteration_x = 32;
break;
}
+#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
{
-#ifdef ARM_COMPUTE_ENABLE_FP16
num_elems_processed_per_iteration_x = 8;
break;
-#endif
}
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
default:
{
ARM_COMPUTE_ERROR("Data type not supported");
@@ -1128,6 +1295,14 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window)
vector_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
break;
}
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ case DataType::F16:
+ {
+ multiply_alpha ? vector_matrix_multiply_f16<true>(_input0, _input1, _output, window, _alpha) :
+ vector_matrix_multiply_f16<false>(_input0, _input1, _output, window, _alpha);
+ break;
+ }
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
default:
{
ARM_COMPUTE_ERROR("Data type not supported");
@@ -1151,14 +1326,14 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window)
matrix_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
break;
}
+#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
{
-#ifdef ARM_COMPUTE_ENABLE_FP16
multiply_alpha ? matrix_matrix_multiply_f16<true>(_input0, _input1, _output, window, _alpha) :
matrix_matrix_multiply_f16<false>(_input0, _input1, _output, window, _alpha);
break;
-#endif
}
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
default:
{
ARM_COMPUTE_ERROR("Data type not supported");
diff --git a/tests/NEON/Helper.h b/tests/NEON/Helper.h
index e77615406e..0651c9c709 100644
--- a/tests/NEON/Helper.h
+++ b/tests/NEON/Helper.h
@@ -27,6 +27,7 @@
#include "arm_compute/runtime/Array.h"
#include <algorithm>
+#include <vector>
namespace arm_compute
{
diff --git a/tests/TensorLibrary.h b/tests/TensorLibrary.h
index 4d7143a206..69b2381171 100644
--- a/tests/TensorLibrary.h
+++ b/tests/TensorLibrary.h
@@ -505,7 +505,7 @@ void TensorLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t
fill(tensor, distribution_f16, seed_offset);
break;
}
-#endif
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
case DataType::F32:
{
// It doesn't make sense to check [-inf, inf], so hard code it to a big number
diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp
index 35f65c8fe2..75ce39716c 100644
--- a/tests/validation/NEON/GEMM.cpp
+++ b/tests/validation/NEON/GEMM.cpp
@@ -77,7 +77,7 @@ Tensor compute_gemm(const TensorShape &src_shape1, const TensorShape &src_shape2
BOOST_TEST(!dst.info()->is_resizable());
// Fill tensors
- if(dt == DataType::F32)
+ if(dt == DataType::F16 || dt == DataType::F32)
{
std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
library->fill(NEAccessor(src1), distribution, 0);
@@ -137,6 +137,24 @@ BOOST_DATA_TEST_CASE(Configuration,
validate(dst.info()->valid_region(), dst_valid_region);
}
+#ifdef ARM_COMPUTE_ENABLE_FP16
+BOOST_AUTO_TEST_SUITE(Float16)
+BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
+BOOST_DATA_TEST_CASE(SmallGEMM, SmallGEMMDataset() * boost::unit_test::data::make(DataType::F16),
+ gemm_set, dt)
+{
+ // Compute reference
+ RawTensor ref_dst = Reference::compute_reference_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt);
+
+ // Compute function
+ Tensor dst = compute_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt);
+
+ // Validate output
+ validate(NEAccessor(dst), ref_dst, tolerance_f32);
+}
+BOOST_AUTO_TEST_SUITE_END()
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
BOOST_AUTO_TEST_SUITE(Float)
BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
BOOST_DATA_TEST_CASE(SmallGEMM, SmallGEMMDataset() * boost::unit_test::data::make(DataType::F32),
diff --git a/tests/validation/Reference.cpp b/tests/validation/Reference.cpp
index 0518819173..62dfcba37e 100644
--- a/tests/validation/Reference.cpp
+++ b/tests/validation/Reference.cpp
@@ -335,7 +335,7 @@ RawTensor Reference::compute_reference_gemm(const TensorShape &src_shape1, const
RawTensor dst = library->get(dst_shape, dt, 1, fixed_point_position);
// Fill reference
- if(dt == DataType::F32)
+ if(dt == DataType::F16 || dt == DataType::F32)
{
std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
library->fill(src1, distribution, 0);
diff --git a/tests/validation/Reference.h b/tests/validation/Reference.h
index 3aca1eaaae..ebd5fa76c4 100644
--- a/tests/validation/Reference.h
+++ b/tests/validation/Reference.h
@@ -26,6 +26,7 @@
#include "RawTensor.h"
#include "Types.h"
+#include <vector>
#include <vector>