aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-11-16 17:11:50 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2018-11-19 17:42:58 +0000
commit8cffcd6b6e4e95f97767f2a25ccc8826dd69c358 (patch)
tree339d4053464ef995d24da035595b44155810036d /src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp
parentd5c075c4ecdac35cd07538acc559a2d8805d8c1c (diff)
downloadComputeLibrary-8cffcd6b6e4e95f97767f2a25ccc8826dd69c358.tar.gz
COMPMID-1644: NEDepthwiseConvolution for FP16 NHWC
Change-Id: I6e7dee8bd615a5eff01c523f208a218574ee5eab
Diffstat (limited to 'src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp53
1 files changed, 50 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp
index 238786953b..3a1595a0c9 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp
@@ -43,11 +43,11 @@ namespace
{
Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(output, DataType::S32, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(output, DataType::S32, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input0->data_type()) && (output->data_type() != DataType::S32));
- ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_float(input0->data_type()) && (output->data_type() != DataType::F32));
+ ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_float(input0->data_type()) && (output->data_type() != input0->data_type()));
ARM_COMPUTE_RETURN_ERROR_ON(input0->num_dimensions() == input1->num_dimensions());
ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(2) != input1->dimension(1));
@@ -87,6 +87,48 @@ void NEGEMMMatrixVectorMultiplyKernel::matrix_vector_multiply(const Window &wind
namespace arm_compute
{
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+void NEGEMMMatrixVectorMultiplyKernel::matrix_vector_multiply<half, half, half>(const Window &window_in,
+ const Window &window_w,
+ const Window &window_out)
+{
+ Iterator in(_input0, window_in);
+ Iterator in2(_input1, window_w);
+ Iterator out(_output, window_out);
+
+ const int input_w = _input0->info()->dimension(0);
+ const int input_h = _input0->info()->dimension(1);
+ const int input_stride_x = _input0->info()->strides_in_bytes().x();
+ const int weights_stride_x = _input1->info()->strides_in_bytes().x();
+ const int weights_stride_y = _input1->info()->strides_in_bytes().y();
+ const int output_stride_x = _output->info()->strides_in_bytes().x();
+
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ // Get pointers
+ const uint8_t *const input_ptr = in.ptr();
+ const uint8_t *const weights_ptr = in2.ptr() + id.z() * weights_stride_y;
+ auto output_ptr = reinterpret_cast<__fp16 *>(out.ptr() + (id.y() + id.z() * input_h) * output_stride_x);
+
+ float16x8_t row_dot = vdupq_n_f16(0.f);
+ for(int i = 0; i < input_w; i += 8)
+ {
+ const auto input = vld1q_f16(reinterpret_cast<const __fp16 *>(input_ptr + i * input_stride_x));
+ const auto weights = vld1q_f16(reinterpret_cast<const __fp16 *>(weights_ptr + i * weights_stride_x));
+ row_dot = vaddq_f16(row_dot, vmulq_f16(input, weights));
+ }
+
+ auto temp = vadd_f16(vget_high_f16(row_dot), vget_low_f16(row_dot));
+ temp = vpadd_f16(temp, temp);
+ temp = vpadd_f16(temp, temp);
+
+ *output_ptr = vget_lane_f16(temp, 0);
+ },
+ in, in2, out);
+}
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+
template <>
void NEGEMMMatrixVectorMultiplyKernel::matrix_vector_multiply<float, float, float>(const Window &window_in,
const Window &window_w,
@@ -226,6 +268,11 @@ void NEGEMMMatrixVectorMultiplyKernel::configure(const ITensor *input0, const IT
case DataType::QASYMM8:
_func = &NEGEMMMatrixVectorMultiplyKernel::matrix_vector_multiply<uint8_t, uint8_t, int32_t>;
break;
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ _func = &NEGEMMMatrixVectorMultiplyKernel::matrix_vector_multiply<half, half, half>;
+ break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
_func = &NEGEMMMatrixVectorMultiplyKernel::matrix_vector_multiply<float, float, float>;
break;