diff options
Diffstat (limited to 'src/core/NEON/kernels/NEChannelShuffleLayerKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEChannelShuffleLayerKernel.cpp | 101 |
1 files changed, 55 insertions, 46 deletions
diff --git a/src/core/NEON/kernels/NEChannelShuffleLayerKernel.cpp b/src/core/NEON/kernels/NEChannelShuffleLayerKernel.cpp index 008ad7c9f4..3b53b7055f 100644 --- a/src/core/NEON/kernels/NEChannelShuffleLayerKernel.cpp +++ b/src/core/NEON/kernels/NEChannelShuffleLayerKernel.cpp @@ -30,6 +30,7 @@ #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" + #include "src/core/CPP/Validate.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" @@ -44,15 +45,19 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN); ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(input, DataLayout::NCHW, DataLayout::NHWC); - const unsigned int channels = input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)); + const unsigned int channels = + input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)); ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups < 2, "Channel shuffling with less than 2 groups would be inefficient"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups == channels, "Channel shuffling with same number of groups as number of channels would be inefficient"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + num_groups == channels, + "Channel shuffling with same number of groups as number of channels would be inefficient"); ARM_COMPUTE_RETURN_ERROR_ON(num_groups > channels); // There cannot be more groups than channels - ARM_COMPUTE_RETURN_ERROR_ON_MSG((channels % num_groups) != 0, "The number of channels must be a multiple of the number of groups"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((channels % num_groups) != 0, + "The number of channels must be a multiple of the number of groups"); // Checks performed when output is configured - if(output->total_size() != 0) + if (output->total_size() != 0) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); @@ -68,24 +73,26 @@ void channel_shuffle_nhwc(const ITensor *input, ITensor *output, unsigned int nu const size_t element_size = input->info()->element_size(); const unsigned int K = input->info()->dimension(channel_idx) / num_groups; - const float rK = 1.f / K; + const double rK = 1.0 / K; Iterator in(input, window); - execute_window_loop(window, [&](const Coordinates & id) - { - // Shuffle channel - const unsigned int curr_channel = id.x(); - const unsigned int group_id = curr_channel * rK; - const unsigned int r = group_id * K; - const unsigned int channel_id = curr_channel - r; - - // Calculate output coordinates - Coordinates out_coords = id; - out_coords.set(Window::DimX, channel_id * num_groups + group_id); - std::copy_n(in.ptr(), element_size, output->ptr_to_element(out_coords)); - }, - in); + execute_window_loop( + window, + [&](const Coordinates &id) + { + // Shuffle channel + const unsigned int curr_channel = id.x(); + const unsigned int group_id = curr_channel * rK; + const unsigned int r = group_id * K; + const unsigned int channel_id = curr_channel - r; + + // Calculate output coordinates + Coordinates out_coords = id; + out_coords.set(Window::DimX, channel_id * num_groups + group_id); + std::copy_n(in.ptr(), element_size, output->ptr_to_element(out_coords)); + }, + in); } void channel_shuffle_nchw(const ITensor *input, ITensor *output, unsigned int num_groups, const Window &window) { @@ -103,38 +110,39 @@ void channel_shuffle_nchw(const ITensor *input, ITensor *output, unsigned int nu const size_t row_size = input->info()->dimension(width_idx) * input->info()->element_size(); const unsigned int K = input->info()->dimension(channel_idx) / num_groups; - const float rK = 1.f / K; + const double rK = 1.0 / K; Iterator in(input, win); - execute_window_loop(win, [&](const Coordinates & id) - { - // Shuffle channel - const unsigned int curr_channel = id.z(); - const unsigned int group_id = curr_channel * rK; - const unsigned int r = group_id * K; - const unsigned int channel_id = curr_channel - r; - - // Calculate output coordinates - Coordinates out_coords = id; - out_coords.set(Window::DimZ, channel_id * num_groups + group_id); - const uint8_t *input_ptr = in.ptr(); - uint8_t *output_ptr = output->ptr_to_element(out_coords); - - // Copy plane - for(unsigned int y = 0; y < height; ++y) + execute_window_loop( + win, + [&](const Coordinates &id) { - std::copy_n(input_ptr, row_size, output_ptr); - input_ptr += input_stride_y; - output_ptr += output_stride_y; - } - }, - in); + // Shuffle channel + const unsigned int curr_channel = id.z(); + const unsigned int group_id = curr_channel * rK; + const unsigned int r = group_id * K; + const unsigned int channel_id = curr_channel - r; + + // Calculate output coordinates + Coordinates out_coords = id; + out_coords.set(Window::DimZ, channel_id * num_groups + group_id); + const uint8_t *input_ptr = in.ptr(); + uint8_t *output_ptr = output->ptr_to_element(out_coords); + + // Copy plane + for (unsigned int y = 0; y < height; ++y) + { + std::copy_n(input_ptr, row_size, output_ptr); + input_ptr += input_stride_y; + output_ptr += output_stride_y; + } + }, + in); } } // namespace -NEChannelShuffleLayerKernel::NEChannelShuffleLayerKernel() - : _input(nullptr), _output(nullptr), _num_groups() +NEChannelShuffleLayerKernel::NEChannelShuffleLayerKernel() : _input(nullptr), _output(nullptr), _num_groups() { } @@ -158,7 +166,8 @@ void NEChannelShuffleLayerKernel::configure(const ITensor *input, ITensor *outpu INEKernel::configure(win); } -Status NEChannelShuffleLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int num_groups) +Status +NEChannelShuffleLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int num_groups) { ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, num_groups)); return Status{}; @@ -170,7 +179,7 @@ void NEChannelShuffleLayerKernel::run(const Window &window, const ThreadInfo &in ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window); - switch(_input->info()->data_layout()) + switch (_input->info()->data_layout()) { case DataLayout::NHWC: channel_shuffle_nhwc(_input, _output, _num_groups, window); |