aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2017-07-03 16:25:09 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:15:39 +0100
commitdf24618b53cffed1c574e11e9fd4ba7740f8c009 (patch)
tree1f1145bca27c5dd0ca63538c2e8cdadd2b0a03cf /src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
parentd1b0ecc206e3858327503888c4a46842ec1808e9 (diff)
downloadComputeLibrary-df24618b53cffed1c574e11e9fd4ba7740f8c009.tar.gz
COMPMID-421: Added FP16 suppot to NENormalizationLayer and NEPixelWiseMultiplication.
Change-Id: If174f8071502fc5cc94b27cd44a9b1d5e451a9e2 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79553 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp45
1 files changed, 42 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index c3f61ac94a..83d6d8218e 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -38,6 +38,10 @@
#include <cstdint>
#include <cstdlib>
+#if ARM_COMPUTE_ENABLE_FP16
+#include <arm_fp16.h> // needed for float16_t
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
using namespace arm_compute;
namespace arm_compute
@@ -249,6 +253,32 @@ void mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict
}
template <bool is_scale255, bool is_sat>
+void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale)
+{
+ ARM_COMPUTE_UNUSED(input1_ptr);
+ ARM_COMPUTE_UNUSED(input2_ptr);
+ ARM_COMPUTE_UNUSED(output_ptr);
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ const auto input1 = static_cast<const float16_t *__restrict>(input1_ptr);
+ const auto input2 = static_cast<const float16_t *__restrict>(input2_ptr);
+ const auto output = static_cast<float16_t *__restrict>(output_ptr);
+ const float16x8x2_t ta1 = vld2q_f16(input1);
+ const float16x8x2_t ta2 = vld2q_f16(input2);
+ const float16x8_t scale_vec = vdupq_n_f16(scale);
+ const float16x8x2_t result =
+ {
+ {
+ vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec),
+ vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
+ }
+ };
+ vst2q_f16(output, result);
+#else /* ARM_COMPUTE_ENABLE_FP16 */
+ ARM_COMPUTE_ERROR("Not supported. Recompile the library with arch=arm64-v8.2-a.");
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+}
+
+template <bool is_scale255, bool is_sat>
void mul_U8_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
{
const auto input1 = static_cast<const uint8_t *__restrict>(input1_ptr);
@@ -347,6 +377,10 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
{
set_format_if_unknown(*output->info(), Format::F32);
}
+ else if(input1->info()->data_type() == DataType::F16 || input2->info()->data_type() == DataType::F16)
+ {
+ set_format_if_unknown(*output->info(), Format::F16);
+ }
else if(input1->info()->data_type() == DataType::QS8 && input2->info()->data_type() == DataType::QS8)
{
set_data_type_if_unknown(*output->info(), DataType::QS8);
@@ -355,9 +389,9 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
}
ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
"Output can only be U8 if both inputs are U8");
if(input1->info()->data_type() == DataType::QS8)
@@ -479,6 +513,11 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
_func_q_int = is_sat ? &mul_QS8_QS8_QS8_n<false, true> : &mul_QS8_QS8_QS8_n<false, false>;
}
}
+ else if(DataType::F16 == dt_input1 && DataType::F16 == dt_input2 && DataType::F16 == dt_output)
+ {
+ _func_float = &mul_F16_F16_F16_n<false, false>;
+ _func_int = nullptr;
+ }
else if(DataType::F32 == dt_input1 && DataType::F32 == dt_input2 && DataType::F32 == dt_output)
{
_func_float = &mul_F32_F32_F32_n<false, false>;