aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp175
1 files changed, 140 insertions, 35 deletions
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index a87588dbb3..ca59e66293 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2019 ARM Limited.
+ * Copyright (c) 2016-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,23 +24,12 @@
#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
#include "arm_compute/core/CPP/Validate.h"
-#include "arm_compute/core/Error.h"
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/IAccessWindow.h"
-#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/NEON/NEAsymm.h"
-#include "arm_compute/core/NEON/NEFixedPoint.h"
#include "arm_compute/core/NEON/NESymm.h"
#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/TensorInfo.h"
-#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Validate.h"
#include <arm_neon.h>
-#include <climits>
-#include <cmath>
-#include <cstdint>
-#include <cstdlib>
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <arm_fp16.h> // needed for float16_t
@@ -48,8 +37,6 @@
namespace arm_compute
{
-class Coordinates;
-
namespace
{
const float scale255_constant = 1.f / 255.f;
@@ -66,10 +53,9 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
- "Output can only be U8 if both inputs are U8");
-
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
+ DataType::S16, DataType::QSYMM16,
+ DataType::S32, DataType::F16, DataType::F32);
if(is_data_type_quantized(input1->data_type()) || is_data_type_quantized(input2->data_type()))
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
@@ -86,6 +72,12 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i
const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
+ "Output can only be U8 if both inputs are U8");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16),
+ "Output can only be S32 if both inputs are QSYMM16");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output");
}
if(std::abs(scale - scale255_constant) < 0.00001f)
@@ -266,8 +258,20 @@ void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, c
const auto input2 = static_cast<const qsymm16_t *__restrict>(input2_ptr);
const auto output = static_cast<qsymm16_t *__restrict>(output_ptr);
- const qsymm16x8x2_t input1_q = vld2q_s16(input1);
- const qsymm16x8x2_t input2_q = vld2q_s16(input2);
+ const qsymm16x8x2_t input1_q =
+ {
+ {
+ vld1q_s16(input1),
+ vld1q_s16(input1 + 8),
+ }
+ };
+ const qsymm16x8x2_t input2_q =
+ {
+ {
+ vld1q_s16(input2),
+ vld1q_s16(input2 + 8),
+ }
+ };
// Dequantitize inputs
const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
@@ -284,7 +288,65 @@ void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, c
};
const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info);
- vst2q_s16(output, result);
+ vst1q_s16(output, result.val[0]);
+ vst1q_s16(output + 8, result.val[1]);
+}
+
+void mul_QSYMM16_QSYMM16_S32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int scale)
+{
+ ARM_COMPUTE_UNUSED(scale);
+ const auto input1 = static_cast<const qsymm16_t *__restrict>(input1_ptr);
+ const auto input2 = static_cast<const qsymm16_t *__restrict>(input2_ptr);
+ const auto output = static_cast<int32_t *__restrict>(output_ptr);
+
+ const qsymm16x8x2_t input1_q =
+ {
+ {
+ vld1q_s16(input1),
+ vld1q_s16(input1 + 8),
+ }
+ };
+ const qsymm16x8x2_t input2_q =
+ {
+ {
+ vld1q_s16(input2),
+ vld1q_s16(input2 + 8),
+ }
+ };
+
+ const int32x4x4_t in1_s32 =
+ {
+ {
+ vmovl_s16(vget_low_s16(input1_q.val[0])),
+ vmovl_s16(vget_high_s16(input1_q.val[0])),
+ vmovl_s16(vget_low_s16(input1_q.val[1])),
+ vmovl_s16(vget_high_s16(input1_q.val[1])),
+ }
+ };
+ const int32x4x4_t in2_s32 =
+ {
+ {
+ vmovl_s16(vget_low_s16(input2_q.val[0])),
+ vmovl_s16(vget_high_s16(input2_q.val[0])),
+ vmovl_s16(vget_low_s16(input2_q.val[1])),
+ vmovl_s16(vget_high_s16(input2_q.val[1])),
+ }
+ };
+
+ const int32x4x4_t result =
+ {
+ {
+ vmulq_s32(in1_s32.val[0], in2_s32.val[0]),
+ vmulq_s32(in1_s32.val[1], in2_s32.val[1]),
+ vmulq_s32(in1_s32.val[2], in2_s32.val[2]),
+ vmulq_s32(in1_s32.val[3], in2_s32.val[3]),
+ }
+ };
+
+ vst1q_s32(output, result.val[0]);
+ vst1q_s32(output + 4, result.val[1]);
+ vst1q_s32(output + 8, result.val[2]);
+ vst1q_s32(output + 12, result.val[3]);
}
template <bool is_scale255, bool is_sat>
@@ -412,11 +474,24 @@ void mul_S16_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict
const auto input2 = static_cast<const int16_t *__restrict>(input2_ptr);
const auto output = static_cast<int16_t *__restrict>(output_ptr);
- const int16x8x2_t ta1 = vld2q_s16(input1);
- const int16x8x2_t ta2 = vld2q_s16(input2);
+ const int16x8x2_t ta1 =
+ {
+ {
+ vld1q_s16(input1),
+ vld1q_s16(input1 + 8),
+ }
+ };
+ const int16x8x2_t ta2 =
+ {
+ {
+ vld1q_s16(input2),
+ vld1q_s16(input2 + 8),
+ }
+ };
const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
- vst2q_s16(output, result);
+ vst1q_s16(output, result.val[0]);
+ vst1q_s16(output + 8, result.val[1]);
}
void mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale)
@@ -472,11 +547,23 @@ void c_mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restri
void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale)
{
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- const auto input1 = static_cast<const float16_t *__restrict>(input1_ptr);
- const auto input2 = static_cast<const float16_t *__restrict>(input2_ptr);
- const auto output = static_cast<float16_t *__restrict>(output_ptr);
- const float16x8x2_t ta1 = vld2q_f16(input1);
- const float16x8x2_t ta2 = vld2q_f16(input2);
+ const auto input1 = static_cast<const float16_t *__restrict>(input1_ptr);
+ const auto input2 = static_cast<const float16_t *__restrict>(input2_ptr);
+ const auto output = static_cast<float16_t *__restrict>(output_ptr);
+ const float16x8x2_t ta1 =
+ {
+ {
+ vld1q_f16(input1),
+ vld1q_f16(input1 + 8),
+ }
+ };
+ const float16x8x2_t ta2 =
+ {
+ {
+ vld1q_f16(input2),
+ vld1q_f16(input2 + 8),
+ }
+ };
const float16x8_t scale_vec = vdupq_n_f16(scale);
const float16x8x2_t result =
{
@@ -485,7 +572,8 @@ void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict
vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
}
};
- vst2q_f16(output, result);
+ vst1q_f16(output, result.val[0]);
+ vst1q_f16(output + 8, result.val[1]);
#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
ARM_COMPUTE_UNUSED(input1_ptr);
ARM_COMPUTE_UNUSED(input2_ptr);
@@ -550,8 +638,20 @@ void mul_S16_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict
const auto input2 = static_cast<const uint8_t *__restrict>(input2_ptr);
const auto output = static_cast<int16_t *__restrict>(output_ptr);
- const int16x8x2_t ta1 = vld2q_s16(input1);
- const uint8x8x2_t ta2u = vld2_u8(input2);
+ const int16x8x2_t ta1 =
+ {
+ {
+ vld1q_s16(input1),
+ vld1q_s16(input1 + 8),
+ }
+ };
+ const uint8x8x2_t ta2u =
+ {
+ {
+ vld1_u8(input2),
+ vld1_u8(input2 + 8),
+ }
+ };
const int16x8x2_t ta2 =
{
{
@@ -562,7 +662,8 @@ void mul_S16_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict
const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
- vst2q_s16(output, result);
+ vst1q_s16(output, result.val[0]);
+ vst1q_s16(output + 8, result.val[1]);
}
template <bool is_scale255, bool is_sat>
@@ -629,10 +730,14 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
{
_func_quantized = &mul_saturate_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED_n;
}
- else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16)
+ else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16 && dt_output == DataType::QSYMM16)
{
_func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16_n;
}
+ else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16 && dt_output == DataType::S32)
+ {
+ _func_int = &mul_QSYMM16_QSYMM16_S32_n;
+ }
else if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::U8 == dt_output)
{
if(is_scale_255)
@@ -750,7 +855,7 @@ void NEPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo
Iterator input2(_input2, slice_input2);
Iterator output(_output, slice);
- if(is_data_type_quantized(_input1->info()->data_type()))
+ if((_run_optimized_qasymm8) || (_func_quantized != nullptr))
{
if(_run_optimized_qasymm8)
{