From e250389ed6d78153a55382fa5b3519c151bfd79f Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Mon, 23 Apr 2018 15:17:31 +0100 Subject: COMPMID-810 Add NHWC data format support for NEON convolution Change-Id: I2a7b49a12da7f3bc3f04749243b1dc111160de6e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129348 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- .../NEON/functions/NEGEMMConvolutionLayer.cpp | 242 +++++++++++++-------- 1 file changed, 151 insertions(+), 91 deletions(-) (limited to 'src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp') diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp index 5a35463365..a5f30557a0 100644 --- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp @@ -109,6 +109,14 @@ Status NEConvolutionLayerReshapeWeights::validate(const ITensorInfo *weights, co ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); } + // Checks performed when biases are present + if(append_bias) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases); + ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(3)); + ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); + } + if(transpose1xW) { TensorInfo weights_reshaped = weights->clone()->set_tensor_shape(get_reshaped_weights_shape(weights, append_bias)); @@ -159,7 +167,7 @@ TensorShape get_reshaped_weights_shape_conv(const ITensorInfo *weights, bool app Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const PadStrideInfo &conv_info, const WeightsInfo &weights_info, const ActivationLayerInfo &act_info, DataType &dt, - bool &append_bias, + bool &append_bias, bool &skip_im2col, bool &are_weights_reshaped, unsigned int &kernel_width, unsigned int &kernel_height, bool &is_fully_connected_convolution, bool &is_interleaved, bool &is_quantized, bool &is_activationlayer_enabled, unsigned int &mat_weights_cols, unsigned int &mat_weights_rows, @@ -168,9 +176,17 @@ Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInf ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, weights); - ARM_COMPUTE_RETURN_ERROR_ON(!weights_info.are_reshaped() && weights->dimension(2) != input->dimension(2)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights); + + DataLayout data_layout = input->data_layout(); + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); + + ARM_COMPUTE_RETURN_ERROR_ON(!weights_info.are_reshaped() && weights->dimension(idx_channel) != input->dimension(idx_channel)); ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4); ARM_COMPUTE_RETURN_ERROR_ON(weights_info.are_reshaped() && is_data_type_quantized_asymmetric(input->data_type())); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32, "NHWC is only supported for FP32 data type."); dt = input->data_type(); is_quantized = is_data_type_quantized_asymmetric(dt); @@ -190,14 +206,16 @@ Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInf ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); } + // If we have 1x1 convolution and data layout is NHWC we can disable im2col append_bias = (biases != nullptr) && (!is_quantized); are_weights_reshaped = weights_info.are_reshaped(); - kernel_width = (are_weights_reshaped) ? weights_info.kernel_size().first : weights->dimension(0); - kernel_height = (are_weights_reshaped) ? weights_info.kernel_size().second : weights->dimension(1); + kernel_width = (are_weights_reshaped) ? weights_info.kernel_size().first : weights->dimension(idx_width); + kernel_height = (are_weights_reshaped) ? weights_info.kernel_size().second : weights->dimension(idx_height); mat_weights_cols = weights->dimension(3); - mat_weights_rows = weights->dimension(0) * weights->dimension(1) * weights->dimension(2) + (append_bias ? 1 : 0); + mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel) + ((append_bias && !skip_im2col) ? 1 : 0); + skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1); - std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_width, kernel_height, + std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(idx_width), input->dimension(idx_height), kernel_width, kernel_height, conv_info, dilation); // Check if its a "fully connected" convolution @@ -211,9 +229,9 @@ Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInf NEGEMMConvolutionLayer::NEGEMMConvolutionLayer(const std::shared_ptr &memory_manager) : _asm_glue(), _memory_group(memory_manager), _input_im2col_kernel(), _input_interleave_kernel(), _reshape_weights(), _mm_kernel(), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), - _output_col2im_kernel(), _activationlayer_function(), _original_weights(nullptr), _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _gemm_output(), _tmp_output(), - _workspace(), _B_pretransposed(), _append_bias(false), _is_fully_connected_convolution(false), _are_weights_reshaped(false), _is_quantized(false), _is_interleaved(false), - _is_activationlayer_enabled(false) + _output_col2im_kernel(), _activationlayer_function(), _add_bias_kernel(), _original_weights(nullptr), _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _gemm_output(), + _tmp_output(), _workspace(), _B_pretransposed(), _data_layout(DataLayout::NCHW), _append_bias(false), _is_fully_connected_convolution(false), _are_weights_reshaped(false), _is_quantized(false), + _is_interleaved(false), _is_activationlayer_enabled(false), _skip_im2col(false) { } @@ -255,7 +273,13 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig unsigned int conv_w = 0; unsigned int conv_h = 0; - Status status = validate_and_initialize_values(input->info(), weights->info(), (biases == nullptr) ? nullptr : biases->info(), conv_info, weights_info, act_info, dt, _append_bias, + _data_layout = input->info()->data_layout(); + const bool is_nhwc = _data_layout == DataLayout::NHWC; + const int idx_width = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); + const int idx_channel = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::CHANNEL); + + Status status = validate_and_initialize_values(input->info(), weights->info(), (biases == nullptr) ? nullptr : biases->info(), conv_info, weights_info, act_info, dt, _append_bias, _skip_im2col, _are_weights_reshaped, kernel_width, kernel_height, _is_fully_connected_convolution, _is_interleaved, _is_quantized, _is_activationlayer_enabled, @@ -272,20 +296,12 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig // Reshape weights if needed if(run_optimised) { - if(_are_weights_reshaped) - { - mat_weights_cols = weights_info.num_kernels(); - mat_weights_rows = weights->info()->dimension(1); - } - else - { - TensorShape reshaped_weights_shape{ mat_weights_cols, mat_weights_rows }; + TensorShape reshaped_weights_shape{ mat_weights_cols, mat_weights_rows }; - // Create tensor to store the reshaped weights - _weights_reshaped.allocator()->init(TensorInfo(reshaped_weights_shape, 1, dt, fixed_point_position)); - _reshape_weights.configure(weights, biases, &_weights_reshaped, false /* 1xW transpose */); - weights = &_weights_reshaped; - } + // Create tensor to store the reshaped weights + _weights_reshaped.allocator()->init(TensorInfo(reshaped_weights_shape, 1, dt, fixed_point_position)); + _reshape_weights.configure(weights, biases, &_weights_reshaped, false /* 1xW transpose */); + weights = &_weights_reshaped; } else { @@ -294,12 +310,12 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig if(_is_fully_connected_convolution || _is_quantized) { mat_weights_cols = weights_info.num_kernels(); - mat_weights_rows = weights->info()->dimension(1); + mat_weights_rows = weights->info()->dimension(idx_height); } else { mat_weights_cols = weights_info.num_kernels(); - mat_weights_rows = weights_info.kernel_size().first * weights_info.kernel_size().second * input->info()->dimension(2) + (_append_bias ? 1 : 0); + mat_weights_rows = weights_info.kernel_size().first * weights_info.kernel_size().second * input->info()->dimension(idx_channel) + (_append_bias ? 1 : 0); } } else @@ -325,48 +341,56 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig } } - // Create tensor to store im2col reshaped inputs - const unsigned int mat_input_cols = mat_weights_rows; - const unsigned int mat_input_rows = conv_w * conv_h; - - TensorShape shape_im2col(input->info()->tensor_shape()); - shape_im2col.set(0, mat_input_cols); - shape_im2col.set(1, mat_input_rows); - shape_im2col.set(2, 1); - _input_im2col_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col)); - _memory_group.manage(&_input_im2col_reshaped); + // In case we skip im2col we have to add bias + if(!_skip_im2col) + { + const unsigned int mat_input_cols = mat_weights_rows; + const unsigned int mat_input_rows = conv_w * conv_h; + + // Create tensor to store im2col reshaped inputs + TensorShape shape_im2col(input->info()->tensor_shape()); + shape_im2col.set(0, mat_input_cols); + shape_im2col.set(1, mat_input_rows); + shape_im2col.set(2, 1); + _input_im2col_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col)); + _memory_group.manage(&_input_im2col_reshaped); + + // Create tensor (interleave) to prepare input tensor for GEMM + if(!_is_fully_connected_convolution && !run_optimised && _is_interleaved) + { + TensorShape shape_interleaved(shape_im2col); + shape_interleaved.set(idx_width, shape_interleaved.x() * 4); + shape_interleaved.set(idx_height, std::ceil(shape_interleaved[idx_height] / 4.f)); + _input_interleaved_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_interleaved)); + _memory_group.manage(&_input_interleaved_reshaped); + } - // Create tensor (interleave) to prepare input tensor for GEMM - if(!_is_fully_connected_convolution && !run_optimised && _is_interleaved) + // Create GEMM output tensor + TensorShape shape_gemm(_input_im2col_reshaped.info()->tensor_shape()); + shape_gemm.set(0, mat_weights_cols); + shape_gemm.set(1, mat_input_rows); + const DataType gemm_data_type = _is_quantized ? DataType::S32 : dt; + // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input. + TensorInfo info_gemm(shape_gemm, 1, gemm_data_type, input->info()->fixed_point_position()); + info_gemm.set_quantization_info(output->info()->quantization_info()); + _gemm_output.allocator()->init(info_gemm); + + // FIXME: enabling memory manager for _gemm_output gives incorrect results (maybe bound to the assembly kernel in GEMMLowp?) + // _memory_group.manage(&_gemm_output); + + // Configure im2col + _input_im2col_kernel.configure(input, &_input_im2col_reshaped, Size2D(kernel_width, kernel_height), conv_info, _append_bias, false, false, dilation); + } + else if(_append_bias) { - TensorShape shape_interleaved(shape_im2col); - shape_interleaved.set(0, shape_interleaved.x() * 4); - shape_interleaved.set(1, std::ceil(shape_interleaved.y() / 4.f)); - _input_interleaved_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_interleaved)); - _memory_group.manage(&_input_interleaved_reshaped); + // Configure add bias kernel + _add_bias_kernel.configure(output, biases, output, ConvertPolicy::SATURATE); } - // Create GEMM output tensor - TensorShape shape_gemm(_input_im2col_reshaped.info()->tensor_shape()); - shape_gemm.set(0, mat_weights_cols); - shape_gemm.set(1, mat_input_rows); - const DataType gemm_data_type = _is_quantized ? DataType::S32 : dt; - // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input. - TensorInfo info_gemm(shape_gemm, 1, gemm_data_type, input->info()->fixed_point_position()); - info_gemm.set_quantization_info(output->info()->quantization_info()); - _gemm_output.allocator()->init(info_gemm); - - // FIXME: enabling memory manager for _gemm_output gives incorrect results (maybe bound to the assembly kernel in GEMMLowp?) - // _memory_group.manage(&_gemm_output); - - // Configure kernels - // Configure im2col - _input_im2col_kernel.configure(input, &_input_im2col_reshaped, Size2D(kernel_width, kernel_height), conv_info, _append_bias, false, false, dilation); - // Configure matrix multiply if(run_optimised) { - if(!setup_assembly_kernel(&_input_im2col_reshaped, weights, &_gemm_output, 1.f, 0.f, true, _workspace, _B_pretransposed, _memory_group, _asm_glue)) + if(!setup_assembly_kernel(_skip_im2col ? input : &_input_im2col_reshaped, weights, is_nhwc ? output : &_gemm_output, 1.f, 0.f, true, _workspace, _B_pretransposed, _memory_group, _asm_glue)) { ARM_COMPUTE_ERROR("setup_assembly_kernel failed."); } @@ -379,8 +403,8 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig _input_interleave_kernel.configure(&_input_im2col_reshaped, &_input_interleaved_reshaped); // Configure GEMM - configure_mm(&_input_interleaved_reshaped, weights, &_gemm_output, _is_interleaved, GEMMReshapeInfo(_input_im2col_reshaped.info()->dimension(1), 0 /* no transpose */, - _input_im2col_reshaped.info()->dimension(0))); + configure_mm(&_input_interleaved_reshaped, weights, &_gemm_output, _is_interleaved, GEMMReshapeInfo(_input_im2col_reshaped.info()->dimension(idx_height), 0 /* no transpose */, + _input_im2col_reshaped.info()->dimension(idx_width))); _input_interleaved_reshaped.allocator()->allocate(); } else @@ -389,29 +413,36 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig } } - _input_im2col_reshaped.allocator()->allocate(); - - // Configure output stage for quantized case - if(_is_quantized) + if(!_skip_im2col) { - const QuantizationInfo output_quant_info = (output->info()->total_size() == 0) ? input->info()->quantization_info() : output->info()->quantization_info(); + _input_im2col_reshaped.allocator()->allocate(); - float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale; - int output_multiplier, output_shift; - quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift); - _memory_group.manage(&_tmp_output); - _gemmlowp_output_stage.configure(&_gemm_output, biases, &_tmp_output, output_multiplier, output_shift, output_quant_info.offset); - } + // Configure output stage for quantized case + if(_is_quantized) + { + const QuantizationInfo output_quant_info = (output->info()->total_size() == 0) ? input->info()->quantization_info() : output->info()->quantization_info(); - // Configure Col2Im - _output_col2im_kernel.configure(_is_quantized ? &_tmp_output : &_gemm_output, output, Size2D(conv_w, conv_h)); - if(_is_quantized) - { - _tmp_output.allocator()->allocate(); + float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale; + int output_multiplier, output_shift; + quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift); + _memory_group.manage(&_tmp_output); + _gemmlowp_output_stage.configure(&_gemm_output, biases, &_tmp_output, output_multiplier, output_shift, output_quant_info.offset); + } + + // Configure Col2Im + if(!is_nhwc) + { + _output_col2im_kernel.configure(_is_quantized ? &_tmp_output : &_gemm_output, output, Size2D(conv_w, conv_h)); + } + + if(_is_quantized) + { + _tmp_output.allocator()->allocate(); + } + _gemm_output.allocator()->allocate(); } - _gemm_output.allocator()->allocate(); - ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(0) != conv_w) || (output->info()->dimension(1) != conv_h), "Output shape does not match the expected one"); + ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(idx_width) != conv_w) || (output->info()->dimension(idx_height) != conv_h), "Output shape does not match the expected one"); // Allocate intermediate tensor if(!_are_weights_reshaped) @@ -433,6 +464,7 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI DataType dt{}; bool append_bias{}; + bool skip_im2col{}; bool are_weights_reshaped{}; bool is_fully_connected_convolution{}; bool is_interleaved{}; @@ -445,7 +477,12 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI unsigned int conv_w = 0; unsigned int conv_h = 0; - Status status = validate_and_initialize_values(input, weights, biases, conv_info, weights_info, act_info, dt, append_bias, are_weights_reshaped, kernel_width, kernel_height, + const DataLayout data_layout = input->data_layout(); + const bool is_nhwc = data_layout == DataLayout::NHWC; + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + + Status status = validate_and_initialize_values(input, weights, biases, conv_info, weights_info, act_info, dt, append_bias, skip_im2col, are_weights_reshaped, kernel_width, kernel_height, is_fully_connected_convolution, is_interleaved, is_quantized, is_activationlayer_enabled, mat_weights_cols, mat_weights_rows, conv_w, conv_h, dilation); @@ -461,7 +498,6 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI optimised_kernel = true; } - // Validate im2col const unsigned int mat_input_cols = mat_weights_rows; const unsigned int mat_input_rows = conv_w * conv_h; TensorShape shape_im2col = input->tensor_shape(); @@ -469,7 +505,17 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI shape_im2col.set(1, mat_input_rows); shape_im2col.set(2, 1); TensorInfo im2_col_info = input->clone()->set_tensor_shape(shape_im2col); - ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &im2_col_info, kernel_weights, conv_info, append_bias, false, false, dilation)); + + if(!skip_im2col) + { + // Validate im2col + ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &im2_col_info, kernel_weights, conv_info, append_bias, false, false, dilation)); + } + else if(append_bias) + { + // Validate add bias kernel + ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(output, biases, output, ConvertPolicy::SATURATE)); + } // Create GEMM output tensor TensorShape shape_gemm(im2_col_info.tensor_shape()); @@ -511,8 +557,8 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI if(is_interleaved) { TensorShape shape_interleaved = shape_im2col; - shape_interleaved.set(0, shape_interleaved.x() * 4); - shape_interleaved.set(1, std::ceil(shape_interleaved.y() / 4.f)); + shape_interleaved.set(idx_width, shape_interleaved.x() * 4); + shape_interleaved.set(idx_height, std::ceil(shape_interleaved.y() / 4.f)); TensorInfo input_interleaved_info = input->clone()->set_tensor_shape(shape_interleaved); ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(&im2_col_info, &input_interleaved_info)); ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&input_interleaved_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo(shape_im2col[1], // m @@ -524,10 +570,12 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&im2_col_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo())); } } + if(!is_nhwc) + { + ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(&gemm_output_info, output, Size2D(conv_w, conv_h))); + } - ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(&gemm_output_info, output, Size2D(conv_w, conv_h))); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG((output->dimension(0) != conv_w) || (output->dimension(1) != conv_h), "Output shape does not match the expected one"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((output->dimension(idx_width) != conv_w) || (output->dimension(idx_height) != conv_h), "Output shape does not match the expected one"); if(act_info.enabled()) { @@ -553,8 +601,12 @@ void NEGEMMConvolutionLayer::run() _memory_group.acquire(); - // Run input reshaping - NEScheduler::get().schedule(&_input_im2col_kernel, Window::DimY); + if(!_skip_im2col) + { + // Run input reshaping + unsigned int _y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); + NEScheduler::get().schedule(&_input_im2col_kernel, _y_dim); + } // Runs matrix multiply on reshaped matrices if(_asm_glue._optimised_kernel != nullptr) @@ -585,6 +637,11 @@ void NEGEMMConvolutionLayer::run() } } + if(_skip_im2col && _append_bias) + { + NEScheduler::get().schedule(&_add_bias_kernel, Window::DimY); + } + // Run output stage for quantized case if(_is_quantized) { @@ -592,7 +649,10 @@ void NEGEMMConvolutionLayer::run() } // Reshape output matrix - NEScheduler::get().schedule(&_output_col2im_kernel, Window::DimY); + if(_data_layout == DataLayout::NCHW) + { + NEScheduler::get().schedule(&_output_col2im_kernel, Window::DimY); + } if(_is_activationlayer_enabled) { -- cgit v1.2.1