From 28f0dd99fba11ed9b7165eca17d801bdfb421576 Mon Sep 17 00:00:00 2001 From: Usama Arif Date: Mon, 20 May 2019 13:44:34 +0100 Subject: COMPMID-2279: Implement REDUCE_MAX operator for NEON Change-Id: Iccd25b8aab1dd871c0d86ec3816b1cbf48370066 Signed-off-by: Usama Arif Reviewed-on: https://review.mlplatform.org/c/1193 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez --- .../NEON/kernels/NEReductionOperationKernel.cpp | 34 +++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) (limited to 'src/core/NEON/kernels/NEReductionOperationKernel.cpp') diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp index b51d4b311f..e6edf22083 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp +++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp @@ -417,6 +417,7 @@ struct RedOpX case ReductionOperation::ARG_IDX_MAX: case ReductionOperation::ARG_IDX_MIN: case ReductionOperation::MIN: + case ReductionOperation::MAX: { init_res_value = *reinterpret_cast(input.ptr()); break; @@ -468,6 +469,11 @@ struct RedOpX vec_res_value = wrapper::vmin(vec_elements, vec_res_value); break; } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -518,6 +524,11 @@ struct RedOpX *(reinterpret_cast(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0); break; } + case ReductionOperation::MAX: + { + *(reinterpret_cast(output.ptr())) = wrapper::vgetlane(calculate_max(vec_res_value), 0); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -541,7 +552,7 @@ struct RedOpX_qasymm8 uint8x16_t vec_res_value = { 0 }; - if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN) + if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX) { vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{}); } @@ -618,6 +629,11 @@ struct RedOpX_qasymm8 vec_res_value = wrapper::vmin(vec_elements, vec_res_value); break; } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -638,6 +654,11 @@ struct RedOpX_qasymm8 *(output.ptr()) = static_cast(wrapper::vgetlane(calculate_min(vec_res_value), 0)); break; } + case ReductionOperation::MAX: + { + *(output.ptr()) = static_cast(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + break; + } case ReductionOperation::PROD: { auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); @@ -694,6 +715,7 @@ struct RedOpYZW case ReductionOperation::ARG_IDX_MAX: case ReductionOperation::ARG_IDX_MIN: case ReductionOperation::MIN: + case ReductionOperation::MAX: { vec_res_value = wrapper::vloadq(reinterpret_cast(input.ptr())); break; @@ -761,6 +783,11 @@ struct RedOpYZW vec_res_value = wrapper::vmin(vec_elements, vec_res_value); break; } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -950,6 +977,11 @@ struct RedOpYZW_qasymm8 vec_res_value = wrapper::vmin(vec_elements, vec_res_value); break; } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } -- cgit v1.2.1