aboutsummaryrefslogtreecommitdiff
path: root/src/core/CPP/kernels/CPPUpsampleKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CPP/kernels/CPPUpsampleKernel.cpp')
-rw-r--r--src/core/CPP/kernels/CPPUpsampleKernel.cpp32
1 files changed, 22 insertions, 10 deletions
diff --git a/src/core/CPP/kernels/CPPUpsampleKernel.cpp b/src/core/CPP/kernels/CPPUpsampleKernel.cpp
index c190543216..8348b4335e 100644
--- a/src/core/CPP/kernels/CPPUpsampleKernel.cpp
+++ b/src/core/CPP/kernels/CPPUpsampleKernel.cpp
@@ -71,15 +71,19 @@ void CPPUpsampleKernel::run(const Window &window, const ThreadInfo &info)
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICPPKernel::window(), window);
+ const DataLayout data_layout = _input->info()->data_layout();
+ const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
// Initialize _scaled_output buffer
- const int width_scaled = _output->info()->dimension(0);
- const int height_scaled = _output->info()->dimension(1);
- const int stride_x = _info.stride().first;
- const int stride_y = _info.stride().second;
- const int start_x = _info.pad_left();
- const int start_y = _info.pad_top();
- const int end_x = width_scaled - _info.pad_right();
- const int end_y = height_scaled - _info.pad_bottom();
+ const int width_scaled = _output->info()->dimension(idx_w);
+ const int height_scaled = _output->info()->dimension(idx_h);
+ const int stride_width = _info.stride().first;
+ const int stride_height = _info.stride().second;
+ const int start_width = _info.pad_left();
+ const int start_height = _info.pad_top();
+ const int end_width = width_scaled - _info.pad_right();
+ const int end_height = height_scaled - _info.pad_bottom();
const size_t element_size = _input->info()->element_size();
// The fill value is normally 0, but for quantized types '0' corresponds to the offset
@@ -103,8 +107,16 @@ void CPPUpsampleKernel::run(const Window &window, const ThreadInfo &info)
// Create window
Window window_out(window);
- window_out.set(Window::DimX, Window::Dimension(start_x, end_x, stride_x));
- window_out.set(Window::DimY, Window::Dimension(start_y, end_y, stride_y));
+ if(data_layout == DataLayout::NCHW)
+ {
+ window_out.set(Window::DimX, Window::Dimension(start_width, end_width, stride_width));
+ window_out.set(Window::DimY, Window::Dimension(start_height, end_height, stride_height));
+ }
+ else
+ {
+ window_out.set(Window::DimY, Window::Dimension(start_width, end_width, stride_width));
+ window_out.set(Window::DimZ, Window::Dimension(start_height, end_height, stride_height));
+ }
// Create iterators
Iterator in(_input, window);