aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEReductionOperationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEReductionOperationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEReductionOperationKernel.cpp101
1 files changed, 97 insertions, 4 deletions
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
index e6fdba2696..aa20d1f40d 100644
--- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
@@ -602,7 +602,7 @@ struct RedOpYZW
{
ARM_COMPUTE_UNUSED(out_slice);
- execute_window_loop(in_slice, [&](const Coordinates & id)
+ execute_window_loop(in_slice, [&](const Coordinates &)
{
neon_vector vec_res_value = { 0 };
if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
@@ -688,13 +688,70 @@ struct RedOpYZW
}
};
+template <typename T, int S, int axis, ReductionOperation op>
+struct RedOpYZW_complex
+{
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+ using neon_vector = typename wrapper::traits::neon_vector<T, S>::type;
+
+ inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int, const ReductionOperation)
+ {
+ ARM_COMPUTE_UNUSED(out_slice);
+ ARM_COMPUTE_ERROR_ON(axis != 2);
+
+ const size_t stride_z = in_info.strides_in_bytes()[axis];
+
+ execute_window_loop(in_slice, [&](const Coordinates &)
+ {
+ neon_vector vec_res_value_0 = { 0 };
+ neon_vector vec_res_value_1 = { 0 };
+
+ vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+ vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+
+ for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
+ {
+ T *in_ptr_0;
+ T *in_ptr_1;
+ switch(axis)
+ {
+ case 2:
+ in_ptr_0 = reinterpret_cast<T *>(input.ptr() + stride_z * dim);
+ in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 16 + stride_z * dim);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
+ const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
+
+ switch(op)
+ {
+ case ReductionOperation::SUM:
+ vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
+ vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ }
+
+ wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_res_value_0);
+ wrapper::vstore(reinterpret_cast<T *>(output.ptr() + 16), vec_res_value_1);
+
+ },
+ input, output);
+ }
+};
+
struct RedOpYZW_qasymm8
{
inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op)
{
ARM_COMPUTE_UNUSED(out_slice);
- execute_window_loop(in_slice, [&](const Coordinates & id)
+ execute_window_loop(in_slice, [&](const Coordinates &)
{
uint32x4x4_t vec_res_idx{ { 0 } };
auto vec_res_value1 = vdupq_n_u32(0);
@@ -848,6 +905,31 @@ struct RedOpYZW_qasymm8
void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
{
+ const bool is_complex = (input->info()->num_channels() == 2);
+
+ if(is_complex)
+ {
+ switch(axis)
+ {
+ case 2:
+ switch(input->info()->data_type())
+ {
+ case DataType::F32:
+ switch(op)
+ {
+ case ReductionOperation::SUM:
+ return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ }
+
switch(axis)
{
case 0:
@@ -917,7 +999,17 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+
+ if(input->num_channels() == 1)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
+ ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
+ }
ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
@@ -929,6 +1021,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
}
else
{
@@ -951,7 +1044,7 @@ std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITe
// Output auto initialization if not yet initialized
const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
- auto_init_if_empty(*output, output_shape, 1, output_data_type, input->quantization_info());
+ auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());