From a4a08ad5e33867f9938a3fbaf9b6dcc56ad8f7b5 Mon Sep 17 00:00:00 2001 From: Usama Arif Date: Mon, 20 May 2019 12:38:33 +0100 Subject: COMPMID-2280: Implement REDUCE_MIN operator for NEON Change-Id: Iaa8d97e3328ce69dae7a97a7111120ecc61fb465 Signed-off-by: Usama Arif Reviewed-on: https://review.mlplatform.org/c/1192 Comments-Addressed: Arm Jenkins Reviewed-by: Pablo Marquez Tested-by: Arm Jenkins --- .../NEON/kernels/NEReductionOperationKernel.cpp | 239 +++++++++++++++------ 1 file changed, 172 insertions(+), 67 deletions(-) (limited to 'src/core/NEON/kernels/NEReductionOperationKernel.cpp') diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp index 5f0a4dd371..b51d4b311f 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp +++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp @@ -108,6 +108,34 @@ uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x return res; } + +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +float32x2_t calculate_min(float32x4_t in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + return wrapper::vpmin(pmin, pmin); +} + +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +float32x2_t calculate_max(float32x4_t in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + return wrapper::vpmax(pmax, pmax); +} +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +int32x2_t calculate_min(int32x4_t in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + return wrapper::vpmin(pmin, pmin); +} + +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +int32x2_t calculate_max(int32x4_t in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + return wrapper::vpmax(pmax, pmax); +} + template uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) { @@ -116,15 +144,13 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, Reduc if(op == ReductionOperation::ARG_IDX_MIN) { - auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - pmin = wrapper::vpmin(pmin, pmin); + auto pmin = calculate_min(vec_res_value); auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); } else { - auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - pmax = wrapper::vpmax(pmax, pmax); + auto pmax = calculate_max(vec_res_value); auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); } @@ -137,6 +163,23 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, Reduc return (res - 0xFFFFFFFF); } +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +inline uint8x8_t calculate_min(uint8x16_t in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmin = wrapper::vpmin(pmin, pmin); + pmin = wrapper::vpmin(pmin, pmin); + return wrapper::vpmin(pmin, pmin); +} +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +inline uint8x8_t calculate_max(uint8x16_t in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmax = wrapper::vpmax(pmax, pmax); + pmax = wrapper::vpmax(pmax, pmax); + return wrapper::vpmax(pmax, pmax); +} + uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_value, ReductionOperation op) { uint32x4x4_t res_idx_mask{ { 0 } }; @@ -144,18 +187,12 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, uint8x16_t vec_res_val uint8x16_t mask_u8{ 0 }; if(op == ReductionOperation::ARG_IDX_MIN) { - auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - pmin = wrapper::vpmin(pmin, pmin); - pmin = wrapper::vpmin(pmin, pmin); - pmin = wrapper::vpmin(pmin, pmin); + auto pmin = calculate_min(vec_res_value); mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); } else { - auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - pmax = wrapper::vpmax(pmax, pmax); - pmax = wrapper::vpmax(pmax, pmax); - pmax = wrapper::vpmax(pmax, pmax); + auto pmax = calculate_max(vec_res_value); mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); } @@ -220,6 +257,21 @@ uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x return res; } +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +inline float16x4_t calculate_min(float16x8_t in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmin = wrapper::vpmin(pmin, pmin); + return wrapper::vpmin(pmin, pmin); +} +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +inline float16x4_t calculate_max(float16x8_t in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmax = wrapper::vpmax(pmax, pmax); + return wrapper::vpmax(pmax, pmax); +} + uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op) { uint32x4x2_t res_idx_mask{ 0 }; @@ -227,16 +279,12 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_va uint16x8_t mask_u16; if(op == ReductionOperation::ARG_IDX_MIN) { - auto pmin = wrapper::vpmin(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - pmin = wrapper::vpmin(pmin, pmin); - pmin = wrapper::vpmin(pmin, pmin); + auto pmin = calculate_min(vec_res_value); mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); } else { - auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - pmax = wrapper::vpmax(pmax, pmax); - pmax = wrapper::vpmax(pmax, pmax); + auto pmax = calculate_max(vec_res_value); mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); } @@ -364,13 +412,22 @@ struct RedOpX { ARM_COMPUTE_UNUSED(out_slice); auto init_res_value = static_cast(0.f); - if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN) - { - init_res_value = *reinterpret_cast(input.ptr()); - } - else if(op == ReductionOperation::PROD) + switch(op) { - init_res_value = static_cast(1.f); + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + { + init_res_value = *reinterpret_cast(input.ptr()); + break; + } + case ReductionOperation::PROD: + { + init_res_value = static_cast(1.f); + break; + } + default: + break; } auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{}); uint32x4x4_t vec_res_idx{ { 0 } }; @@ -406,6 +463,11 @@ struct RedOpX vec_res_value = temp_vec_res_value; break; } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -451,6 +513,11 @@ struct RedOpX *(reinterpret_cast(output.ptr())) = res; break; } + case ReductionOperation::MIN: + { + *(reinterpret_cast(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -474,7 +541,7 @@ struct RedOpX_qasymm8 uint8x16_t vec_res_value = { 0 }; - if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN) + if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN) { vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{}); } @@ -546,48 +613,64 @@ struct RedOpX_qasymm8 vec_res_value = temp_vec_res_value; break; } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } }, input); - if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) - { - auto res = calculate_vector_index(vec_res_idx, vec_res_value, op); - *(reinterpret_cast(output.ptr())) = res; - } - else if(op == ReductionOperation::PROD) - { - auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); - carry_res = wrapper::vmul(carry_res, vec_res_value3_f); - carry_res = wrapper::vmul(carry_res, vec_res_value4_f); - - float res = wrapper::vgetlane(carry_res, 0); - res *= wrapper::vgetlane(carry_res, 1); - res *= wrapper::vgetlane(carry_res, 2); - res *= wrapper::vgetlane(carry_res, 3); - - //re-quantize result - res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset); - *(output.ptr()) = static_cast(res); - } - else + switch(op) { - auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2); - carry_res = wrapper::vadd(carry_res, vec_res_value3); - carry_res = wrapper::vadd(carry_res, vec_res_value4); - - auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res)); - carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition); - auto res = wrapper::vgetlane(carry_paddition, 0); - - if(op == ReductionOperation::MEAN_SUM) + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::ARG_IDX_MAX: { - res /= in_info.dimension(0); + auto res = calculate_vector_index(vec_res_idx, vec_res_value, op); + *(reinterpret_cast(output.ptr())) = res; + break; } + case ReductionOperation::MIN: + { + *(output.ptr()) = static_cast(wrapper::vgetlane(calculate_min(vec_res_value), 0)); + break; + } + case ReductionOperation::PROD: + { + auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); + carry_res = wrapper::vmul(carry_res, vec_res_value3_f); + carry_res = wrapper::vmul(carry_res, vec_res_value4_f); + + float res = wrapper::vgetlane(carry_res, 0); + res *= wrapper::vgetlane(carry_res, 1); + res *= wrapper::vgetlane(carry_res, 2); + res *= wrapper::vgetlane(carry_res, 3); + + //re-quantize result + res = sqcvt_qasymm8_f32(res, in_info.quantization_info().scale, in_info.quantization_info().offset); + *(output.ptr()) = static_cast(res); + break; + } + default: + { + auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2); + carry_res = wrapper::vadd(carry_res, vec_res_value3); + carry_res = wrapper::vadd(carry_res, vec_res_value4); - *(output.ptr()) = static_cast(res); + auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res)); + carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition); + auto res = wrapper::vgetlane(carry_paddition, 0); + + if(op == ReductionOperation::MEAN_SUM) + { + res /= in_info.dimension(0); + } + + *(output.ptr()) = static_cast(res); + } } } }; @@ -606,17 +689,25 @@ struct RedOpYZW execute_window_loop(in_slice, [&](const Coordinates &) { neon_vector vec_res_value = { 0 }; - if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN) - { - vec_res_value = wrapper::vloadq(reinterpret_cast(input.ptr())); - } - else if(op == ReductionOperation::PROD) - { - vec_res_value = wrapper::vdup_n(static_cast(1.f), ExactTagType{}); - } - else + switch(op) { - vec_res_value = wrapper::vdup_n(static_cast(0.f), ExactTagType{}); + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vloadq(reinterpret_cast(input.ptr())); + break; + } + case ReductionOperation::PROD: + { + vec_res_value = wrapper::vdup_n(static_cast(1.f), ExactTagType{}); + break; + } + default: + { + vec_res_value = wrapper::vdup_n(static_cast(0.f), ExactTagType{}); + break; + } } uint32x4x4_t vec_res_idx{ { 0 } }; @@ -665,6 +756,11 @@ struct RedOpYZW vec_res_value = temp_vec_res_value; break; } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -849,6 +945,11 @@ struct RedOpYZW_qasymm8 vec_res_value = temp_vec_res_value; break; } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -891,6 +992,10 @@ struct RedOpYZW_qasymm8 wrapper::vstore(reinterpret_cast(output.ptr()) + 8, vec_res_idx.val[2]); wrapper::vstore(reinterpret_cast(output.ptr()) + 12, vec_res_idx.val[3]); } + else if(op == ReductionOperation::ARG_IDX_MIN) + { + wrapper::vstore(output.ptr(), vec_res_value); + } else { const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2)); -- cgit v1.2.1