From 28f0dd99fba11ed9b7165eca17d801bdfb421576 Mon Sep 17 00:00:00 2001 From: Usama Arif Date: Mon, 20 May 2019 13:44:34 +0100 Subject: COMPMID-2279: Implement REDUCE_MAX operator for NEON Change-Id: Iccd25b8aab1dd871c0d86ec3816b1cbf48370066 Signed-off-by: Usama Arif Reviewed-on: https://review.mlplatform.org/c/1193 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez --- arm_compute/core/Types.h | 1 + .../NEON/kernels/NEReductionOperationKernel.cpp | 34 +++++++++++++++++++++- .../NEON/functions/NEReductionOperation.cpp | 26 +++++++++++++++++ tests/validation/NEON/ReductionOperation.cpp | 1 + tests/validation/reference/ReductionOperation.cpp | 13 +++++++++ utils/TypePrinter.h | 3 ++ 6 files changed, 77 insertions(+), 1 deletion(-) diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 241c1fe1f4..65db06b878 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -561,6 +561,7 @@ enum class ReductionOperation SUM_SQUARE, /**< Sum of squares */ SUM, /**< Sum */ MIN, /**< Min */ + MAX, /**< Max */ }; /** Available element-wise operations */ diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp index b51d4b311f..e6edf22083 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp +++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp @@ -417,6 +417,7 @@ struct RedOpX case ReductionOperation::ARG_IDX_MAX: case ReductionOperation::ARG_IDX_MIN: case ReductionOperation::MIN: + case ReductionOperation::MAX: { init_res_value = *reinterpret_cast(input.ptr()); break; @@ -468,6 +469,11 @@ struct RedOpX vec_res_value = wrapper::vmin(vec_elements, vec_res_value); break; } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -518,6 +524,11 @@ struct RedOpX *(reinterpret_cast(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0); break; } + case ReductionOperation::MAX: + { + *(reinterpret_cast(output.ptr())) = wrapper::vgetlane(calculate_max(vec_res_value), 0); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -541,7 +552,7 @@ struct RedOpX_qasymm8 uint8x16_t vec_res_value = { 0 }; - if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN) + if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX) { vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{}); } @@ -618,6 +629,11 @@ struct RedOpX_qasymm8 vec_res_value = wrapper::vmin(vec_elements, vec_res_value); break; } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -638,6 +654,11 @@ struct RedOpX_qasymm8 *(output.ptr()) = static_cast(wrapper::vgetlane(calculate_min(vec_res_value), 0)); break; } + case ReductionOperation::MAX: + { + *(output.ptr()) = static_cast(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + break; + } case ReductionOperation::PROD: { auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); @@ -694,6 +715,7 @@ struct RedOpYZW case ReductionOperation::ARG_IDX_MAX: case ReductionOperation::ARG_IDX_MIN: case ReductionOperation::MIN: + case ReductionOperation::MAX: { vec_res_value = wrapper::vloadq(reinterpret_cast(input.ptr())); break; @@ -761,6 +783,11 @@ struct RedOpYZW vec_res_value = wrapper::vmin(vec_elements, vec_res_value); break; } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -950,6 +977,11 @@ struct RedOpYZW_qasymm8 vec_res_value = wrapper::vmin(vec_elements, vec_res_value); break; } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } diff --git a/src/runtime/NEON/functions/NEReductionOperation.cpp b/src/runtime/NEON/functions/NEReductionOperation.cpp index 81bb32f5dc..dc6cf59019 100644 --- a/src/runtime/NEON/functions/NEReductionOperation.cpp +++ b/src/runtime/NEON/functions/NEReductionOperation.cpp @@ -112,6 +112,32 @@ void NEReductionOperation::configure(ITensor *input, ITensor *output, unsigned i } break; } + case ReductionOperation::MAX: + { + switch(input->info()->data_type()) + { + case DataType::F32: + { + pixelValue = PixelValue(-std::numeric_limits::max()); + break; + } + case DataType::F16: + { + pixelValue = PixelValue(static_cast(-65504.0f)); + break; + } + case DataType::QASYMM8: + { + pixelValue = PixelValue(0, input->info()->data_type(), input->info()->quantization_info()); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported DataType"); + } + } + break; + } case ReductionOperation::ARG_IDX_MAX: case ReductionOperation::ARG_IDX_MIN: case ReductionOperation::MEAN_SUM: diff --git a/tests/validation/NEON/ReductionOperation.cpp b/tests/validation/NEON/ReductionOperation.cpp index 074689d678..5b697a5efa 100644 --- a/tests/validation/NEON/ReductionOperation.cpp +++ b/tests/validation/NEON/ReductionOperation.cpp @@ -53,6 +53,7 @@ const auto ReductionOperations = framework::dataset::make("ReductionOperation", ReductionOperation::SUM, ReductionOperation::PROD, ReductionOperation::MIN, + ReductionOperation::MAX, }); const auto QuantizationInfos = framework::dataset::make("QuantizationInfo", diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp index 1f825f0e0f..571b991b92 100644 --- a/tests/validation/reference/ReductionOperation.cpp +++ b/tests/validation/reference/ReductionOperation.cpp @@ -51,6 +51,7 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in } break; case ReductionOperation::MIN: + case ReductionOperation::MAX: { res = *ptr; } @@ -88,6 +89,12 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in int_res = elem; } break; + case ReductionOperation::MAX: + if(static_cast(int_res) < elem) + { + int_res = elem; + } + break; case ReductionOperation::SUM_SQUARE: int_res += elem * elem; break; @@ -133,6 +140,12 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in res = elem; } break; + case ReductionOperation::MAX: + if(res < elem) + { + res = elem; + } + break; case ReductionOperation::SUM_SQUARE: res += elem * elem; break; diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h index 9b8efe5a23..74dd0bbc35 100644 --- a/utils/TypePrinter.h +++ b/utils/TypePrinter.h @@ -1449,6 +1449,9 @@ inline ::std::ostream &operator<<(::std::ostream &os, const ReductionOperation & case ReductionOperation::MIN: os << "MIN"; break; + case ReductionOperation::MAX: + os << "MAX"; + break; default: ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); } -- cgit v1.2.1