aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEReductionOperationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEReductionOperationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEReductionOperationKernel.cpp25
1 files changed, 17 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
index aa20d1f40d..5f0a4dd371 100644
--- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
@@ -41,7 +41,8 @@ namespace arm_compute
{
namespace
{
-uint32x4x4_t calculate_index(uint32_t idx, float32x4_t a, float32x4_t b, uint32x4x4_t c, ReductionOperation op, int axis)
+template <typename T>
+uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
{
uint32x4_t mask{ 0 };
if(op == ReductionOperation::ARG_IDX_MIN)
@@ -107,8 +108,8 @@ uint32x4x4_t calculate_index(uint32_t idx, uint8x16_t a, uint8x16_t b, uint32x4x
return res;
}
-
-uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float32x4_t vec_res_value, ReductionOperation op)
+template <typename T>
+uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
{
uint32x4_t res_idx_mask{ 0 };
uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
@@ -124,7 +125,7 @@ uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float32x4_t vec_res_va
{
auto pmax = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
pmax = wrapper::vpmax(pmax, pmax);
- auto mask = vceqq_f32(vec_res_value, wrapper::vcombine(pmax, pmax));
+ auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
}
@@ -394,14 +395,14 @@ struct RedOpX
case ReductionOperation::ARG_IDX_MIN:
{
auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
- vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
+ vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
vec_res_value = temp_vec_res_value;
break;
}
case ReductionOperation::ARG_IDX_MAX:
{
auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
- vec_res_idx = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
+ vec_res_idx = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
vec_res_value = temp_vec_res_value;
break;
}
@@ -446,7 +447,7 @@ struct RedOpX
case ReductionOperation::ARG_IDX_MIN:
case ReductionOperation::ARG_IDX_MAX:
{
- auto res = calculate_vector_index(vec_res_idx, vec_res_value, op);
+ auto res = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
*(reinterpret_cast<uint32_t *>(output.ptr())) = res;
break;
}
@@ -943,6 +944,8 @@ void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsi
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F32:
return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
+ case DataType::S32:
+ return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -957,6 +960,8 @@ void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsi
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F32:
return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
+ case DataType::S32:
+ return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -971,6 +976,8 @@ void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsi
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F32:
return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
+ case DataType::S32:
+ return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -985,6 +992,8 @@ void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsi
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F32:
return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
+ case DataType::S32:
+ return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
default:
ARM_COMPUTE_ERROR("Not supported");
}
@@ -1002,7 +1011,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u
if(input->num_channels() == 1)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
}
else
{