From 154bc1c3e6a0182e2130c7966af3944ee6ca20b3 Mon Sep 17 00:00:00 2001 From: giuros01 Date: Tue, 26 Mar 2019 17:44:40 +0000 Subject: COMPMID-1973: Implement FFTConvolutionLayer on NEON Change-Id: I2e667c0411bda0164a616ffe44473a78de6752c9 Signed-off-by: giuros01 Reviewed-on: https://review.mlplatform.org/c/1066 Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp | 199 +++++++++++++++++---- src/core/NEON/kernels/NEFFTScaleKernel.cpp | 2 +- .../kernels/NEPixelWiseMultiplicationKernel.cpp | 137 ++++++++++++++ .../NEON/kernels/NEReductionOperationKernel.cpp | 101 ++++++++++- 4 files changed, 403 insertions(+), 36 deletions(-) (limited to 'src/core') diff --git a/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp b/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp index b2ffb01e99..cf77345da7 100644 --- a/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp +++ b/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp @@ -29,19 +29,24 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" +#include + namespace arm_compute { namespace { -Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *idx, unsigned int axis) +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *idx, const FFTDigitReverseKernelInfo &config) { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() > 2); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(idx, 1, DataType::U32); - ARM_COMPUTE_RETURN_ERROR_ON(axis > 1); + ARM_COMPUTE_RETURN_ERROR_ON(std::set({ 0, 1 }).count(config.axis) == 0); + ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[config.axis] != idx->tensor_shape().x()); // Checks performed when output is configured if((output != nullptr) && (output->total_size() != 0)) { + ARM_COMPUTE_RETURN_ERROR_ON(output->num_channels() != 2); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); } @@ -49,75 +54,207 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c return Status{}; } -std::pair validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *idx, unsigned int axis) +std::pair validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *idx, const FFTDigitReverseKernelInfo &config) { - ARM_COMPUTE_UNUSED(idx, axis); + ARM_COMPUTE_UNUSED(idx, config); - auto_init_if_empty(*output, *input); + auto_init_if_empty(*output, input->clone()->set_num_channels(2)); - Window win = calculate_max_window(*output, Steps()); - output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape())); + Window win = calculate_max_window(*input, Steps()); + input->set_valid_region(ValidRegion(Coordinates(), input->tensor_shape())); return std::make_pair(Status{}, win); } } // namespace NEFFTDigitReverseKernel::NEFFTDigitReverseKernel() - : _input(nullptr), _output(nullptr), _idx(nullptr), _axis(0) + : _func(nullptr), _input(nullptr), _output(nullptr), _idx(nullptr) { } -void NEFFTDigitReverseKernel::configure(const ITensor *input, ITensor *output, const ITensor *idx, unsigned int axis) +void NEFFTDigitReverseKernel::configure(const ITensor *input, ITensor *output, const ITensor *idx, const FFTDigitReverseKernelInfo &config) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output, idx); - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), idx->info(), axis)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), idx->info(), config)); _input = input; _output = output; _idx = idx; - _axis = axis; + + const size_t axis = config.axis; + const bool is_conj = config.conjugate; + const bool is_input_complex = (input->info()->num_channels() == 2); // Configure kernel window - auto win_config = validate_and_configure_window(input->info(), output->info(), idx->info(), axis); + auto win_config = validate_and_configure_window(input->info(), output->info(), idx->info(), config); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); INEKernel::configure(win_config.second); + + if(axis == 0) + { + if(is_input_complex) + { + if(is_conj) + { + _func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_0; + } + else + { + _func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_0; + } + } + else + { + _func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_0; + } + } + else if(axis == 1) + { + if(is_input_complex) + { + if(is_conj) + { + _func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_1; + } + else + { + _func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_1; + } + } + else + { + _func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_1; + } + } + else + { + ARM_COMPUTE_ERROR("Not supported"); + } } -Status NEFFTDigitReverseKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *idx, unsigned int axis) +Status NEFFTDigitReverseKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *idx, const FFTDigitReverseKernelInfo &config) { - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, idx, axis)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), idx->clone().get(), axis).first); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, idx, config)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), idx->clone().get(), config).first); return Status{}; } -void NEFFTDigitReverseKernel::run(const Window &window, const ThreadInfo &info) +template +void NEFFTDigitReverseKernel::digit_reverse_kernel_axis_0(const Window &window) { - ARM_COMPUTE_UNUSED(info); - Iterator out(_output, window); - const size_t element_size = _input->info()->element_size(); + const size_t N = _input->info()->dimension(0); + + // Copy the look-up buffer to a local array + std::vector buffer_idx(N); + std::copy_n(reinterpret_cast(_idx->buffer()), N, buffer_idx.data()); + + // Input/output iterators + Window slice = window; + slice.set(0, Window::DimX); + Iterator in(_input, slice); + Iterator out(_output, slice); + + // Row buffers + std::vector buffer_row_out(2 * N); + std::vector buffer_row_in(2 * N); + + execute_window_loop(slice, [&](const Coordinates &) + { + if(is_input_complex) + { + // Load + memcpy(buffer_row_in.data(), reinterpret_cast(in.ptr()), 2 * N * sizeof(float)); - // Pointers to the buffers - const size_t offset = _input->info()->offset_first_element_in_bytes(); - auto *idx_ptr = reinterpret_cast(_idx->buffer()); - uint8_t *input_ptr = offset + _input->buffer(); + // Shuffle + for(size_t x = 0; x < 2 * N; x += 2) + { + size_t idx = buffer_idx[x / 2]; + buffer_row_out[x] = buffer_row_in[2 * idx]; + buffer_row_out[x + 1] = (is_conj ? -buffer_row_in[2 * idx + 1] : buffer_row_in[2 * idx + 1]); + } + } + else + { + // Load + memcpy(buffer_row_in.data(), reinterpret_cast(in.ptr()), N * sizeof(float)); + + // Shuffle + for(size_t x = 0; x < N; ++x) + { + size_t idx = buffer_idx[x]; + buffer_row_out[2 * x] = buffer_row_in[idx]; + } + } + + // Copy back + memcpy(reinterpret_cast(out.ptr()), buffer_row_out.data(), 2 * N * sizeof(float)); + }, + in, out); +} + +template +void NEFFTDigitReverseKernel::digit_reverse_kernel_axis_1(const Window &window) +{ + const size_t Nx = _input->info()->dimension(0); + const size_t Ny = _input->info()->dimension(1); + + // Copy the look-up buffer to a local array + std::vector buffer_idx(Ny); + std::copy_n(reinterpret_cast(_idx->buffer()), Ny, buffer_idx.data()); + + // Output iterator + Window slice = window; + slice.set(0, Window::DimX); + Iterator out(_output, slice); + + // Row buffer + std::vector buffer_row(Nx); // Strides - const size_t stride_x = _input->info()->strides_in_bytes()[0]; - const size_t stride_y = _input->info()->strides_in_bytes()[1]; const size_t stride_z = _input->info()->strides_in_bytes()[2]; const size_t stride_w = _input->info()->strides_in_bytes()[3]; - execute_window_loop(window, [&](const Coordinates & id) + execute_window_loop(slice, [&](const Coordinates & id) { - unsigned int in_index_1d = idx_ptr[id[_axis]]; - auto reverse_id = id; - reverse_id.set(_axis, in_index_1d); + auto *out_ptr = reinterpret_cast(out.ptr()); + auto *in_ptr = reinterpret_cast(_input->buffer() + id.z() * stride_z + id[3] * stride_w); + const size_t y_shuffled = buffer_idx[id.y()]; + + if(is_input_complex) + { + // Shuffle the entire row into the output + memcpy(out_ptr, in_ptr + 2 * Nx * y_shuffled, 2 * Nx * sizeof(float)); - memcpy(out.ptr(), input_ptr + reverse_id.x() * stride_x + reverse_id.y() * stride_y + reverse_id.z() * stride_z + reverse_id[3] * stride_w, element_size); + // Conjugate if necessary + if(is_conj) + { + for(size_t x = 0; x < 2 * Nx; x += 2) + { + out_ptr[x + 1] = -out_ptr[x + 1]; + } + } + } + else + { + // Shuffle the entire row into the buffer + memcpy(buffer_row.data(), in_ptr + Nx * y_shuffled, Nx * sizeof(float)); + + // Copy the buffer to the output, with a zero imaginary part + for(size_t x = 0; x < 2 * Nx; x += 2) + { + out_ptr[x] = buffer_row[x / 2]; + } + } }, out); +} +void NEFFTDigitReverseKernel::run(const Window &window, const ThreadInfo &info) +{ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); + ARM_COMPUTE_UNUSED(info); + (this->*_func)(window); } + } // namespace arm_compute diff --git a/src/core/NEON/kernels/NEFFTScaleKernel.cpp b/src/core/NEON/kernels/NEFFTScaleKernel.cpp index 6568755e5d..56703bafcc 100644 --- a/src/core/NEON/kernels/NEFFTScaleKernel.cpp +++ b/src/core/NEON/kernels/NEFFTScaleKernel.cpp @@ -129,7 +129,7 @@ void NEFFTScaleKernel::run(const Window &window, const ThreadInfo &info) execute_window_loop(window, [&](const Coordinates &) { - scale_complex(reinterpret_cast(out.ptr()), reinterpret_cast(in.ptr()), _is_conj, _scale); + scale_complex(reinterpret_cast(in.ptr()), reinterpret_cast(out.ptr()), _is_conj, _scale); }, in, out); } diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index b565300906..fa16484cd3 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -30,6 +30,7 @@ #include "arm_compute/core/ITensor.h" #include "arm_compute/core/NEON/NEAsymm.h" #include "arm_compute/core/NEON/NEFixedPoint.h" +#include "arm_compute/core/NEON/wrapper/wrapper.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" @@ -353,6 +354,35 @@ void mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict vst4q_f32(output, result); } +void c_mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr) +{ + const auto input1 = static_cast(input1_ptr); + const auto input2 = static_cast(input2_ptr); + const auto output = static_cast(output_ptr); + + const float32x4_t a = wrapper::vloadq(input1); + float32x4_t b = wrapper::vloadq(input2); + + using ExactTagType = typename wrapper::traits::neon_vector::tag_type; + + const float32x4_t mask = { -1.0f, 1.0f, -1.0f, 1.0f }; + const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{}); + const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{}); + const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{}); + const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{}); + + const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10); + const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11); + + float32x4_t res = wrapper::vmul(tmp0, b); + + b = wrapper::vrev64(b); + b = wrapper::vmul(b, mask); + + res = wrapper::vmla(res, tmp1, b); + wrapper::vstore(output, res); +} + void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale) { #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -665,4 +695,111 @@ BorderSize NEPixelWiseMultiplicationKernel::border_size() const const unsigned int border = std::min(num_elems_processed_per_iteration - 1U, replicateSize); return BorderSize{ 0, border, 0, 0 }; } + +namespace +{ +constexpr unsigned int num_elems_processed_per_iteration_complex = 2; + +Status validate_arguments_complex(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output) +{ + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 2, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 2, DataType::F32); + + const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape()); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); + + // Validate in case of configured output + if(output->total_size() > 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 2, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output"); + } + + return Status{}; +} + +std::pair validate_and_configure_window_complex(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) +{ + const std::pair broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2); + const TensorShape &out_shape = broadcast_pair.first; + const ValidRegion &valid_region = broadcast_pair.second; + + // Auto initialize output if not initialized + const TensorInfo out_info(out_shape, input1->num_channels(), input1->data_type()); + auto_init_if_empty(*output, out_info); + + Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration_complex)); + Window win_input1 = win.broadcast_if_dimension_le_one(*input1); + Window win_input2 = win.broadcast_if_dimension_le_one(*input2); + + AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration_complex); + AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_complex); + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_complex); + + bool window_changed = update_window_and_padding(win_input1, input1_access) + || update_window_and_padding(win_input2, input2_access) + || update_window_and_padding(win, output_access); + + output_access.set_valid_region(win, valid_region); + + Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; + return std::make_pair(err, win); +} +} // namespace + +NEComplexPixelWiseMultiplicationKernel::NEComplexPixelWiseMultiplicationKernel() + : _input1(nullptr), _input2(nullptr), _output(nullptr) +{ +} + +void NEComplexPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(input1->info(), input2->info(), output->info())); + + // Configure kernel window + auto win_config = validate_and_configure_window_complex(input1->info(), input2->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + + _input1 = input1; + _input2 = input2; + _output = output; + + // Create kernel + INEKernel::configure(win_config.second); +} + +Status NEComplexPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_complex(input1, input2, output)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_complex(input1->clone().get(), input2->clone().get(), output->clone().get()).first); + + return Status{}; +} + +void NEComplexPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo &info) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); + + Iterator input1(_input1, window.broadcast_if_dimension_le_one(_input1->info()->tensor_shape())); + Iterator input2(_input2, window.broadcast_if_dimension_le_one(_input2->info()->tensor_shape())); + Iterator output(_output, window); + + execute_window_loop(window, [&](const Coordinates &) + { + c_mul_F32_F32_F32_n(input1.ptr(), input2.ptr(), output.ptr()); + }, + input1, input2, output); +} + +BorderSize NEComplexPixelWiseMultiplicationKernel::border_size() const +{ + const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0)); + const unsigned int border = std::min(num_elems_processed_per_iteration_complex - 1U, replicateSize); + return { 0, border, 0, 0 }; +} } // namespace arm_compute diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp index e6fdba2696..aa20d1f40d 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp +++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp @@ -602,7 +602,7 @@ struct RedOpYZW { ARM_COMPUTE_UNUSED(out_slice); - execute_window_loop(in_slice, [&](const Coordinates & id) + execute_window_loop(in_slice, [&](const Coordinates &) { neon_vector vec_res_value = { 0 }; if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN) @@ -688,13 +688,70 @@ struct RedOpYZW } }; +template +struct RedOpYZW_complex +{ + /** NEON vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector::tag_type; + using neon_vector = typename wrapper::traits::neon_vector::type; + + inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int, const ReductionOperation) + { + ARM_COMPUTE_UNUSED(out_slice); + ARM_COMPUTE_ERROR_ON(axis != 2); + + const size_t stride_z = in_info.strides_in_bytes()[axis]; + + execute_window_loop(in_slice, [&](const Coordinates &) + { + neon_vector vec_res_value_0 = { 0 }; + neon_vector vec_res_value_1 = { 0 }; + + vec_res_value_0 = wrapper::vdup_n(static_cast(0.f), ExactTagType{}); + vec_res_value_1 = wrapper::vdup_n(static_cast(0.f), ExactTagType{}); + + for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + T *in_ptr_0; + T *in_ptr_1; + switch(axis) + { + case 2: + in_ptr_0 = reinterpret_cast(input.ptr() + stride_z * dim); + in_ptr_1 = reinterpret_cast(input.ptr() + 16 + stride_z * dim); + break; + default: + ARM_COMPUTE_ERROR("Not supported"); + } + const auto vec_elements_0 = wrapper::vloadq(in_ptr_0); + const auto vec_elements_1 = wrapper::vloadq(in_ptr_1); + + switch(op) + { + case ReductionOperation::SUM: + vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0); + vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1); + break; + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + wrapper::vstore(reinterpret_cast(output.ptr()), vec_res_value_0); + wrapper::vstore(reinterpret_cast(output.ptr() + 16), vec_res_value_1); + + }, + input, output); + } +}; + struct RedOpYZW_qasymm8 { inline void operator()(Iterator &input, Iterator &output, Window &in_slice, Window &out_slice, const TensorInfo &in_info, int axis, const ReductionOperation op) { ARM_COMPUTE_UNUSED(out_slice); - execute_window_loop(in_slice, [&](const Coordinates & id) + execute_window_loop(in_slice, [&](const Coordinates &) { uint32x4x4_t vec_res_idx{ { 0 } }; auto vec_res_value1 = vdupq_n_u32(0); @@ -848,6 +905,31 @@ struct RedOpYZW_qasymm8 void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op) { + const bool is_complex = (input->info()->num_channels() == 2); + + if(is_complex) + { + switch(axis) + { + case 2: + switch(input->info()->data_type()) + { + case DataType::F32: + switch(op) + { + case ReductionOperation::SUM: + return Reducer>::reduceZ(window, input, output, RedOpYZW_complex(), op); + default: + ARM_COMPUTE_ERROR("Not supported"); + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + switch(axis) { case 0: @@ -917,7 +999,17 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); + + if(input->num_channels() == 1) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM); + ARM_COMPUTE_RETURN_ERROR_ON(axis != 2); + } ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis"); @@ -929,6 +1021,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output); + ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels()); } else { @@ -951,7 +1044,7 @@ std::tuple validate_and_configure_window(ITensorInfo *input, ITe // Output auto initialization if not yet initialized const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX); DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type(); - auto_init_if_empty(*output, output_shape, 1, output_data_type, input->quantization_info()); + auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true)); unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type()); -- cgit v1.2.1