From a4a08ad5e33867f9938a3fbaf9b6dcc56ad8f7b5 Mon Sep 17 00:00:00 2001 From: Usama Arif Date: Mon, 20 May 2019 12:38:33 +0100 Subject: COMPMID-2280: Implement REDUCE_MIN operator for NEON Change-Id: Iaa8d97e3328ce69dae7a97a7111120ecc61fb465 Signed-off-by: Usama Arif Reviewed-on: https://review.mlplatform.org/c/1192 Comments-Addressed: Arm Jenkins Reviewed-by: Pablo Marquez Tested-by: Arm Jenkins --- arm_compute/core/Types.h | 3 +- .../NEON/kernels/NEReductionOperationKernel.cpp | 239 +++++++++++++++------ .../NEON/functions/NEReductionOperation.cpp | 47 +++- tests/validation/NEON/ReductionOperation.cpp | 3 +- tests/validation/reference/ReductionOperation.cpp | 31 ++- utils/TypePrinter.h | 3 + 6 files changed, 255 insertions(+), 71 deletions(-) diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 83ab2d755a..241c1fe1f4 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -559,7 +559,8 @@ enum class ReductionOperation MEAN_SUM, /**< Mean of sum */ PROD, /**< Product */ SUM_SQUARE, /**< Sum of squares */ - SUM /**< Sum */ + SUM, /**< Sum */ + MIN, /**< Min */ }; /** Available element-wise operations */ diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp index 5f0a4dd371..b51d4b311f 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp +++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp @@ -108,6 +108,34 @@ uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x return res; } + +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +float32x2_t calculate_min(float32x4_t in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + return wrapper::vpmin(pmin, pmin); +} + +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +float32x2_t calculate_max(float32x4_t in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + return wrapper::vpmax(pmax, pmax); +} +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +int32x2_t calculate_min(int32x4_t in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + return wrapper::vpmin(pmin, pmin); +} + +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +int32x2_t calculate_max(int32x4_t in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + return wrapper::vpmax(pmax, pmax); +} + template uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) { @@ -116,15 +144,13 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, Reduc if(op == ReductionOperation::ARG_IDX_MIN) { - auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - pmin = wrapper::vpmin(pmin, pmin); + auto pmin = calculate_min(vec_res_value); auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); } else { - auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - pmax = wrapper::vpmax(pmax, pmax); + auto pmax = calculate_max(vec_res_value); auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); } @@ -137,6 +163,23 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, Reduc return (res - 0xFFFFFFFF); } +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +inline uint8x8_t calculate_min(uint8x16_t in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmin = wrapper::vpmin(pmin, pmin); + pmin = wrapper::vpmin(pmin, pmin); + return wrapper::vpmin(pmin, pmin); +} +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +inline uint8x8_t calculate_max(uint8x16_t in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmax = wrapper::vpmax(pmax, pmax); + pmax = wrapper::vpmax(pmax, pmax); + return wrapper::vpmax(pmax, pmax); +} + uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op) { uint32x4x4_t res_idx_mask{ { 0 } }; @@ -144,18 +187,12 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_val uint8x16_t mask_u8{ 0 }; if(op == ReductionOperation::ARG_IDX_MIN) { - auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - pmin = wrapper::vpmin(pmin, pmin); - pmin = wrapper::vpmin(pmin, pmin); - pmin = wrapper::vpmin(pmin, pmin); + auto pmin = calculate_min(vec_res_value); mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); } else { - auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - pmax = wrapper::vpmax(pmax, pmax); - pmax = wrapper::vpmax(pmax, pmax); - pmax = wrapper::vpmax(pmax, pmax); + auto pmax = calculate_max(vec_res_value); mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); } @@ -220,6 +257,21 @@ uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x return res; } +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +inline float16x4_t calculate_min(float16x8_t in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmin = wrapper::vpmin(pmin, pmin); + return wrapper::vpmin(pmin, pmin); +} +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +inline float16x4_t calculate_max(float16x8_t in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmax = wrapper::vpmax(pmax, pmax); + return wrapper::vpmax(pmax, pmax); +} + uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op) { uint32x4x2_t res_idx_mask{ 0 }; @@ -227,16 +279,12 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_va uint16x8_t mask_u16; if(op == ReductionOperation::ARG_IDX_MIN) { - auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - pmin = wrapper::vpmin(pmin, pmin); - pmin = wrapper::vpmin(pmin, pmin); + auto pmin = calculate_min(vec_res_value); mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); } else { - auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - pmax = wrapper::vpmax(pmax, pmax); - pmax = wrapper::vpmax(pmax, pmax); + auto pmax = calculate_max(vec_res_value); mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); } @@ -364,13 +412,22 @@ struct RedOpX { ARM_COMPUTE_UNUSED(out_slice); auto init_res_value = static_cast(0.f); - if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN) - { - init_res_value = *reinterpret_cast(input.ptr()); - } - else if(op == ReductionOperation::PROD) + switch(op) { - init_res_value = static_cast(1.f); + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + { + init_res_value = *reinterpret_cast(input.ptr()); + break; + } + case ReductionOperation::PROD: + { + init_res_value = static_cast(1.f); + break; + } + default: + break; } auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{}); uint32x4x4_t vec_res_idx{ { 0 } }; @@ -406,6 +463,11 @@ struct RedOpX vec_res_value = temp_vec_res_value; break; } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -451,6 +513,11 @@ struct RedOpX *(reinterpret_cast(output.ptr())) = res; break; } + case ReductionOperation::MIN: + { + *(reinterpret_cast(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -474,7 +541,7 @@ struct RedOpX_qasymm8 uint8x16_t vec_res_value = { 0 }; - if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN) + if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN) { vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{}); } @@ -546,48 +613,64 @@ struct RedOpX_qasymm8 vec_res_value = temp_vec_res_value; break; } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } }, input); - if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) - { - auto res = calculate_vector_index(vec_res_idx, vec_res_value, op); - *(reinterpret_cast(output.ptr())) = res; - } - else if(op == ReductionOperation::PROD) - { - auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); - carry_res = wrapper::vmul(carry_res, vec_res_value3_f); - carry_res = wrapper::vmul(carry_res, vec_res_value4_f); - - float res = wrapper::vgetlane(carry_res, 0); - res *= wrapper::vgetlane(carry_res, 1); - res *= wrapper::vgetlane(carry_res, 2); - res *= wrapper::vgetlane(carry_res, 3); - - //re-quantize result - res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset); - *(output.ptr()) = static_cast(res); - } - else + switch(op) { - auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2); - carry_res = wrapper::vadd(carry_res, vec_res_value3); - carry_res = wrapper::vadd(carry_res, vec_res_value4); - - auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res)); - carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition); - auto res = wrapper::vgetlane(carry_paddition, 0); - - if(op == ReductionOperation::MEAN_SUM) + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::ARG_IDX_MAX: { - res /= in_info.dimension(0); + auto res = calculate_vector_index(vec_res_idx, vec_res_value, op); + *(reinterpret_cast(output.ptr())) = res; + break; } + case ReductionOperation::MIN: + { + *(output.ptr()) = static_cast(wrapper::vgetlane(calculate_min(vec_res_value), 0)); + break; + } + case ReductionOperation::PROD: + { + auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); + carry_res = wrapper::vmul(carry_res, vec_res_value3_f); + carry_res = wrapper::vmul(carry_res, vec_res_value4_f); + + float res = wrapper::vgetlane(carry_res, 0); + res *= wrapper::vgetlane(carry_res, 1); + res *= wrapper::vgetlane(carry_res, 2); + res *= wrapper::vgetlane(carry_res, 3); + + //re-quantize result + res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset); + *(output.ptr()) = static_cast(res); + break; + } + default: + { + auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2); + carry_res = wrapper::vadd(carry_res, vec_res_value3); + carry_res = wrapper::vadd(carry_res, vec_res_value4); - *(output.ptr()) = static_cast(res); + auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res)); + carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition); + auto res = wrapper::vgetlane(carry_paddition, 0); + + if(op == ReductionOperation::MEAN_SUM) + { + res /= in_info.dimension(0); + } + + *(output.ptr()) = static_cast(res); + } } } }; @@ -606,17 +689,25 @@ struct RedOpYZW execute_window_loop(in_slice, [&](const Coordinates &) { neon_vector vec_res_value = { 0 }; - if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN) - { - vec_res_value = wrapper::vloadq(reinterpret_cast(input.ptr())); - } - else if(op == ReductionOperation::PROD) - { - vec_res_value = wrapper::vdup_n(static_cast(1.f), ExactTagType{}); - } - else + switch(op) { - vec_res_value = wrapper::vdup_n(static_cast(0.f), ExactTagType{}); + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vloadq(reinterpret_cast(input.ptr())); + break; + } + case ReductionOperation::PROD: + { + vec_res_value = wrapper::vdup_n(static_cast(1.f), ExactTagType{}); + break; + } + default: + { + vec_res_value = wrapper::vdup_n(static_cast(0.f), ExactTagType{}); + break; + } } uint32x4x4_t vec_res_idx{ { 0 } }; @@ -665,6 +756,11 @@ struct RedOpYZW vec_res_value = temp_vec_res_value; break; } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -849,6 +945,11 @@ struct RedOpYZW_qasymm8 vec_res_value = temp_vec_res_value; break; } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -891,6 +992,10 @@ struct RedOpYZW_qasymm8 wrapper::vstore(reinterpret_cast(output.ptr()) + 8, vec_res_idx.val[2]); wrapper::vstore(reinterpret_cast(output.ptr()) + 12, vec_res_idx.val[3]); } + else if(op == ReductionOperation::ARG_IDX_MIN) + { + wrapper::vstore(output.ptr(), vec_res_value); + } else { const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2)); diff --git a/src/runtime/NEON/functions/NEReductionOperation.cpp b/src/runtime/NEON/functions/NEReductionOperation.cpp index a0aed96521..81bb32f5dc 100644 --- a/src/runtime/NEON/functions/NEReductionOperation.cpp +++ b/src/runtime/NEON/functions/NEReductionOperation.cpp @@ -78,7 +78,52 @@ void NEReductionOperation::configure(ITensor *input, ITensor *output, unsigned i { // Configure fill border kernel const BorderSize fill_border_size = _reduction_kernel.border_size(); - const PixelValue pixelValue = (op == ReductionOperation::PROD) ? PixelValue(1, input->info()->data_type(), input->info()->quantization_info()) : PixelValue(0, input->info()->data_type()); + PixelValue pixelValue; + switch(op) + { + case ReductionOperation::PROD: + { + pixelValue = PixelValue(1, input->info()->data_type(), input->info()->quantization_info()); + break; + } + case ReductionOperation::MIN: + { + switch(input->info()->data_type()) + { + case DataType::F32: + { + pixelValue = PixelValue(std::numeric_limits::max()); + break; + } + case DataType::F16: + { + pixelValue = PixelValue(static_cast(65504.0f)); + break; + } + case DataType::QASYMM8: + { + pixelValue = PixelValue(255, input->info()->data_type(), input->info()->quantization_info()); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported DataType"); + } + } + break; + } + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MEAN_SUM: + case ReductionOperation::SUM_SQUARE: + case ReductionOperation::SUM: + { + pixelValue = PixelValue(0, input->info()->data_type()); + break; + } + default: + ARM_COMPUTE_ERROR("Reduction Operation unsupported"); + } _fill_border_kernel.configure(input, fill_border_size, BorderMode::CONSTANT, pixelValue); } } diff --git a/tests/validation/NEON/ReductionOperation.cpp b/tests/validation/NEON/ReductionOperation.cpp index b9b4983ae6..074689d678 100644 --- a/tests/validation/NEON/ReductionOperation.cpp +++ b/tests/validation/NEON/ReductionOperation.cpp @@ -51,7 +51,8 @@ RelativeTolerance tolerance_qasymm8(1); const auto ReductionOperations = framework::dataset::make("ReductionOperation", { ReductionOperation::SUM, - ReductionOperation::PROD + ReductionOperation::PROD, + ReductionOperation::MIN, }); const auto QuantizationInfos = framework::dataset::make("QuantizationInfo", diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp index c7624a4628..1f825f0e0f 100644 --- a/tests/validation/reference/ReductionOperation.cpp +++ b/tests/validation/reference/ReductionOperation.cpp @@ -42,7 +42,24 @@ template OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, int stride) { using type = typename std::remove_cv::type; - auto res = (op == ReductionOperation::PROD) ? type(1) : type(0); + T res; + switch(op) + { + case ReductionOperation::PROD: + { + res = type(1); + } + break; + case ReductionOperation::MIN: + { + res = *ptr; + } + break; + default: + { + res = type(0); + } + } if(std::is_integral::value) { @@ -65,6 +82,12 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in int_res = static_cast(i); } break; + case ReductionOperation::MIN: + if(static_cast(int_res) > elem) + { + int_res = elem; + } + break; case ReductionOperation::SUM_SQUARE: int_res += elem * elem; break; @@ -104,6 +127,12 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in res = static_cast(i); } break; + case ReductionOperation::MIN: + if(res > elem) + { + res = elem; + } + break; case ReductionOperation::SUM_SQUARE: res += elem * elem; break; diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h index b8927615ee..9b8efe5a23 100644 --- a/utils/TypePrinter.h +++ b/utils/TypePrinter.h @@ -1446,6 +1446,9 @@ inline ::std::ostream &operator<<(::std::ostream &os, const ReductionOperation & case ReductionOperation::PROD: os << "PROD"; break; + case ReductionOperation::MIN: + os << "MIN"; + break; default: ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); } -- cgit v1.2.1