aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLUpsampleLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLUpsampleLayerKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLUpsampleLayerKernel.cpp12
1 files changed, 5 insertions, 7 deletions
diff --git a/src/core/CL/kernels/CLUpsampleLayerKernel.cpp b/src/core/CL/kernels/CLUpsampleLayerKernel.cpp
index ce5ed86332..2ccd540788 100644
--- a/src/core/CL/kernels/CLUpsampleLayerKernel.cpp
+++ b/src/core/CL/kernels/CLUpsampleLayerKernel.cpp
@@ -37,7 +37,7 @@
namespace arm_compute
{
CLUpsampleLayerKernel::CLUpsampleLayerKernel()
- : _input(nullptr), _output(nullptr), _info(), _num_elems_processed_per_iteration_input_x()
+ : _input(nullptr), _output(nullptr), _info(), _data_layout(DataLayout::UNKNOWN), _num_elems_processed_per_iteration_input_x()
{
}
@@ -71,13 +71,12 @@ void CLUpsampleLayerKernel::configure(const ICLTensor *input, ICLTensor *output,
_input = input;
_output = output;
_info = info;
+ _data_layout = input->info()->data_layout();
_num_elems_processed_per_iteration_input_x = 1;
- const DataLayout data_layout = input->info()->data_layout();
-
TensorShape output_shape = misc::shape_calculator::compute_upsample_shape(*input->info(), info);
auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type());
- output->info()->set_data_layout(data_layout);
+ output->info()->set_data_layout(_data_layout);
unsigned int num_elems_processed_per_iteration_x = 16;
const int output_width_x = output->info()->dimension(0);
@@ -88,7 +87,7 @@ void CLUpsampleLayerKernel::configure(const ICLTensor *input, ICLTensor *output,
Window win{};
- switch(data_layout)
+ switch(_data_layout)
{
case DataLayout::NCHW:
{
@@ -140,8 +139,7 @@ void CLUpsampleLayerKernel::run(const Window &window, cl::CommandQueue &queue)
Window slice_out = collapsed_window.first_slice_window_3D();
Window slice_in = collapsed_window.first_slice_window_3D();
- DataLayout data_layout = _input->info()->data_layout();
- switch(data_layout)
+ switch(_data_layout)
{
case DataLayout::NCHW:
slice_in.set(Window::DimX, Window::Dimension(0, _input->info()->dimension(0), _num_elems_processed_per_iteration_input_x));