From df24618b53cffed1c574e11e9fd4ba7740f8c009 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Mon, 3 Jul 2017 16:25:09 +0100 Subject: COMPMID-421: Added FP16 suppot to NENormalizationLayer and NEPixelWiseMultiplication. Change-Id: If174f8071502fc5cc94b27cd44a9b1d5e451a9e2 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79553 Tested-by: Kaizen Reviewed-by: Georgios Pinitas --- .../kernels/NEPixelWiseMultiplicationKernel.cpp | 45 ++++++++++++++++++++-- 1 file changed, 42 insertions(+), 3 deletions(-) (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp') diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index c3f61ac94a..83d6d8218e 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -38,6 +38,10 @@ #include #include +#if ARM_COMPUTE_ENABLE_FP16 +#include // needed for float16_t +#endif /* ARM_COMPUTE_ENABLE_FP16 */ + using namespace arm_compute; namespace arm_compute @@ -248,6 +252,32 @@ void mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict vst4q_f32(output, result); } +template +void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale) +{ + ARM_COMPUTE_UNUSED(input1_ptr); + ARM_COMPUTE_UNUSED(input2_ptr); + ARM_COMPUTE_UNUSED(output_ptr); +#ifdef ARM_COMPUTE_ENABLE_FP16 + const auto input1 = static_cast(input1_ptr); + const auto input2 = static_cast(input2_ptr); + const auto output = static_cast(output_ptr); + const float16x8x2_t ta1 = vld2q_f16(input1); + const float16x8x2_t ta2 = vld2q_f16(input2); + const float16x8_t scale_vec = vdupq_n_f16(scale); + const float16x8x2_t result = + { + { + vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec), + vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec), + } + }; + vst2q_f16(output, result); +#else /* ARM_COMPUTE_ENABLE_FP16 */ + ARM_COMPUTE_ERROR("Not supported. Recompile the library with arch=arm64-v8.2-a."); +#endif /* ARM_COMPUTE_ENABLE_FP16 */ +} + template void mul_U8_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n) { @@ -347,6 +377,10 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe { set_format_if_unknown(*output->info(), Format::F32); } + else if(input1->info()->data_type() == DataType::F16 || input2->info()->data_type() == DataType::F16) + { + set_format_if_unknown(*output->info(), Format::F16); + } else if(input1->info()->data_type() == DataType::QS8 && input2->info()->data_type() == DataType::QS8) { set_data_type_if_unknown(*output->info(), DataType::QS8); @@ -355,9 +389,9 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe } ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8), "Output can only be U8 if both inputs are U8"); if(input1->info()->data_type() == DataType::QS8) @@ -479,6 +513,11 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe _func_q_int = is_sat ? &mul_QS8_QS8_QS8_n : &mul_QS8_QS8_QS8_n; } } + else if(DataType::F16 == dt_input1 && DataType::F16 == dt_input2 && DataType::F16 == dt_output) + { + _func_float = &mul_F16_F16_F16_n; + _func_int = nullptr; + } else if(DataType::F32 == dt_input1 && DataType::F32 == dt_input2 && DataType::F32 == dt_output) { _func_float = &mul_F32_F32_F32_n; -- cgit v1.2.1