diff options
Diffstat (limited to 'src/core/CL/kernels/CLSpaceToBatchLayerKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLSpaceToBatchLayerKernel.cpp | 41 |
1 files changed, 28 insertions, 13 deletions
diff --git a/src/core/CL/kernels/CLSpaceToBatchLayerKernel.cpp b/src/core/CL/kernels/CLSpaceToBatchLayerKernel.cpp index cda6e96806..9e4010e6c6 100644 --- a/src/core/CL/kernels/CLSpaceToBatchLayerKernel.cpp +++ b/src/core/CL/kernels/CLSpaceToBatchLayerKernel.cpp @@ -58,11 +58,16 @@ Status validate_arguments_static(const ITensorInfo *input, const int block_shape // Validate output if initialized if(output->total_size() != 0) { - ARM_COMPUTE_RETURN_ERROR_ON(output->tensor_shape()[0] < padding_left.x() + padding_right.y()); - ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[0] / block_shape_x != (output->tensor_shape()[0] - padding_left.x() - padding_right.y())); - ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[1] / block_shape_y != (output->tensor_shape()[1] - padding_left.x() - padding_right.y())); - ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[2] != output->tensor_shape()[2]); - ARM_COMPUTE_RETURN_ERROR_ON(output->tensor_shape()[3] % (block_shape_x * block_shape_y) != 0); + const DataLayout data_layout = input->data_layout(); + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); + const int idx_batch = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES); + ARM_COMPUTE_RETURN_ERROR_ON(output->tensor_shape()[idx_width] < padding_left.x() + padding_right.y()); + ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[idx_width] / block_shape_x != (output->tensor_shape()[idx_width] - padding_left.x() - padding_right.y())); + ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[idx_height] / block_shape_y != (output->tensor_shape()[idx_height] - padding_left.x() - padding_right.y())); + ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[idx_channel] != output->tensor_shape()[idx_channel]); + ARM_COMPUTE_RETURN_ERROR_ON(output->tensor_shape()[idx_batch] % (block_shape_x * block_shape_y) != 0); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); } @@ -85,13 +90,18 @@ void CLSpaceToBatchLayerKernel::configure(const ICLTensor *input, const ICLTenso _paddings = paddings; _output = output; + const DataLayout data_layout = input->info()->data_layout(); + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int idx_batch = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES); + // Create kernel CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())); - build_opts.add_option("-DWIDTH_OUT=" + support::cpp11::to_string(output->info()->dimension(0))); - build_opts.add_option("-DHEIGHT_OUT=" + support::cpp11::to_string(output->info()->dimension(1))); - build_opts.add_option("-DBATCH_SIZE=" + support::cpp11::to_string(output->info()->dimension(3))); - _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("space_to_batch", build_opts.options())); + build_opts.add_option("-DWIDTH_OUT=" + support::cpp11::to_string(output->info()->dimension(idx_width))); + build_opts.add_option("-DHEIGHT_OUT=" + support::cpp11::to_string(output->info()->dimension(idx_height))); + build_opts.add_option("-DBATCH_SIZE=" + support::cpp11::to_string(output->info()->dimension(idx_batch))); + _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("space_to_batch_" + lower_string(string_from_data_layout(input->info()->data_layout())), build_opts.options())); // Configure kernel window Window win = calculate_max_window(*output->info(), Steps()); @@ -111,19 +121,24 @@ void CLSpaceToBatchLayerKernel::configure(const ICLTensor *input, const int bloc _input = input; _output = output; + const DataLayout data_layout = input->info()->data_layout(); + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int idx_batch = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES); + // Create kernel CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())); - build_opts.add_option("-DWIDTH_OUT=" + support::cpp11::to_string(output->info()->dimension(0))); - build_opts.add_option("-DHEIGHT_OUT=" + support::cpp11::to_string(output->info()->dimension(1))); - build_opts.add_option("-DBATCH_SIZE=" + support::cpp11::to_string(output->info()->dimension(3))); + build_opts.add_option("-DWIDTH_OUT=" + support::cpp11::to_string(output->info()->dimension(idx_width))); + build_opts.add_option("-DHEIGHT_OUT=" + support::cpp11::to_string(output->info()->dimension(idx_height))); + build_opts.add_option("-DBATCH_SIZE=" + support::cpp11::to_string(output->info()->dimension(idx_batch))); build_opts.add_option("-DBLOCK_SHAPE_X=" + support::cpp11::to_string(block_shape_x)); build_opts.add_option("-DBLOCK_SHAPE_Y=" + support::cpp11::to_string(block_shape_y)); build_opts.add_option("-DPAD_START_X=" + support::cpp11::to_string(padding_left.x())); build_opts.add_option("-DPAD_END_X=" + support::cpp11::to_string(padding_right.x())); build_opts.add_option("-DPAD_START_Y=" + support::cpp11::to_string(padding_left.y())); build_opts.add_option("-DPAD_END_Y=" + support::cpp11::to_string(padding_right.y())); - _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("space_to_batch_static", build_opts.options())); + _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("space_to_batch_static_" + lower_string(string_from_data_layout(input->info()->data_layout())), build_opts.options())); // Configure kernel window Window win = calculate_max_window(*output->info(), Steps()); |