diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/core/NEON/kernels/NEReductionOperationKernel.cpp | 34 | ||||
-rw-r--r-- | src/runtime/NEON/functions/NEReductionOperation.cpp | 26 |
2 files changed, 59 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp index b51d4b311f..e6edf22083 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp +++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp @@ -417,6 +417,7 @@ struct RedOpX case ReductionOperation::ARG_IDX_MAX: case ReductionOperation::ARG_IDX_MIN: case ReductionOperation::MIN: + case ReductionOperation::MAX: { init_res_value = *reinterpret_cast<T *>(input.ptr()); break; @@ -468,6 +469,11 @@ struct RedOpX vec_res_value = wrapper::vmin(vec_elements, vec_res_value); break; } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -518,6 +524,11 @@ struct RedOpX *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_min(vec_res_value), 0); break; } + case ReductionOperation::MAX: + { + *(reinterpret_cast<T *>(output.ptr())) = wrapper::vgetlane(calculate_max(vec_res_value), 0); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -541,7 +552,7 @@ struct RedOpX_qasymm8 uint8x16_t vec_res_value = { 0 }; - if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN) + if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX) { vec_res_value = wrapper::vdup_n(*input.ptr(), wrapper::traits::vector_128_tag{}); } @@ -618,6 +629,11 @@ struct RedOpX_qasymm8 vec_res_value = wrapper::vmin(vec_elements, vec_res_value); break; } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -638,6 +654,11 @@ struct RedOpX_qasymm8 *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); break; } + case ReductionOperation::MAX: + { + *(output.ptr()) = static_cast<uint8_t>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + break; + } case ReductionOperation::PROD: { auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); @@ -694,6 +715,7 @@ struct RedOpYZW case ReductionOperation::ARG_IDX_MAX: case ReductionOperation::ARG_IDX_MIN: case ReductionOperation::MIN: + case ReductionOperation::MAX: { vec_res_value = wrapper::vloadq(reinterpret_cast<T *>(input.ptr())); break; @@ -761,6 +783,11 @@ struct RedOpYZW vec_res_value = wrapper::vmin(vec_elements, vec_res_value); break; } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } @@ -950,6 +977,11 @@ struct RedOpYZW_qasymm8 vec_res_value = wrapper::vmin(vec_elements, vec_res_value); break; } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } default: ARM_COMPUTE_ERROR("Not supported"); } diff --git a/src/runtime/NEON/functions/NEReductionOperation.cpp b/src/runtime/NEON/functions/NEReductionOperation.cpp index 81bb32f5dc..dc6cf59019 100644 --- a/src/runtime/NEON/functions/NEReductionOperation.cpp +++ b/src/runtime/NEON/functions/NEReductionOperation.cpp @@ -112,6 +112,32 @@ void NEReductionOperation::configure(ITensor *input, ITensor *output, unsigned i } break; } + case ReductionOperation::MAX: + { + switch(input->info()->data_type()) + { + case DataType::F32: + { + pixelValue = PixelValue(-std::numeric_limits<float>::max()); + break; + } + case DataType::F16: + { + pixelValue = PixelValue(static_cast<half>(-65504.0f)); + break; + } + case DataType::QASYMM8: + { + pixelValue = PixelValue(0, input->info()->data_type(), input->info()->quantization_info()); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported DataType"); + } + } + break; + } case ReductionOperation::ARG_IDX_MAX: case ReductionOperation::ARG_IDX_MIN: case ReductionOperation::MEAN_SUM: |