aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEReductionOperationKernel.cpp
diff options
context:
space:
mode:
authorUsama Arif <usama.arif@arm.com>2019-05-20 12:38:33 +0100
committerUsama Arif <usama.arif@arm.com>2019-05-22 15:06:06 +0000
commita4a08ad5e33867f9938a3fbaf9b6dcc56ad8f7b5 (patch)
tree0689ecbe56aed40fd61fa250a4e8a7a98d549bc3 /src/core/NEON/kernels/NEReductionOperationKernel.cpp
parentb28905010a95044c7a1c0a5665fc886521a56541 (diff)
downloadComputeLibrary-a4a08ad5e33867f9938a3fbaf9b6dcc56ad8f7b5.tar.gz
COMPMID-2280: Implement REDUCE_MIN operator for NEON
Change-Id: Iaa8d97e3328ce69dae7a97a7111120ecc61fb465 Signed-off-by: Usama Arif <usama.arif@arm.com> Reviewed-on: https://review.mlplatform.org/c/1192 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEReductionOperationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEReductionOperationKernel.cpp239
1 files changed, 172 insertions, 67 deletions
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
index 5f0a4dd371..b51d4b311f 100644
--- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
@@ -108,6 +108,34 @@ uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x
return res;
}
+
+// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
+float32x2_t calculate_min(float32x4_t in)
+{
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ return wrapper::vpmin(pmin, pmin);
+}
+
+// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
+float32x2_t calculate_max(float32x4_t in)
+{
+ auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ return wrapper::vpmax(pmax, pmax);
+}
+// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
+int32x2_t calculate_min(int32x4_t in)
+{
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ return wrapper::vpmin(pmin, pmin);
+}
+
+// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
+int32x2_t calculate_max(int32x4_t in)
+{
+ auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ return wrapper::vpmax(pmax, pmax);
+}
+
template <typename T>
uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
{
@@ -116,15 +144,13 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, Reduc
if(op == ReductionOperation::ARG_IDX_MIN)
{
- auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
- pmin = wrapper::vpmin(pmin, pmin);
+ auto pmin = calculate_min(vec_res_value);
auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
}
else
{
- auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
- pmax = wrapper::vpmax(pmax, pmax);
+ auto pmax = calculate_max(vec_res_value);
auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
}
@@ -137,6 +163,23 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, Reduc
return (res - 0xFFFFFFFF);
}
+// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
+inline uint8x8_t calculate_min(uint8x16_t in)
+{
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ pmin = wrapper::vpmin(pmin, pmin);
+ pmin = wrapper::vpmin(pmin, pmin);
+ return wrapper::vpmin(pmin, pmin);
+}
+// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
+inline uint8x8_t calculate_max(uint8x16_t in)
+{
+ auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ pmax = wrapper::vpmax(pmax, pmax);
+ pmax = wrapper::vpmax(pmax, pmax);
+ return wrapper::vpmax(pmax, pmax);
+}
+
uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op)
{
uint32x4x4_t res_idx_mask{ { 0 } };
@@ -144,18 +187,12 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_val
uint8x16_t mask_u8{ 0 };
if(op == ReductionOperation::ARG_IDX_MIN)
{
- auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
- pmin = wrapper::vpmin(pmin, pmin);
- pmin = wrapper::vpmin(pmin, pmin);
- pmin = wrapper::vpmin(pmin, pmin);
+ auto pmin = calculate_min(vec_res_value);
mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
}
else
{
- auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
- pmax = wrapper::vpmax(pmax, pmax);
- pmax = wrapper::vpmax(pmax, pmax);
- pmax = wrapper::vpmax(pmax, pmax);
+ auto pmax = calculate_max(vec_res_value);
mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
}
@@ -220,6 +257,21 @@ uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x
return res;
}
+// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
+inline float16x4_t calculate_min(float16x8_t in)
+{
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ pmin = wrapper::vpmin(pmin, pmin);
+ return wrapper::vpmin(pmin, pmin);
+}
+// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
+inline float16x4_t calculate_max(float16x8_t in)
+{
+ auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ pmax = wrapper::vpmax(pmax, pmax);
+ return wrapper::vpmax(pmax, pmax);
+}
+
uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
{
uint32x4x2_t res_idx_mask{ 0 };
@@ -227,16 +279,12 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_va
uint16x8_t mask_u16;
if(op == ReductionOperation::ARG_IDX_MIN)
{
- auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
- pmin = wrapper::vpmin(pmin, pmin);
- pmin = wrapper::vpmin(pmin, pmin);
+ auto pmin = calculate_min(vec_res_value);
mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
}
else
{
- auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
- pmax = wrapper::vpmax(pmax, pmax);
- pmax = wrapper::vpmax(pmax, pmax);
+ auto pmax = calculate_max(vec_res_value);
mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
}
@@ -364,13 +412,22 @@ struct RedOpX
{
ARM_COMPUTE_UNUSED(out_slice);
auto init_res_value = static_cast<T>(0.f);
- if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
- {
- init_res_value = *reinterpret_cast<T *>(input.ptr());
- }
- else if(op == ReductionOperation::PROD)
+ switch(op)
{
- init_res_value = static_cast<T>(1.f);
+ case ReductionOperation::ARG_IDX_MAX:
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::MIN:
+ {
+ init_res_value = *reinterpret_cast<T *>(input.ptr());
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ init_res_value = static_cast<T>(1.f);
+ break;
+ }
+ default:
+ break;
}
auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
uint32x4x4_t vec_res_idx{ { 0 } };
@@ -406,6 +463,11 @@ struct RedOpX
vec_res_value = temp_vec_res_value;
break;
}
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -451,6 +513,11 @@ struct RedOpX
*(reinterpret_cast<uint32_t *>(output.ptr())) = res;
break;
}
+ case ReductionOperation::MIN:
+ {
+ *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -474,7 +541,7 @@ struct RedOpX_qasymm8
uint8x16_t vec_res_value = { 0 };
- if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
+ if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN)
{
vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
}
@@ -546,48 +613,64 @@ struct RedOpX_qasymm8
vec_res_value = temp_vec_res_value;
break;
}
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
},
input);
- if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
- {
- auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
- *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
- }
- else if(op == ReductionOperation::PROD)
- {
- auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
- carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
- carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
-
- float res = wrapper::vgetlane(carry_res, 0);
- res *= wrapper::vgetlane(carry_res, 1);
- res *= wrapper::vgetlane(carry_res, 2);
- res *= wrapper::vgetlane(carry_res, 3);
-
- //re-quantize result
- res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset);
- *(output.ptr()) = static_cast<uint8_t>(res);
- }
- else
+ switch(op)
{
- auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
- carry_res = wrapper::vadd(carry_res, vec_res_value3);
- carry_res = wrapper::vadd(carry_res, vec_res_value4);
-
- auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
- carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
- auto res = wrapper::vgetlane(carry_paddition, 0);
-
- if(op == ReductionOperation::MEAN_SUM)
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::ARG_IDX_MAX:
{
- res /= in_info.dimension(0);
+ auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
+ *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
+ break;
}
+ case ReductionOperation::MIN:
+ {
+ *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
+ carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
+ carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
+
+ float res = wrapper::vgetlane(carry_res, 0);
+ res *= wrapper::vgetlane(carry_res, 1);
+ res *= wrapper::vgetlane(carry_res, 2);
+ res *= wrapper::vgetlane(carry_res, 3);
+
+ //re-quantize result
+ res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset);
+ *(output.ptr()) = static_cast<uint8_t>(res);
+ break;
+ }
+ default:
+ {
+ auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
+ carry_res = wrapper::vadd(carry_res, vec_res_value3);
+ carry_res = wrapper::vadd(carry_res, vec_res_value4);
- *(output.ptr()) = static_cast<uint8_t>(res);
+ auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
+ carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
+ auto res = wrapper::vgetlane(carry_paddition, 0);
+
+ if(op == ReductionOperation::MEAN_SUM)
+ {
+ res /= in_info.dimension(0);
+ }
+
+ *(output.ptr()) = static_cast<uint8_t>(res);
+ }
}
}
};
@@ -606,17 +689,25 @@ struct RedOpYZW
execute_window_loop(in_slice, [&](const Coordinates &)
{
neon_vector vec_res_value = { 0 };
- if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
- {
- vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
- }
- else if(op == ReductionOperation::PROD)
- {
- vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
- }
- else
+ switch(op)
{
- vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+ case ReductionOperation::ARG_IDX_MAX:
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
+ break;
+ }
+ default:
+ {
+ vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+ break;
+ }
}
uint32x4x4_t vec_res_idx{ { 0 } };
@@ -665,6 +756,11 @@ struct RedOpYZW
vec_res_value = temp_vec_res_value;
break;
}
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -849,6 +945,11 @@ struct RedOpYZW_qasymm8
vec_res_value = temp_vec_res_value;
break;
}
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -891,6 +992,10 @@ struct RedOpYZW_qasymm8
wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
}
+ else if(op == ReductionOperation::ARG_IDX_MIN)
+ {
+ wrapper::vstore(output.ptr(), vec_res_value);
+ }
else
{
const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));