aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorUsama Arif <usama.arif@arm.com>2019-05-20 12:38:33 +0100
committerUsama Arif <usama.arif@arm.com>2019-05-22 15:06:06 +0000
commita4a08ad5e33867f9938a3fbaf9b6dcc56ad8f7b5 (patch)
tree0689ecbe56aed40fd61fa250a4e8a7a98d549bc3
parentb28905010a95044c7a1c0a5665fc886521a56541 (diff)
downloadComputeLibrary-a4a08ad5e33867f9938a3fbaf9b6dcc56ad8f7b5.tar.gz
COMPMID-2280: Implement REDUCE_MIN operator for NEON
Change-Id: Iaa8d97e3328ce69dae7a97a7111120ecc61fb465 Signed-off-by: Usama Arif <usama.arif@arm.com> Reviewed-on: https://review.mlplatform.org/c/1192 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/Types.h3
-rw-r--r--src/core/NEON/kernels/NEReductionOperationKernel.cpp239
-rw-r--r--src/runtime/NEON/functions/NEReductionOperation.cpp47
-rw-r--r--tests/validation/NEON/ReductionOperation.cpp3
-rw-r--r--tests/validation/reference/ReductionOperation.cpp31
-rw-r--r--utils/TypePrinter.h3
6 files changed, 255 insertions, 71 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 83ab2d755a..241c1fe1f4 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -559,7 +559,8 @@ enum class ReductionOperation
MEAN_SUM, /**< Mean of sum */
PROD, /**< Product */
SUM_SQUARE, /**< Sum of squares */
- SUM /**< Sum */
+ SUM, /**< Sum */
+ MIN, /**< Min */
};
/** Available element-wise operations */
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
index 5f0a4dd371..b51d4b311f 100644
--- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
@@ -108,6 +108,34 @@ uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x
return res;
}
+
+// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
+float32x2_t calculate_min(float32x4_t in)
+{
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ return wrapper::vpmin(pmin, pmin);
+}
+
+// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
+float32x2_t calculate_max(float32x4_t in)
+{
+ auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ return wrapper::vpmax(pmax, pmax);
+}
+// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
+int32x2_t calculate_min(int32x4_t in)
+{
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ return wrapper::vpmin(pmin, pmin);
+}
+
+// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
+int32x2_t calculate_max(int32x4_t in)
+{
+ auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ return wrapper::vpmax(pmax, pmax);
+}
+
template <typename T>
uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
{
@@ -116,15 +144,13 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, Reduc
if(op == ReductionOperation::ARG_IDX_MIN)
{
- auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
- pmin = wrapper::vpmin(pmin, pmin);
+ auto pmin = calculate_min(vec_res_value);
auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
}
else
{
- auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
- pmax = wrapper::vpmax(pmax, pmax);
+ auto pmax = calculate_max(vec_res_value);
auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
}
@@ -137,6 +163,23 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, Reduc
return (res - 0xFFFFFFFF);
}
+// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
+inline uint8x8_t calculate_min(uint8x16_t in)
+{
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ pmin = wrapper::vpmin(pmin, pmin);
+ pmin = wrapper::vpmin(pmin, pmin);
+ return wrapper::vpmin(pmin, pmin);
+}
+// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
+inline uint8x8_t calculate_max(uint8x16_t in)
+{
+ auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ pmax = wrapper::vpmax(pmax, pmax);
+ pmax = wrapper::vpmax(pmax, pmax);
+ return wrapper::vpmax(pmax, pmax);
+}
+
uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op)
{
uint32x4x4_t res_idx_mask{ { 0 } };
@@ -144,18 +187,12 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_val
uint8x16_t mask_u8{ 0 };
if(op == ReductionOperation::ARG_IDX_MIN)
{
- auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
- pmin = wrapper::vpmin(pmin, pmin);
- pmin = wrapper::vpmin(pmin, pmin);
- pmin = wrapper::vpmin(pmin, pmin);
+ auto pmin = calculate_min(vec_res_value);
mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
}
else
{
- auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
- pmax = wrapper::vpmax(pmax, pmax);
- pmax = wrapper::vpmax(pmax, pmax);
- pmax = wrapper::vpmax(pmax, pmax);
+ auto pmax = calculate_max(vec_res_value);
mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
}
@@ -220,6 +257,21 @@ uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x
return res;
}
+// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
+inline float16x4_t calculate_min(float16x8_t in)
+{
+ auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ pmin = wrapper::vpmin(pmin, pmin);
+ return wrapper::vpmin(pmin, pmin);
+}
+// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
+inline float16x4_t calculate_max(float16x8_t in)
+{
+ auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
+ pmax = wrapper::vpmax(pmax, pmax);
+ return wrapper::vpmax(pmax, pmax);
+}
+
uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
{
uint32x4x2_t res_idx_mask{ 0 };
@@ -227,16 +279,12 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_va
uint16x8_t mask_u16;
if(op == ReductionOperation::ARG_IDX_MIN)
{
- auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
- pmin = wrapper::vpmin(pmin, pmin);
- pmin = wrapper::vpmin(pmin, pmin);
+ auto pmin = calculate_min(vec_res_value);
mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
}
else
{
- auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
- pmax = wrapper::vpmax(pmax, pmax);
- pmax = wrapper::vpmax(pmax, pmax);
+ auto pmax = calculate_max(vec_res_value);
mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
}
@@ -364,13 +412,22 @@ struct RedOpX
{
ARM_COMPUTE_UNUSED(out_slice);
auto init_res_value = static_cast<T>(0.f);
- if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
- {
- init_res_value = *reinterpret_cast<T *>(input.ptr());
- }
- else if(op == ReductionOperation::PROD)
+ switch(op)
{
- init_res_value = static_cast<T>(1.f);
+ case ReductionOperation::ARG_IDX_MAX:
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::MIN:
+ {
+ init_res_value = *reinterpret_cast<T *>(input.ptr());
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ init_res_value = static_cast<T>(1.f);
+ break;
+ }
+ default:
+ break;
}
auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
uint32x4x4_t vec_res_idx{ { 0 } };
@@ -406,6 +463,11 @@ struct RedOpX
vec_res_value = temp_vec_res_value;
break;
}
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -451,6 +513,11 @@ struct RedOpX
*(reinterpret_cast<uint32_t *>(output.ptr())) = res;
break;
}
+ case ReductionOperation::MIN:
+ {
+ *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -474,7 +541,7 @@ struct RedOpX_qasymm8
uint8x16_t vec_res_value = { 0 };
- if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
+ if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN)
{
vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{});
}
@@ -546,48 +613,64 @@ struct RedOpX_qasymm8
vec_res_value = temp_vec_res_value;
break;
}
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
},
input);
- if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
- {
- auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
- *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
- }
- else if(op == ReductionOperation::PROD)
- {
- auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
- carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
- carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
-
- float res = wrapper::vgetlane(carry_res, 0);
- res *= wrapper::vgetlane(carry_res, 1);
- res *= wrapper::vgetlane(carry_res, 2);
- res *= wrapper::vgetlane(carry_res, 3);
-
- //re-quantize result
- res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset);
- *(output.ptr()) = static_cast<uint8_t>(res);
- }
- else
+ switch(op)
{
- auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
- carry_res = wrapper::vadd(carry_res, vec_res_value3);
- carry_res = wrapper::vadd(carry_res, vec_res_value4);
-
- auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
- carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
- auto res = wrapper::vgetlane(carry_paddition, 0);
-
- if(op == ReductionOperation::MEAN_SUM)
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::ARG_IDX_MAX:
{
- res /= in_info.dimension(0);
+ auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
+ *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
+ break;
}
+ case ReductionOperation::MIN:
+ {
+ *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
+ carry_res = wrapper::vmul(carry_res, vec_res_value3_f);
+ carry_res = wrapper::vmul(carry_res, vec_res_value4_f);
+
+ float res = wrapper::vgetlane(carry_res, 0);
+ res *= wrapper::vgetlane(carry_res, 1);
+ res *= wrapper::vgetlane(carry_res, 2);
+ res *= wrapper::vgetlane(carry_res, 3);
+
+ //re-quantize result
+ res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset);
+ *(output.ptr()) = static_cast<uint8_t>(res);
+ break;
+ }
+ default:
+ {
+ auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
+ carry_res = wrapper::vadd(carry_res, vec_res_value3);
+ carry_res = wrapper::vadd(carry_res, vec_res_value4);
- *(output.ptr()) = static_cast<uint8_t>(res);
+ auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
+ carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
+ auto res = wrapper::vgetlane(carry_paddition, 0);
+
+ if(op == ReductionOperation::MEAN_SUM)
+ {
+ res /= in_info.dimension(0);
+ }
+
+ *(output.ptr()) = static_cast<uint8_t>(res);
+ }
}
}
};
@@ -606,17 +689,25 @@ struct RedOpYZW
execute_window_loop(in_slice, [&](const Coordinates &)
{
neon_vector vec_res_value = { 0 };
- if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
- {
- vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
- }
- else if(op == ReductionOperation::PROD)
- {
- vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
- }
- else
+ switch(op)
{
- vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+ case ReductionOperation::ARG_IDX_MAX:
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr()));
+ break;
+ }
+ case ReductionOperation::PROD:
+ {
+ vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
+ break;
+ }
+ default:
+ {
+ vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+ break;
+ }
}
uint32x4x4_t vec_res_idx{ { 0 } };
@@ -665,6 +756,11 @@ struct RedOpYZW
vec_res_value = temp_vec_res_value;
break;
}
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -849,6 +945,11 @@ struct RedOpYZW_qasymm8
vec_res_value = temp_vec_res_value;
break;
}
+ case ReductionOperation::MIN:
+ {
+ vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -891,6 +992,10 @@ struct RedOpYZW_qasymm8
wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 8, vec_res_idx.val[2]);
wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + 12, vec_res_idx.val[3]);
}
+ else if(op == ReductionOperation::ARG_IDX_MIN)
+ {
+ wrapper::vstore(output.ptr(), vec_res_value);
+ }
else
{
const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
diff --git a/src/runtime/NEON/functions/NEReductionOperation.cpp b/src/runtime/NEON/functions/NEReductionOperation.cpp
index a0aed96521..81bb32f5dc 100644
--- a/src/runtime/NEON/functions/NEReductionOperation.cpp
+++ b/src/runtime/NEON/functions/NEReductionOperation.cpp
@@ -78,7 +78,52 @@ void NEReductionOperation::configure(ITensor *input, ITensor *output, unsigned i
{
// Configure fill border kernel
const BorderSize fill_border_size = _reduction_kernel.border_size();
- const PixelValue pixelValue = (op == ReductionOperation::PROD) ? PixelValue(1, input->info()->data_type(), input->info()->quantization_info()) : PixelValue(0, input->info()->data_type());
+ PixelValue pixelValue;
+ switch(op)
+ {
+ case ReductionOperation::PROD:
+ {
+ pixelValue = PixelValue(1, input->info()->data_type(), input->info()->quantization_info());
+ break;
+ }
+ case ReductionOperation::MIN:
+ {
+ switch(input->info()->data_type())
+ {
+ case DataType::F32:
+ {
+ pixelValue = PixelValue(std::numeric_limits<float>::max());
+ break;
+ }
+ case DataType::F16:
+ {
+ pixelValue = PixelValue(static_cast<half>(65504.0f));
+ break;
+ }
+ case DataType::QASYMM8:
+ {
+ pixelValue = PixelValue(255, input->info()->data_type(), input->info()->quantization_info());
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported DataType");
+ }
+ }
+ break;
+ }
+ case ReductionOperation::ARG_IDX_MAX:
+ case ReductionOperation::ARG_IDX_MIN:
+ case ReductionOperation::MEAN_SUM:
+ case ReductionOperation::SUM_SQUARE:
+ case ReductionOperation::SUM:
+ {
+ pixelValue = PixelValue(0, input->info()->data_type());
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Reduction Operation unsupported");
+ }
_fill_border_kernel.configure(input, fill_border_size, BorderMode::CONSTANT, pixelValue);
}
}
diff --git a/tests/validation/NEON/ReductionOperation.cpp b/tests/validation/NEON/ReductionOperation.cpp
index b9b4983ae6..074689d678 100644
--- a/tests/validation/NEON/ReductionOperation.cpp
+++ b/tests/validation/NEON/ReductionOperation.cpp
@@ -51,7 +51,8 @@ RelativeTolerance<float> tolerance_qasymm8(1);
const auto ReductionOperations = framework::dataset::make("ReductionOperation",
{
ReductionOperation::SUM,
- ReductionOperation::PROD
+ ReductionOperation::PROD,
+ ReductionOperation::MIN,
});
const auto QuantizationInfos = framework::dataset::make("QuantizationInfo",
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp
index c7624a4628..1f825f0e0f 100644
--- a/tests/validation/reference/ReductionOperation.cpp
+++ b/tests/validation/reference/ReductionOperation.cpp
@@ -42,7 +42,24 @@ template <typename T, typename OT>
OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, int stride)
{
using type = typename std::remove_cv<OT>::type;
- auto res = (op == ReductionOperation::PROD) ? type(1) : type(0);
+ T res;
+ switch(op)
+ {
+ case ReductionOperation::PROD:
+ {
+ res = type(1);
+ }
+ break;
+ case ReductionOperation::MIN:
+ {
+ res = *ptr;
+ }
+ break;
+ default:
+ {
+ res = type(0);
+ }
+ }
if(std::is_integral<type>::value)
{
@@ -65,6 +82,12 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in
int_res = static_cast<uint32_t>(i);
}
break;
+ case ReductionOperation::MIN:
+ if(static_cast<T>(int_res) > elem)
+ {
+ int_res = elem;
+ }
+ break;
case ReductionOperation::SUM_SQUARE:
int_res += elem * elem;
break;
@@ -104,6 +127,12 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in
res = static_cast<uint32_t>(i);
}
break;
+ case ReductionOperation::MIN:
+ if(res > elem)
+ {
+ res = elem;
+ }
+ break;
case ReductionOperation::SUM_SQUARE:
res += elem * elem;
break;
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index b8927615ee..9b8efe5a23 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -1446,6 +1446,9 @@ inline ::std::ostream &operator<<(::std::ostream &os, const ReductionOperation &
case ReductionOperation::PROD:
os << "PROD";
break;
+ case ReductionOperation::MIN:
+ os << "MIN";
+ break;
default:
ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
}