aboutsummaryrefslogtreecommitdiff
path: root/src/core/CPP/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CPP/kernels')
-rw-r--r--src/core/CPP/kernels/CPPFlipWeightsKernel.cpp27
1 files changed, 17 insertions, 10 deletions
diff --git a/src/core/CPP/kernels/CPPFlipWeightsKernel.cpp b/src/core/CPP/kernels/CPPFlipWeightsKernel.cpp
index 741218e4f7..2d4c0ce5c8 100644
--- a/src/core/CPP/kernels/CPPFlipWeightsKernel.cpp
+++ b/src/core/CPP/kernels/CPPFlipWeightsKernel.cpp
@@ -42,25 +42,36 @@ CPPFlipWeightsKernel::CPPFlipWeightsKernel()
}
template <typename T>
-void CPPFlipWeightsKernel::flip_weights(const Window &window_input, const Window &window)
+void CPPFlipWeightsKernel::flip_weights(const Window &window_input)
{
// Create iterators
Iterator in(_input, window_input);
- Iterator out(_output, window);
+ const DataLayout data_layout = _input->info()->data_layout();
+ const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
- const int kernel_size = _input->info()->dimension(0);
+ const int kernel_width = _input->info()->dimension(idx_w);
+ const int kernel_height = _input->info()->dimension(idx_h);
execute_window_loop(window_input, [&](const Coordinates & id)
{
- *((reinterpret_cast<T *>(out.ptr()) + kernel_size * (kernel_size - id.y() - 1) + (kernel_size - id.x() - 1))) = *(reinterpret_cast<const T *>(in.ptr()));
+ const unsigned int x = kernel_width - id[idx_w] - 1;
+ const unsigned int y = kernel_height - id[idx_h] - 1;
+ Coordinates output_coord(id);
+ output_coord.set(idx_w, x);
+ output_coord.set(idx_h, y);
+ *(reinterpret_cast<T *>(_output->ptr_to_element(output_coord))) = *(reinterpret_cast<const T *>(in.ptr()));
},
- in, out);
+ in);
}
void CPPFlipWeightsKernel::configure(const ITensor *input, ITensor *output)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
_input = input;
_output = output;
@@ -98,9 +109,5 @@ void CPPFlipWeightsKernel::run(const Window &window, const ThreadInfo &info)
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICPPKernel::window(), window);
ARM_COMPUTE_ERROR_ON(_func == nullptr);
- Window out_window{ window };
- out_window.set(Window::DimX, Window::Dimension(0, 0, 0));
- out_window.set(Window::DimY, Window::Dimension(0, 0, 0));
-
- (this->*_func)(window, out_window);
+ (this->*_func)(window);
}