diff options
Diffstat (limited to 'src/core/CL/kernels/CLPoolingLayerKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLPoolingLayerKernel.cpp | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/src/core/CL/kernels/CLPoolingLayerKernel.cpp b/src/core/CL/kernels/CLPoolingLayerKernel.cpp index 8e69157fdb..e3f1114f21 100644 --- a/src/core/CL/kernels/CLPoolingLayerKernel.cpp +++ b/src/core/CL/kernels/CLPoolingLayerKernel.cpp @@ -172,7 +172,7 @@ std::tuple<Status, Window, CLPoolingConfig> validate_and_configure_window(ITenso } // namespace CLPoolingLayerKernel::CLPoolingLayerKernel() - : _input(nullptr), _output(nullptr), _pool_info(), _border_size(0), _num_elems_processed_per_iteration(1) + : _input(nullptr), _output(nullptr), _pool_info(), _data_layout(DataLayout::UNKNOWN), _border_size(0), _num_elems_processed_per_iteration(1) { } @@ -185,13 +185,18 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output, { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + // Set instance variables + _input = input; + _output = output; + _pool_info = pool_info; + _data_layout = input->info()->data_layout(); + int pool_stride_x = 0; int pool_stride_y = 0; const PoolingType pool_type = pool_info.pool_type(); - DataLayout data_layout = input->info()->data_layout(); - const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); - const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); - const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); + const int idx_width = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); + const int idx_channel = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::CHANNEL); const int pool_size_x = pool_info.is_global_pooling() ? input->info()->dimension(idx_width) : pool_info.pool_size().width; const int pool_size_y = pool_info.is_global_pooling() ? input->info()->dimension(idx_height) : pool_info.pool_size().height; const PadStrideInfo pad_stride_info = pool_info.pad_stride_info(); @@ -218,11 +223,6 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output, auto_init(input->info(), output->info(), pool_info); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), pool_info)); - // Set instance variables - _input = input; - _output = output; - _pool_info = pool_info; - const DataType data_type = input->info()->data_type(); build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type)); @@ -243,7 +243,7 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output, build_opts.add_option_if(use_wider_accumulator, "-DFP_MIXED_PRECISION"); // Create kernel - switch(data_layout) + switch(_data_layout) { case DataLayout::NCHW: { @@ -292,7 +292,7 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output, ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config)); ICLKernel::configure_internal(std::get<1>(win_config)); - if(data_layout == DataLayout::NCHW) + if(_data_layout == DataLayout::NCHW) { CLPoolingConfig pooling_config = std::get<2>(win_config); _num_elems_processed_per_iteration = pooling_config.first; @@ -308,7 +308,7 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output, _config_id = "pooling_layer_"; _config_id += lower_string(string_from_data_type(data_type)); _config_id += "_"; - _config_id += lower_string(string_from_data_layout(data_layout)); + _config_id += lower_string(string_from_data_layout(_data_layout)); _config_id += "_"; _config_id += support::cpp11::to_string(output->info()->dimension(idx_width)); _config_id += "_"; @@ -339,7 +339,7 @@ void CLPoolingLayerKernel::run(const Window &window, cl::CommandQueue &queue) // Collapse window Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ); - switch(_input->info()->data_layout()) + switch(_data_layout) { case DataLayout::NCHW: { |