From 9d3a831d4131f8a8b37f127f11d36848d33e8496 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Tue, 20 Nov 2018 12:31:24 +0000 Subject: COMPMID-1648: CLNormalizationLayer IN_MAP_2D support for NHWC for FP32/FP16 Change-Id: I49f1d865f5e7562f1d80db849353a89ef77e6a9e --- src/core/CL/kernels/CLNormalizationLayerKernel.cpp | 32 ++++++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) (limited to 'src/core/CL/kernels') diff --git a/src/core/CL/kernels/CLNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLNormalizationLayerKernel.cpp index 67357da7d1..9623ec6a89 100644 --- a/src/core/CL/kernels/CLNormalizationLayerKernel.cpp +++ b/src/core/CL/kernels/CLNormalizationLayerKernel.cpp @@ -37,20 +37,21 @@ using namespace arm_compute; namespace { +constexpr unsigned int num_elems_processed_per_iteration = 4; Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, NormalizationLayerInfo norm_info) { ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(input, DataLayout::NCHW, DataLayout::NHWC); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_layout() == DataLayout::NHWC && norm_info.type() == NormType::IN_MAP_2D, - "Only Cross-map and 1D In-map normalization is supported for NHWC layout"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(norm_info.norm_size() % 2), "Normalization size should be odd"); // Checks performed when output is configured if(output->total_size() != 0) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); } @@ -62,8 +63,6 @@ std::pair validate_and_configure_window(ITensorInfo *input, ITen // Output tensor auto initialization if not yet initialized auto_init_if_empty(*output, *input->clone()); - const unsigned int num_elems_processed_per_iteration = 4; - const unsigned int norm_idx = get_normalization_dimension_index(input->data_layout(), norm_info); const bool is_norm_accross_width = norm_idx == 0; @@ -118,15 +117,14 @@ void CLNormalizationLayerKernel::configure(const ICLTensor *input, ICLTensor *ou _input = input; _output = output; - const unsigned int num_elems_processed_per_iteration = 4; - const bool is_in_map_2D = (norm_info.type() == NormType::IN_MAP_2D); - const DataLayout data_layout = input->info()->data_layout(); const unsigned int norm_idx = get_normalization_dimension_index(data_layout, norm_info); _is_norm_across_width = norm_idx == 0; const unsigned int border_width = _is_norm_across_width ? num_elems_processed_per_iteration - 1 : 0; _border_size = BorderSize(0, border_width); + const bool is_in_map_2D = (norm_info.type() == NormType::IN_MAP_2D); + // Set build options CLBuildOptions build_opts; build_opts.add_option(("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()))); @@ -140,8 +138,24 @@ void CLNormalizationLayerKernel::configure(const ICLTensor *input, ICLTensor *ou build_opts.add_option_if(norm_info.is_in_map() || (data_layout == DataLayout::NHWC && norm_info.is_cross_map()), "-DWIDTH_SIZE=" + support::cpp11::to_string(input->info()->dimension(0))); // Create kernel - std::string kernel_name = _is_norm_across_width ? "normalization_layer_in_map" : "normalization_layer_cross_map"; - _kernel = static_cast(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); + std::string kernel_name; + if(norm_info.is_in_map()) + { + kernel_name = "normalization_layer_in_map_" + lower_string(string_from_data_layout(data_layout)); + } + else + { + if(data_layout == DataLayout::NCHW) + { + kernel_name = "normalization_layer_cross_map"; + } + else + { + // 1D Cross-Map normalization in NHWC is the same as 1D In-Map normalization in NCHW + kernel_name = "normalization_layer_in_map_nchw"; + } + } + _kernel = static_cast(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); // Configure kernel window auto win_config = validate_and_configure_window(input->info(), output->info(), norm_info); -- cgit v1.2.1