aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLPoolingLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLPoolingLayerKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLPoolingLayerKernel.cpp28
1 files changed, 14 insertions, 14 deletions
diff --git a/src/core/CL/kernels/CLPoolingLayerKernel.cpp b/src/core/CL/kernels/CLPoolingLayerKernel.cpp
index 8eaf5bf76f..032d451aad 100644
--- a/src/core/CL/kernels/CLPoolingLayerKernel.cpp
+++ b/src/core/CL/kernels/CLPoolingLayerKernel.cpp
@@ -172,7 +172,7 @@ std::tuple<Status, Window, CLPoolingConfig> validate_and_configure_window(ITenso
} // namespace
CLPoolingLayerKernel::CLPoolingLayerKernel()
- : _input(nullptr), _output(nullptr), _pool_info(), _border_size(0), _num_elems_processed_per_iteration(1)
+ : _input(nullptr), _output(nullptr), _pool_info(), _data_layout(DataLayout::UNKNOWN), _border_size(0), _num_elems_processed_per_iteration(1)
{
}
@@ -185,13 +185,18 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output,
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+ // Set instance variables
+ _input = input;
+ _output = output;
+ _pool_info = pool_info;
+ _data_layout = input->info()->data_layout();
+
int pool_stride_x = 0;
int pool_stride_y = 0;
const PoolingType pool_type = pool_info.pool_type();
- DataLayout data_layout = input->info()->data_layout();
- const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
- const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
- const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+ const int idx_width = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
+ const int idx_channel = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::CHANNEL);
const int pool_size_x = pool_info.is_global_pooling() ? input->info()->dimension(idx_width) : pool_info.pool_size().width;
const int pool_size_y = pool_info.is_global_pooling() ? input->info()->dimension(idx_height) : pool_info.pool_size().height;
const PadStrideInfo pad_stride_info = pool_info.pad_stride_info();
@@ -218,11 +223,6 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output,
auto_init(input->info(), output->info(), pool_info);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), pool_info));
- // Set instance variables
- _input = input;
- _output = output;
- _pool_info = pool_info;
-
const DataType data_type = input->info()->data_type();
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
@@ -237,7 +237,7 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output,
build_opts.add_option_if(data_type == DataType::F16, "-DFP16");
// Create kernel
- switch(data_layout)
+ switch(_data_layout)
{
case DataLayout::NCHW:
{
@@ -286,7 +286,7 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output,
ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
ICLKernel::configure_internal(std::get<1>(win_config));
- if(data_layout == DataLayout::NCHW)
+ if(_data_layout == DataLayout::NCHW)
{
CLPoolingConfig pooling_config = std::get<2>(win_config);
_num_elems_processed_per_iteration = pooling_config.first;
@@ -302,7 +302,7 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output,
_config_id = "pooling_layer_";
_config_id += lower_string(string_from_data_type(data_type));
_config_id += "_";
- _config_id += lower_string(string_from_data_layout(data_layout));
+ _config_id += lower_string(string_from_data_layout(_data_layout));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(idx_width));
_config_id += "_";
@@ -333,7 +333,7 @@ void CLPoolingLayerKernel::run(const Window &window, cl::CommandQueue &queue)
// Collapse window
Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
- switch(_input->info()->data_layout())
+ switch(_data_layout)
{
case DataLayout::NCHW:
{