diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-05-13 17:41:01 +0100 |
---|---|---|
committer | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-05-20 13:59:51 +0000 |
commit | b9626ab169a168a7c1ca57edd1996e1e80938bf1 (patch) | |
tree | 57ce41fff5e2ece1e7d8f2a6f332c67e4534e752 /src/core/NEON/kernels | |
parent | 0af4418f4d4b6bceaea64fa21eaf127b1b8fed35 (diff) | |
download | ComputeLibrary-b9626ab169a168a7c1ca57edd1996e1e80938bf1.tar.gz |
COMPMID-2243 ArgMinMaxLayer: support new datatypes
Change-Id: I846e833e0c94090cbbdcd6aee6061cea8295f4f9
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1131
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels')
-rw-r--r-- | src/core/NEON/kernels/NEReductionOperationKernel.cpp | 25 |
1 files changed, 17 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp index aa20d1f40d..5f0a4dd371 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp +++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp @@ -41,7 +41,8 @@ namespace arm_compute { namespace { -uint32x4x4_t calculate_index(uint32_t idx, float32x4_t a, float32x4_t b, uint32x4x4_t c, ReductionOperation op, int axis) +template <typename T> +uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis) { uint32x4_t mask{ 0 }; if(op == ReductionOperation::ARG_IDX_MIN) @@ -107,8 +108,8 @@ uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x return res; } - -uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float32x4_t vec_res_value, ReductionOperation op) +template <typename T> +uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) { uint32x4_t res_idx_mask{ 0 }; uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); @@ -124,7 +125,7 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float32x4_t vec_res_va { auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); pmax = wrapper::vpmax(pmax, pmax); - auto mask = vceqq_f32(vec_res_value, wrapper::vcombine(pmax, pmax)); + auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); } @@ -394,14 +395,14 @@ struct RedOpX case ReductionOperation::ARG_IDX_MIN: { auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); + vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); vec_res_value = temp_vec_res_value; break; } case ReductionOperation::ARG_IDX_MAX: { auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); + vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); vec_res_value = temp_vec_res_value; break; } @@ -446,7 +447,7 @@ struct RedOpX case ReductionOperation::ARG_IDX_MIN: case ReductionOperation::ARG_IDX_MAX: { - auto res = calculate_vector_index(vec_res_idx, vec_res_value, op); + auto res = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); *(reinterpret_cast<uint32_t *>(output.ptr())) = res; break; } @@ -943,6 +944,8 @@ void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsi #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F32: return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op); + case DataType::S32: + return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op); default: ARM_COMPUTE_ERROR("Not supported"); } @@ -957,6 +960,8 @@ void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsi #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F32: return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op); + case DataType::S32: + return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op); default: ARM_COMPUTE_ERROR("Not supported"); } @@ -971,6 +976,8 @@ void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsi #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F32: return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op); + case DataType::S32: + return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op); default: ARM_COMPUTE_ERROR("Not supported"); } @@ -985,6 +992,8 @@ void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsi #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F32: return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op); + case DataType::S32: + return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op); default: ARM_COMPUTE_ERROR("Not supported"); } @@ -1002,7 +1011,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u if(input->num_channels() == 1) { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32); } else { |