aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEReductionOperationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEReductionOperationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEReductionOperationKernel.cpp481
1 files changed, 459 insertions, 22 deletions
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
index 30f21bbf33..b77219cd79 100644
--- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
@@ -32,10 +32,11 @@
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include <arm_neon.h>
-using namespace arm_compute;
-
+namespace arm_compute
+{
namespace
{
template <class F>
@@ -57,31 +58,281 @@ public:
Iterator in(input, in_slice);
Iterator out(output, out_slice);
- f(in, out, in_slice, out_slice);
+ f(in, out, in_slice, out_slice, *input->info());
+ }
+ while(window.slide_window_slice_1D(in_slice) && out_window.slide_window_slice_1D(out_slice));
+ }
+ static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f)
+ {
+ // Set in window
+ Window in_window(window);
+
+ in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
+
+ // Get first input and output slices
+ Window in_slice = in_window.first_slice_window_2D();
+ Window out_slice = window.first_slice_window_2D();
+
+ do
+ {
+ Iterator in(input, in_slice);
+ Iterator out(output, out_slice);
+
+ f(in, out, in_slice, out_slice, *input->info(), 1);
+ }
+ while(in_window.slide_window_slice_2D(in_slice) && window.slide_window_slice_2D(out_slice));
+ }
+ static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f)
+ {
+ // Set in window
+ Window in_window(window);
+
+ in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
+
+ // Get first input and output slices
+ Window in_slice = in_window.first_slice_window_3D();
+ Window out_slice = window.first_slice_window_3D();
+
+ do
+ {
+ Iterator in(input, in_slice);
+ Iterator out(output, out_slice);
+
+ f(in, out, in_slice, out_slice, *input->info(), 2);
}
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(out_slice));
+ while(in_window.slide_window_slice_3D(in_slice) && window.slide_window_slice_3D(out_slice));
+ }
+ static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f)
+ {
+ // Set in/out window
+ Window in_window(window);
+ Window out_window(window);
+
+ in_window.set(3, Window::Dimension(0, 1, 1));
+ out_window.set(3, Window::Dimension(0, 1, 1));
+
+ // Get first input and output slices
+ Window in_slice = in_window.first_slice_window_4D();
+ Window out_slice = out_window.first_slice_window_4D();
+
+ do
+ {
+ Iterator in(input, in_slice);
+ Iterator out(output, out_slice);
+
+ f(in, out, in_slice, out_slice, *input->info(), 3);
+ }
+ while(in_window.slide_window_slice_4D(in_slice) && out_window.slide_window_slice_4D(out_slice));
}
};
-struct SumsqOpX
+template <typename T, int S, ReductionOperation op>
+struct RedOpX
{
- inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice)
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+ inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info)
{
ARM_COMPUTE_UNUSED(out_slice);
- float32x4_t vec_sum_value = vdupq_n_f32(0.f);
+ auto vec_sum_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
execute_window_loop(in_slice, [&](const Coordinates & id)
{
- const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
- const float32x4_t vec_elements = vld1q_f32(in_ptr);
- vec_sum_value = vaddq_f32(vmulq_f32(vec_elements, vec_elements), vec_sum_value);
+ const auto in_ptr = reinterpret_cast<const T *>(input.ptr());
+ const auto vec_elements = wrapper::vloadq(in_ptr);
+
+ if(op == ReductionOperation::SUM_SQUARE)
+ {
+ vec_sum_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_sum_value);
+ }
+ else
+ {
+ vec_sum_value = wrapper::vadd(vec_elements, vec_sum_value);
+ }
},
input);
- float32x2_t carry_addition = vpadd_f32(vget_high_f32(vec_sum_value), vget_low_f32(vec_sum_value));
- carry_addition = vpadd_f32(carry_addition, carry_addition);
+ auto carry_addition = wrapper::vpadd(wrapper::vgethigh(vec_sum_value), wrapper::vgetlow(vec_sum_value));
+ carry_addition = wrapper::vpadd(carry_addition, carry_addition);
+
+ auto res = wrapper::vgetlane(carry_addition, 0);
+ if(op == ReductionOperation::MEAN_SUM)
+ {
+ res /= in_info.dimension(0);
+ }
- *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(carry_addition, 0);
+ *(reinterpret_cast<T *>(output.ptr())) = res;
+ }
+};
+
+template <ReductionOperation op>
+struct RedOpX_qasymm8
+{
+ inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info)
+ {
+ ARM_COMPUTE_UNUSED(out_slice);
+ auto vec_sum_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
+ auto vec_sum_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
+ auto vec_sum_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
+ auto vec_sum_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
+
+ execute_window_loop(in_slice, [&](const Coordinates & id)
+ {
+ const auto vec_elements = wrapper::vloadq(input.ptr());
+
+ const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
+ const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
+
+ const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
+ const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
+ const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
+ const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
+
+ vec_sum_value1 = wrapper::vadd(temp32x4t_1, vec_sum_value1);
+ vec_sum_value2 = wrapper::vadd(temp32x4t_2, vec_sum_value2);
+ vec_sum_value3 = wrapper::vadd(temp32x4t_3, vec_sum_value3);
+ vec_sum_value4 = wrapper::vadd(temp32x4t_4, vec_sum_value4);
+ },
+ input);
+
+ auto carry_addition = wrapper::vadd(vec_sum_value1, vec_sum_value2);
+ carry_addition = wrapper::vadd(carry_addition, vec_sum_value3);
+ carry_addition = wrapper::vadd(carry_addition, vec_sum_value4);
+
+ auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_addition), wrapper::vgetlow(carry_addition));
+ carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition);
+ auto res = wrapper::vgetlane(carry_paddition, 0);
+
+ if(op == ReductionOperation::MEAN_SUM)
+ {
+ res /= in_info.dimension(0);
+ }
+
+ *(output.ptr()) = static_cast<uint8_t>(res);
+ }
+};
+
+template <typename T, int S, ReductionOperation op>
+struct RedOpYZW
+{
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+ inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis)
+ {
+ ARM_COMPUTE_UNUSED(out_slice);
+
+ execute_window_loop(in_slice, [&](const Coordinates & id)
+ {
+ auto vec_sum_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+ for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
+ {
+ T *in_ptr;
+ switch(axis)
+ {
+ case 1:
+ in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
+ break;
+ case 2:
+ in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
+ break;
+ case 3:
+ in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ const auto vec_elements = wrapper::vloadq(in_ptr);
+
+ if(op == ReductionOperation::SUM_SQUARE)
+ {
+ vec_sum_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_sum_value);
+ }
+ else
+ {
+ vec_sum_value = wrapper::vadd(vec_elements, vec_sum_value);
+ }
+ }
+
+ if(op == ReductionOperation::MEAN_SUM)
+ {
+ auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
+ vec_sum_value = wrapper::vmul(vec_sum_value, vec_width_inv);
+ }
+
+ wrapper::vstore(reinterpret_cast<T *>(output.ptr()), vec_sum_value);
+ },
+ input, output);
+ }
+};
+
+template <ReductionOperation op>
+struct RedOpYZW_qasymm8
+{
+ inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis)
+ {
+ ARM_COMPUTE_UNUSED(out_slice);
+
+ execute_window_loop(in_slice, [&](const Coordinates & id)
+ {
+ auto vec_sum_value1 = vdupq_n_u32(static_cast<uint32_t>(0.f));
+ auto vec_sum_value2 = vdupq_n_u32(static_cast<uint32_t>(0.f));
+ auto vec_sum_value3 = vdupq_n_u32(static_cast<uint32_t>(0.f));
+ auto vec_sum_value4 = vdupq_n_u32(static_cast<uint32_t>(0.f));
+ for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
+ {
+ uint8_t *in_ptr;
+ switch(axis)
+ {
+ case 1:
+ in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim));
+ break;
+ case 2:
+ in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim));
+ break;
+ case 3:
+ in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim));
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ const auto vec_elements = wrapper::vloadq(in_ptr);
+
+ const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
+ const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
+
+ const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
+ const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
+ const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
+ const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
+
+ vec_sum_value1 = wrapper::vadd(temp32x4t_1, vec_sum_value1);
+ vec_sum_value2 = wrapper::vadd(temp32x4t_2, vec_sum_value2);
+ vec_sum_value3 = wrapper::vadd(temp32x4t_3, vec_sum_value3);
+ vec_sum_value4 = wrapper::vadd(temp32x4t_4, vec_sum_value4);
+ }
+
+ if(op == ReductionOperation::MEAN_SUM)
+ {
+ const auto vec_width_inv = wrapper::vinv(vdupq_n_f32(in_info.dimension(axis)));
+ const auto vec_sum_value1_f = wrapper::vmul(vcvtq_f32_u32(vec_sum_value1), vec_width_inv);
+ const auto vec_sum_value2_f = wrapper::vmul(vcvtq_f32_u32(vec_sum_value2), vec_width_inv);
+ const auto vec_sum_value3_f = wrapper::vmul(vcvtq_f32_u32(vec_sum_value3), vec_width_inv);
+ const auto vec_sum_value4_f = wrapper::vmul(vcvtq_f32_u32(vec_sum_value4), vec_width_inv);
+
+ vec_sum_value1 = vcvtq_u32_f32(vec_sum_value1_f);
+ vec_sum_value2 = vcvtq_u32_f32(vec_sum_value2_f);
+ vec_sum_value3 = vcvtq_u32_f32(vec_sum_value3_f);
+ vec_sum_value4 = vcvtq_u32_f32(vec_sum_value4_f);
+ }
+
+ const auto temp16x8t_1 = vcombine_u16(wrapper::vqmovn(vec_sum_value1), wrapper::vqmovn(vec_sum_value2));
+ const auto temp16x8t_2 = vcombine_u16(wrapper::vqmovn(vec_sum_value3), wrapper::vqmovn(vec_sum_value4));
+ auto res = vcombine_u8(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
+ wrapper::vstore(output.ptr(), res);
+ },
+ input, output);
}
};
@@ -90,7 +341,186 @@ void reduce_sumsq(const Window &window, const ITensor *input, ITensor *output, u
switch(axis)
{
case 0:
- return Reducer<SumsqOpX>::reduceX(window, input, output, SumsqOpX());
+ switch(input->info()->data_type())
+ {
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ return Reducer<RedOpX<float16_t, 8, ReductionOperation::SUM_SQUARE>>::reduceX(window, input, output, RedOpX<float16_t, 8, ReductionOperation::SUM_SQUARE>());
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F32:
+ return Reducer<RedOpX<float, 4, ReductionOperation::SUM_SQUARE>>::reduceX(window, input, output, RedOpX<float, 4, ReductionOperation::SUM_SQUARE>());
+ case DataType::QASYMM8:
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ case 1:
+ switch(input->info()->data_type())
+ {
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>>::reduceY(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>());
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F32:
+ return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>>::reduceY(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>());
+ case DataType::QASYMM8:
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ case 2:
+ switch(input->info()->data_type())
+ {
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>());
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F32:
+ return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>>::reduceZ(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>());
+ case DataType::QASYMM8:
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ case 3:
+ switch(input->info()->data_type())
+ {
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>>::reduceW(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM_SQUARE>());
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F32:
+ return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>>::reduceW(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM_SQUARE>());
+ case DataType::QASYMM8:
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ default:
+ ARM_COMPUTE_ERROR("Unsupported reduction axis");
+ }
+}
+
+void reduce_sum(const Window &window, const ITensor *input, ITensor *output, unsigned int axis)
+{
+ switch(axis)
+ {
+ case 0:
+ switch(input->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ return Reducer<RedOpX_qasymm8<ReductionOperation::SUM>>::reduceX(window, input, output, RedOpX_qasymm8<ReductionOperation::SUM>());
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ return Reducer<RedOpX<float16_t, 8, ReductionOperation::SUM>>::reduceX(window, input, output, RedOpX<float16_t, 8, ReductionOperation::SUM>());
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F32:
+ return Reducer<RedOpX<float, 4, ReductionOperation::SUM>>::reduceX(window, input, output, RedOpX<float, 4, ReductionOperation::SUM>());
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ case 1:
+ switch(input->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ return Reducer<RedOpYZW_qasymm8<ReductionOperation::SUM>>::reduceY(window, input, output, RedOpYZW_qasymm8<ReductionOperation::SUM>());
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM>>::reduceY(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM>());
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F32:
+ return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM>>::reduceY(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM>());
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ case 2:
+ switch(input->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ return Reducer<RedOpYZW_qasymm8<ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_qasymm8<ReductionOperation::SUM>());
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM>());
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F32:
+ return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM>());
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ case 3:
+ switch(input->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ return Reducer<RedOpYZW_qasymm8<ReductionOperation::SUM>>::reduceW(window, input, output, RedOpYZW_qasymm8<ReductionOperation::SUM>());
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::SUM>>::reduceW(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::SUM>());
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F32:
+ return Reducer<RedOpYZW<float, 4, ReductionOperation::SUM>>::reduceW(window, input, output, RedOpYZW<float, 4, ReductionOperation::SUM>());
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ default:
+ ARM_COMPUTE_ERROR("Unsupported reduction axis");
+ }
+}
+void reduce_mean_sum(const Window &window, const ITensor *input, ITensor *output, unsigned int axis)
+{
+ switch(axis)
+ {
+ case 0:
+ switch(input->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ return Reducer<RedOpX_qasymm8<ReductionOperation::MEAN_SUM>>::reduceX(window, input, output, RedOpX_qasymm8<ReductionOperation::MEAN_SUM>());
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ return Reducer<RedOpX<float16_t, 8, ReductionOperation::MEAN_SUM>>::reduceX(window, input, output, RedOpX<float16_t, 8, ReductionOperation::MEAN_SUM>());
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F32:
+ return Reducer<RedOpX<float, 4, ReductionOperation::MEAN_SUM>>::reduceX(window, input, output, RedOpX<float, 4, ReductionOperation::MEAN_SUM>());
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ case 1:
+ switch(input->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ return Reducer<RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>>::reduceY(window, input, output, RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>());
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>>::reduceY(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>());
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F32:
+ return Reducer<RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>>::reduceY(window, input, output, RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>());
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ case 2:
+ switch(input->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ return Reducer<RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>>::reduceZ(window, input, output, RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>());
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>());
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F32:
+ return Reducer<RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>>::reduceZ(window, input, output, RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>());
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+ case 3:
+ switch(input->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ return Reducer<RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>>::reduceW(window, input, output, RedOpYZW_qasymm8<ReductionOperation::MEAN_SUM>());
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ return Reducer<RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>>::reduceW(window, input, output, RedOpYZW<float16_t, 8, ReductionOperation::MEAN_SUM>());
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F32:
+ return Reducer<RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>>::reduceW(window, input, output, RedOpYZW<float, 4, ReductionOperation::MEAN_SUM>());
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
default:
ARM_COMPUTE_ERROR("Unsupported reduction axis");
}
@@ -109,16 +539,15 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u
ARM_COMPUTE_UNUSED(op);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() != DataLayout::NCHW);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 0, "Unsupported reduction axis, Supported axis is 0");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
if(output->total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON(output->data_layout() != DataLayout::NCHW);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
const TensorShape output_shape = calculate_output_shape(input->tensor_shape(), axis);
const TensorInfo tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
@@ -170,10 +599,11 @@ void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output
unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
- _input = input;
- _output = output;
- _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
- _op = op;
+ _input = input;
+ _output = output;
+ _border_size = (axis == 0) ? BorderSize(0, num_elems_processed_per_iteration - (input->info()->dimension(0) % num_elems_processed_per_iteration), 0, 0) : BorderSize();
+ _op = op;
+ _reduction_axis = axis;
// Configure kernel window
auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis);
@@ -202,7 +632,14 @@ void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &inf
case ReductionOperation::SUM_SQUARE:
reduce_sumsq(window, _input, _output, _reduction_axis);
break;
+ case ReductionOperation::MEAN_SUM:
+ reduce_mean_sum(window, _input, _output, _reduction_axis);
+ break;
+ case ReductionOperation::SUM:
+ reduce_sum(window, _input, _output, _reduction_axis);
+ break;
default:
ARM_COMPUTE_ERROR("Unsupported reduction operation.");
}
}
+} // namespace arm_compute